浏览器端的机器学习 tensorflowjs(6) 训练模型

作者 : 开心源码 本文共1829个字,预计阅读时间需要5分钟 发布时间: 2022-05-14 共208人阅读

cover_002.png

现在模型已经定义好了,数据也下载并进行了解决,一切准备就绪准备开始训练。

async function trainModel(model, inputs, labels) {  // 准备要训练的模型  model.compile({    optimizer: tf.train.adam(),    loss: tf.losses.meanSquaredError,    metrics: ['mse'],  });  const batchSize = 32;  const epochs = 50;  return await model.fit(inputs, labels, {    batchSize,    epochs,    shuffle: true,    callbacks: tfvis.show.fitCallbacks(      { name: 'Training Performance' },      ['loss', 'mse'],      { height: 200, callbacks: ['onEpochEnd'] }    )  });}

训练前的少量准备

model.compile({  optimizer: tf.train.adam(),  loss: tf.losses.meanSquaredError,  metrics: ['mse'],});

在训练模型之前,需要 “编译 “该模型,那么具体应该如何做呢? 我们需要一个优化和一个损失函数,损失函数也可以了解目标函数,主要是指定训练,让我们训练一个目标,优化器这是给出一个策略如何在训练过程升级参数。

  • 优化器。这是一种算法,是升级参数的算法。在 TensorFlow.js 中有许多优化器可用。这里选择了 adam 优化器,也可以尝试用其余优化器
  • 损失函数:其实就是一个函数,告诉模型在学习过程中,在每个批次(数据子集)时的体现如何。这里选择 meanSquaredError 来比较模型的预测和真实值
const batchSize = 32;const epochs = 50;

设置超参数 batchSize 和一个 epochs 的数量。

  • batchSize 指的是模型在每次迭代训练中看到的数据子集的大小。常见的批次大小往往在 32-512 之间取值。批次大小对于训练速度是有所影响的

  • epochs 完成整个数据集进行训练的次数

开始训练

return await model.fit(inputs, labels, {  batchSize,  epochs,  callbacks: tfvis.show.fitCallbacks(    { name: 'Training Performance' },    ['loss', 'mse'],    { height: 200, callbacks: ['onEpochEnd'] }  )});

model.fit 是来启动训练的函数。这是一个异步函数,所以返回会是一个 promise。

为了监控训练进度,回调传函数作为 model.fit 来获取训练过程中信息。而后回调函数使用 tfvis.show.fitCallbacks 来定义,而后可以绘制损失值对于迭代的图标

const tensorData = convertToTensor(data);const {inputs, labels} = tensorData;// Train the modelawait trainModel(model, inputs, labels);console.log('Done Training');

这的注意的这部分代码要写在 run 函数中,具体如下

async function run() {    // 加载数据    const data = await getData();    // 解决原始数据,将数据 horsepower 映射为 x 而 mpg 则映射为 y    const values = data.map(d => ({      x: d.horsepower,      y: d.mpg,    }));    // 将数据以散点图形式显示在开发者调试工具          tfvis.render.scatterplot(      {name: 'Horsepower v MPG'},      {values},      {        xLabel: 'Horsepower',        yLabel: 'MPG',        height: 300      }    );    const model = createModel();    const tensorData = convertToTensor(data);    const {inputs, labels} = tensorData;    // Train the model    await trainModel(model, inputs, labels);    console.log('Done Training');}
说明
1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长!
2. 分享目的仅供大家学习和交流,您必须在下载后24小时内删除!
3. 不得使用于非法商业用途,不得违反国家法律。否则后果自负!
4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解!
5. 如有链接无法下载、失效或广告,请联系管理员处理!
6. 本站资源售价只是摆设,本站源码仅提供给会员学习使用!
7. 如遇到加密压缩包,请使用360解压,如遇到无法解压的请联系管理员
开心源码网 » 浏览器端的机器学习 tensorflowjs(6) 训练模型

发表回复