浏览器端的机器学习 tensorflowjs(6) 训练模型
cover_002.png
现在模型已经定义好了,数据也下载并进行了解决,一切准备就绪准备开始训练。
async function trainModel(model, inputs, labels) { // 准备要训练的模型 model.compile({ optimizer: tf.train.adam(), loss: tf.losses.meanSquaredError, metrics: ['mse'], }); const batchSize = 32; const epochs = 50; return await model.fit(inputs, labels, { batchSize, epochs, shuffle: true, callbacks: tfvis.show.fitCallbacks( { name: 'Training Performance' }, ['loss', 'mse'], { height: 200, callbacks: ['onEpochEnd'] } ) });}
训练前的少量准备
model.compile({ optimizer: tf.train.adam(), loss: tf.losses.meanSquaredError, metrics: ['mse'],});
在训练模型之前,需要 “编译 “该模型,那么具体应该如何做呢? 我们需要一个优化和一个损失函数,损失函数也可以了解目标函数,主要是指定训练,让我们训练一个目标,优化器这是给出一个策略如何在训练过程升级参数。
- 优化器。这是一种算法,是升级参数的算法。在 TensorFlow.js 中有许多优化器可用。这里选择了 adam 优化器,也可以尝试用其余优化器
- 损失函数:其实就是一个函数,告诉模型在学习过程中,在每个批次(数据子集)时的体现如何。这里选择 meanSquaredError 来比较模型的预测和真实值
const batchSize = 32;const epochs = 50;
设置超参数 batchSize 和一个 epochs 的数量。
batchSize 指的是模型在每次迭代训练中看到的数据子集的大小。常见的批次大小往往在 32-512 之间取值。批次大小对于训练速度是有所影响的
epochs 完成整个数据集进行训练的次数
开始训练
return await model.fit(inputs, labels, { batchSize, epochs, callbacks: tfvis.show.fitCallbacks( { name: 'Training Performance' }, ['loss', 'mse'], { height: 200, callbacks: ['onEpochEnd'] } )});
model.fit 是来启动训练的函数。这是一个异步函数,所以返回会是一个 promise。
为了监控训练进度,回调传函数作为 model.fit 来获取训练过程中信息。而后回调函数使用 tfvis.show.fitCallbacks 来定义,而后可以绘制损失值对于迭代的图标
const tensorData = convertToTensor(data);const {inputs, labels} = tensorData;// Train the modelawait trainModel(model, inputs, labels);console.log('Done Training');
这的注意的这部分代码要写在 run 函数中,具体如下
async function run() { // 加载数据 const data = await getData(); // 解决原始数据,将数据 horsepower 映射为 x 而 mpg 则映射为 y const values = data.map(d => ({ x: d.horsepower, y: d.mpg, })); // 将数据以散点图形式显示在开发者调试工具 tfvis.render.scatterplot( {name: 'Horsepower v MPG'}, {values}, { xLabel: 'Horsepower', yLabel: 'MPG', height: 300 } ); const model = createModel(); const tensorData = convertToTensor(data); const {inputs, labels} = tensorData; // Train the model await trainModel(model, inputs, labels); console.log('Done Training');}
说明
1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长!
2. 分享目的仅供大家学习和交流,您必须在下载后24小时内删除!
3. 不得使用于非法商业用途,不得违反国家法律。否则后果自负!
4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解!
5. 如有链接无法下载、失效或广告,请联系管理员处理!
6. 本站资源售价只是摆设,本站源码仅提供给会员学习使用!
7. 如遇到加密压缩包,请使用360解压,如遇到无法解压的请联系管理员
开心源码网 » 浏览器端的机器学习 tensorflowjs(6) 训练模型
1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长!
2. 分享目的仅供大家学习和交流,您必须在下载后24小时内删除!
3. 不得使用于非法商业用途,不得违反国家法律。否则后果自负!
4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解!
5. 如有链接无法下载、失效或广告,请联系管理员处理!
6. 本站资源售价只是摆设,本站源码仅提供给会员学习使用!
7. 如遇到加密压缩包,请使用360解压,如遇到无法解压的请联系管理员
开心源码网 » 浏览器端的机器学习 tensorflowjs(6) 训练模型