我指的是这个例子
https://github.com/tensorflow/tfjs-examples/tree/master/webcam-transfer-learning
https://storage.googleapis.com/tfjs-examples/webcam-transfer-learning/dist/index.html
您可以看到,演示可以拍摄各种场景,训练模型,然后预测网络摄像头场景的情况。
我将演示代码更改为自己的代码,并使用文件输入上传了大量图片作为输入示例。
当我上传许多(300-400)图片224*244图片时,每个图片的大小约为70kb,我的图形卡内存(rx 570 4gb)将被填满,然后崩溃。
这是我的演示视频
https://www.youtube.com/watch?v=irnd29lcqi0
错误消息 Uncaught (in promise) Error: Failed to compile fragment shader.
这是我的代码:
class ControllerDataset {
constructor(numClasses) {
this.numClasses = numClasses;
}
/**
* Adds an example to the controller dataset.
* @param {Tensor} example A tensor representing the example. It can be an image,
* an activation, or any other type of Tensor.
* @param {number} label The label of the example. Should be a number.
*/
addExample(example, label) {
// One-hot encode the label.
const y = tf.tidy(
() => tf.oneHot(tf.tensor1d([label]).toInt(), this.numClasses));
if (this.xs == null) {
// For the first example that gets added, keep example and y so that the
// ControllerDataset owns the memory of the inputs. This makes sure that
// if addExample() is called in a tf.tidy(), these Tensors will not get
// disposed.
this.xs = tf.keep(example);
this.ys = tf.keep(y);
} else {
const oldX = this.xs;
this.xs = tf.keep(oldX.concat(example, 0));
const oldY = this.ys;
this.ys = tf.keep(oldY.concat(y, 0));
oldX.dispose();
oldY.dispose();
y.dispose();
}
}
}
var truncatedMobileNet;
const NUM_CLASSES = 3;
const controllerDataset = new ControllerDataset(NUM_CLASSES);
async function addMultiSampleFromInputfile(files, label) {
for (let index = 0; index < files.length; index++) {
const file = files[index];
let image = await readFileToImageElement(file);
let { sourceImageTensor, imageTensorNormalize } = getTensorImgFromElement(image)
controllerDataset.addExample(truncatedMobileNet.predict(imageTensorNormalize), label);
sourceImageTensor.dispose();
imageTensorNormalize.dispose();
}
}
// Loads mobilenet and returns a model that returns the internal activation
// we'll use as input to our classifier model.
async function loadTruncatedMobileNet() {
const url = document.getElementById("MobileNetUrl").value
const mobilenet = await tf.loadLayersModel(url);
// Return a model that outputs an internal activation.
const layer = mobilenet.getLayer('conv_pw_13_relu');
return tf.model({ inputs: mobilenet.inputs, outputs: layer.output });
}
function getTensorImgFromElement(element) {
const imageTensor = tf.browser.fromPixels(element);
const processedImgTensor = tf.tidy(() => imageTensor.expandDims(0).toFloat().div(127).sub(1));
return { sourceImageTensor: imageTensor, imageTensorNormalize: processedImgTensor }
}
function readFileToImageElement(file) {
return new Promise((resolve, reject) => {
let reader = new FileReader();
reader.onload = function() {
let image = document.createElement('img');
image.src = this.result;
image.onload = function() {
resolve(image)
}
}
reader.readAsDataURL(file);
});
}
loadTruncatedMobileNet().then(model => {
truncatedMobileNet = model;
})
// add multi sample from html input file
addMultiSmapleBtn.onclick = () => {
// label value is 0 or 1 or 2
if (truncatedMobileNet)
addMultiSampleFromInputfile(imagefiles.files, parseInt(label.value))
}
1条答案
按热度按时间iugsix8n1#
试试这个: