gpt4 book ai didi

javascript - 如何在nodejs(tensorflow.js)中训练模型?

转载 作者:行者123 更新时间:2023-12-03 00:21:07 24 4
gpt4 key购买 nike

我想做一个图像分类器,但我不会Python。Tensorflow.js 使用我熟悉的 javascript 工作。可以用它来训练模型吗?训练步骤是什么?坦白说,我不知道从哪里开始。

我唯一想到的是如何加载“mobilenet”,它显然是一组预先训练的模型,并用它对图像进行分类:

const tf = require('@tensorflow/tfjs'),
mobilenet = require('@tensorflow-models/mobilenet'),
tfnode = require('@tensorflow/tfjs-node'),
fs = require('fs-extra');

const imageBuffer = await fs.readFile(......),
tfimage = tfnode.node.decodeImage(imageBuffer),
mobilenetModel = await mobilenet.load();

const results = await mobilenetModel.classify(tfimage);

这可行,但对我来说没有用,因为我想使用带有我创建的标签的图像来训练我自己的模型。

========================

假设我有一堆图像和标签。如何使用它们来训练模型?

const myData = JSON.parse(await fs.readFile('files.json'));

for(const data of myData){
const image = await fs.readFile(data.imagePath),
labels = data.labels;

// how to train, where to pass image and labels ?

}

最佳答案

首先,图像需要转换为张量。第一种方法是创建一个包含所有特征的张量(分别是包含所有标签的张量)。仅当数据集包含少量图像时才应采用这种方法。

  const imageBuffer = await fs.readFile(feature_file);
tensorFeature = tfnode.node.decodeImage(imageBuffer) // create a tensor for the image

// create an array of all the features
// by iterating over all the images
tensorFeatures = tf.stack([tensorFeature, tensorFeature2, tensorFeature3])

标签将是一个数组,指示每个图像的类型

 labelArray = [0, 1, 2] // maybe 0 for dog, 1 for cat and 2 for birds

现在需要创建标签的热编码

 tensorLabels = tf.oneHot(tf.tensor1d(labelArray, 'int32'), 3);

一旦有了张量,就需要创建训练模型。这是一个简单的模型。

const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [height, width, numberOfChannels], // numberOfChannels = 3 for colorful images and one otherwise
filters: 32,
kernelSize: 3,
activation: 'relu',
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 3, activation: 'softmax'}));

然后就可以训练模型了

model.fit(tensorFeatures, tensorLabels)

如果数据集包含大量图像,则需要创建一个 tfDataset。这个answer讨论原因。

const genFeatureTensor = image => {
const imageBuffer = await fs.readFile(feature_file);
return tfnode.node.decodeImage(imageBuffer)
}

const labelArray = indice => Array.from({length: numberOfClasses}, (_, k) => k === indice ? 1 : 0)

function* dataGenerator() {
const numElements = numberOfImages;
let index = 0;
while (index < numFeatures) {
const feature = genFeatureTensor(imagePath);
const label = tf.tensor1d(labelArray(classImageIndex))
index++;
yield {xs: feature, ys: label};
}
}

const ds = tf.data.generator(dataGenerator).batch(1) // specify an appropriate batchsize;

并使用model.fitDataset(ds)来训练模型

<小时/>

以上是在nodejs中进行的训练。要在浏览器中进行这样的处理,genFeatureTensor可以编写如下:

function loadImage(url){
return new Promise((resolve, reject) => {
const im = new Image()
im.crossOrigin = 'anonymous'
im.src = 'url'
im.onload = () => {
resolve(im)
}
})
}

genFeatureTensor = image => {
const img = await loadImage(image);
return tf.browser.fromPixels(image);
}

需要注意的是,进行繁重的处理可能会阻塞浏览器中的主线程。这就是网络 worker 发挥作用的地方。

关于javascript - 如何在nodejs(tensorflow.js)中训练模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58953399/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com