Appearance
全局配置
Appearance
寻找Emoji声音命令使用预训练模型进行图片分类使用摄像头控制吃豆人
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
npm install @tensorflow/tfjs
或
yarn add @tensorflow/tfjs
shape中的元素代表每层数组的长度。
import * as tf from '@tensorflow/tfjs';
// 传统 for 循环
const input = [1, 2, 3, 4];
const w = [[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]];
const output = [0, 0, 0, 0];
// 1.使用for循环计算
for (let i = 0; i < w.length; i++) {
for (let j = 0; j < input.length; j++) {
output[i] += input[j] * w[i][j];
}
}
console.log(output); // [30, 40, 50, 60]
// 2.使用TensorFlow计算
// 权重点乘输入
tf.tensor(w).dot(tf.tensor(input)).print(); // [30, 40, 50, 60]
// 除了简洁意外,由于使用向量,还方便使用GPU计算加速,更加提高了性能。
// 关于dot语法:https://js.tensorflow.org/api/latest/#dot
安装tfvis从而使数据可视化:
https://js.tensorflow.org/api_vis/1.5.1/
散点图:
import * as tfvis from '@tensorflow/tfjs-vis';
window.onload = async () => {
const xs = [1, 2, 3, 4];
const ys = [1, 3, 5, 7];
tfvis.render.scatterplot(
{ name: '线性回归训练集' },
{ values: xs.map((x, i) => ({ x, y: ys[i] })) },
{ xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
);
};
连续的模型:https://js.tensorflow.org/api/3.16.0/#sequential
https://js.tensorflow.org/api/3.16.0/#layers.dense
由于
故选用layers.dense()创造
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
window.onload = async () => {
const xs = [1, 2, 3, 4];
const ys = [1, 3, 5, 7];
tfvis.render.scatterplot(
{ name: '线性回归训练集' },
{ values: xs.map((x, i) => ({ x, y: ys[i] })) },
{ xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
);
// 创建模型
const model = tf.sequential();
// 添加层,由于满足output = activation(dot(input, kernel) + bias)关系,故选用dense创造。
// units是神经元个数。inputShape是一维数据且长度是1。
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
};
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
window.onload = async () => {
const xs = [1, 2, 3, 4];
const ys = [1, 3, 5, 7];
tfvis.render.scatterplot(
{ name: '线性回归训练集' },
{ values: xs.map((x, i) => ({ x, y: ys[i] })) },
{ xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
);
// 创建模型
const model = tf.sequential();
// 添加层,由于满足output = activation(dot(input, kernel) + bias)关系,故选用dense创造。
// units是神经元个数。inputShape是一维数据且长度是1。
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
// 编译函数内传入MSE均方差损失函数
model.compile({ loss: tf.losses.meanSquaredError });
};
迭代方法:https://developers.google.com/machine-learning/crash-course/reducing-loss/an-iterative-approach梯度下降:https://developers.google.com/machine-learning/crash-course/reducing-loss/gradient-descent学习速率:https://developers.google.com/machine-learning/crash-course/reducing-loss/learning-rate优化学习速率:https://developers.google.com/machine-learning/crash-course/fitter/graph随机梯度下降:https://developers.google.cn/machine-learning/crash-course/reducing-loss/stochastic-gradient-descent
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
window.onload = async () => {
const xs = [1, 2, 3, 4];
const ys = [1, 3, 5, 7];
tfvis.render.scatterplot(
{ name: '线性回归训练集' },
{ values: xs.map((x, i) => ({ x, y: ys[i] })) },
{ xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
);
// 创建模型
const model = tf.sequential();
// 添加层,由于满足output = activation(dot(input, kernel) + bias)关系,故选用dense创造。
// units是神经元个数。inputShape是一维数据且长度是1。
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
// 编译函数内传入MSE均方差损失函数和随机梯度下降SGD优化器,并设置学习速率为0.1
model.compile({ loss: tf.losses.meanSquaredError, optimizer: tf.train.sgd(0.1) });
};
其它优化器举例:
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
window.onload = async () => {
const xs = [1, 2, 3, 4];
const ys = [1, 3, 5, 7];
tfvis.render.scatterplot(
{ name: '线性回归训练集' },
{ values: xs.map((x, i) => ({ x, y: ys[i] })) },
{ xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
);
// 创建模型
const model = tf.sequential();
// 添加层,由于满足output = activation(dot(input, kernel) + bias)关系,故选用dense创造。
// units是神经元个数。inputShape是一维数据且长度是1。
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
// 编译函数内传入MSE均方差损失函数和随机梯度下降SGD优化器,并设置学习速率为0.1
model.compile({ loss: tf.losses.meanSquaredError, optimizer: tf.train.sgd(0.1) });
// 输入和输出转换为Tensor格式数据
const inputs = tf.tensor(xs);
const labels = tf.tensor(ys);
// 进行训练(异步的)
await model.fit(inputs, labels, {
// 批量数据,设置每次学习的量
batchSize: 4,
epochs: 200,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练过程' },
['loss']
)
});
};
纵轴loss指损失的值 onBatchEnd图指一个小批量变化后就执行得到的图表 onEpochEnd图指一个时期/纪元完成时执行得到的图表
设置epochs为100,但训练数据一共有4个,所以batch数量总共就是400
batchSize改为4后,图像变得更加平滑,没有了抖动。(因为batchSize为1时,每次给模型喂1个数据。使用随机梯度下降法时,刚开始部分的数据不确定性很大,因此出现抖动)
另外,学习率也不宜调的过大: 学习率由0.1改为0.15时,对应图表如下: 发现损失函数仍处在上升阶段,不能很好的预测数据。
学习率由0.1改为0.01时,对应图表如下: 有较好的精度,但学习效率大幅下降。
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
window.onload = async () => {
const xs = [1, 2, 3, 4];
const ys = [1, 3, 5, 7];
tfvis.render.scatterplot(
{ name: '线性回归训练集' },
{ values: xs.map((x, i) => ({ x, y: ys[i] })) },
{ xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
);
// 创建模型
const model = tf.sequential();
// 添加层,由于满足output = activation(dot(input, kernel) + bias)关系,故选用dense创造。
// units是神经元个数。inputShape是一维数据且长度是1。
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
// 编译函数内传入MSE均方差损失函数和随机梯度下降SGD优化器,并设置学习速率为0.1
model.compile({ loss: tf.losses.meanSquaredError, optimizer: tf.train.sgd(0.1) });
// 输入和输出转换为Tensor格式数据
const inputs = tf.tensor(xs);
const labels = tf.tensor(ys);
// 进行训练(异步的)
await model.fit(inputs, labels, {
// 批量数据,设置每次学习的量
batchSize: 4,
epochs: 200,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练过程' },
['loss']
)
});
// 预测数据
const output = model.predict(tf.tensor([5]));
alert(`如果 x 为 5,那么预测 y 为 ${output.dataSync()[0]}`);
};
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
window.onload = async () => {
const heights = [150, 160, 170];
const weights = [40, 50, 60];
tfvis.render.scatterplot(
{ name: '身高体重训练数据' },
{ values: heights.map((x, i) => ({ x, y: weights[i] })) },
{
xAxisDomain: [140, 180],
yAxisDomain: [30, 70]
}
);
const inputs = tf.tensor(heights).sub(150).div(20);
const labels = tf.tensor(weights).sub(40).div(20);
const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
model.compile({ loss: tf.losses.meanSquaredError, optimizer: tf.train.sgd(0.1) });
await model.fit(inputs, labels, {
batchSize: 3,
epochs: 200,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练过程' },
['loss']
)
});
// 张量预测前需要经过归一化数据后计算
const output = model.predict(tf.tensor([180]).sub(150).div(20));
// 训练后需要进行反归一化
alert(`如果身高为 180cm,那么预测体重为 ${output.mul(20).add(40).dataSync()[0]}kg`);
};
data.js:
export function getData(numSamples) {
let points = [];
function genGauss(cx, cy, label) {
for (let i = 0; i < numSamples / 2; i++) {
let x = normalRandom(cx);
let y = normalRandom(cy);
points.push({ x, y, label });
}
}
genGauss(2, 2, 1);
genGauss(-2, -2, 0);
return points;
}
/**
* Samples from a normal distribution. Uses the seedrandom library as the
* random generator.
*
* @param mean The mean. Default is 0.
* @param variance The variance. Default is 1.
*/
function normalRandom(mean = 0, variance = 1) {
let v1, v2, s;
do {
v1 = 2 * Math.random() - 1;
v2 = 2 * Math.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s > 1);
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
return mean + Math.sqrt(variance) * result;
}
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';
window.onload = async () => {
// 获取400个点的数据
const data = getData(400);
// console.log('data',data) // 结果见上图
tfvis.render.scatterplot(
{ name: '逻辑回归训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
};
log loss:https://developers.google.cn/machine-learning/crash-course/logistic-regression/model-training?hl=zh_cn
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';
window.onload = async () => {
const data = getData(400);
console.log('data', data)
tfvis.render.scatterplot(
{ name: '逻辑回归训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
const model = tf.sequential();
model.add(tf.layers.dense({
units: 1,
inputShape: [2],
activation: 'sigmoid'
}));
model.compile({
// 设置对数损失函数
loss: tf.losses.logLoss,
});
};
export function getData(numSamples) {
let points = [];
function genGauss(cx, cy, label) {
for (let i = 0; i < numSamples / 2; i++) {
let x = normalRandom(cx);
let y = normalRandom(cy);
points.push({ x, y, label });
}
}
genGauss(2, 2, 1);
genGauss(-2, -2, 0);
return points;
}
/**
* Samples from a normal distribution. Uses the seedrandom library as the
* random generator.
*
* @param mean The mean. Default is 0.
* @param variance The variance. Default is 1.
*/
function normalRandom(mean = 0, variance = 1) {
let v1, v2, s;
do {
v1 = 2 * Math.random() - 1;
v2 = 2 * Math.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s > 1);
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
return mean + Math.sqrt(variance) * result;
}
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';
window.onload = async () => {
const data = getData(400);
console.log('data', data)
tfvis.render.scatterplot(
{ name: '逻辑回归训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
const model = tf.sequential();
model.add(tf.layers.dense({
units: 1,
inputShape: [2],
activation: 'sigmoid'
}));
model.compile({
loss: tf.losses.logLoss,
optimizer: tf.train.adam(0.1)
});
// 将训练数据变为tensor
const inputs = tf.tensor(data.map(p => [p.x, p.y]));
const labels = tf.tensor(data.map(p => p.label));
// 开始训练
await model.fit(inputs, labels, {
batchSize: 40,
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss']
)
});
};
<form action="" onsubmit="predict(this);return false;">
x: <input type="text" name="x">
y: <input type="text" name="y">
<button type="submit">预测</button>
</form>
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';
window.onload = async () => {
const data = getData(400);
console.log('data', data)
tfvis.render.scatterplot(
{ name: '逻辑回归训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
const model = tf.sequential();
model.add(tf.layers.dense({
units: 1,
inputShape: [2],
activation: 'sigmoid'
}));
model.compile({
loss: tf.losses.logLoss,
optimizer: tf.train.adam(0.1)
});
const inputs = tf.tensor(data.map(p => [p.x, p.y]));
const labels = tf.tensor(data.map(p => p.label));
await model.fit(inputs, labels, {
batchSize: 40,
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss']
)
});
window.predict = (form) => {
const pred = model.predict(tf.tensor([[form.x.value * 1, form.y.value * 1]]));
alert(`预测结果:${pred.dataSync()[0]}`);
};
};
export function getData(numSamples) {
let points = [];
// 生成以指定点为中心,指定分布范围,符合高斯分布(正态分布)的点
function genGauss(cx, cy, label) {
for (let i = 0; i < numSamples / 2; i++) {
let x = normalRandom(cx);
let y = normalRandom(cy);
points.push({ x, y, label });
}
}
genGauss(2, 2, 1);
genGauss(-2, -2, 0);
return points;
}
/**
* Samples from a normal distribution. Uses the seedrandom library as the
* random generator.
*
* @param mean The mean. Default is 0.
* @param variance The variance. Default is 1.
*/
function normalRandom(mean = 0, variance = 1) {
let v1, v2, s;
do {
v1 = 2 * Math.random() - 1;
v2 = 2 * Math.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s > 1);
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
return mean + Math.sqrt(variance) * result;
}
由于原公式中使用了三角函数,性能不是很好,经过推导,可使用另外的公式替代:
v1,v2对应上面公式中的u,v s是u方和v方的和
同时由于要取圆圈内的值,即半径s需要小于1(取值计算后发现s>1则需要放弃该值),故最终生成满足高斯分布随机值的函数的算法是
function normalRandom(mean = 0, variance = 1) {
let v1, v2, s;
do {
v1 = 2 * Math.random() - 1;
v2 = 2 * Math.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s > 1);
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
// mean改变中心
return mean + Math.sqrt(variance) * result;
}
https://playground.tensorflow.org/
data.js:
export function getData(numSamples) {
let points = [];
function genGauss(cx, cy, label) {
for (let i = 0; i < numSamples / 2; i++) {
let x = normalRandom(cx);
let y = normalRandom(cy);
points.push({ x, y, label });
}
}
genGauss(2, 2, 0);
genGauss(-2, -2, 0);
genGauss(-2, 2, 1);
genGauss(2, -2, 1);
return points;
}
/**
* Samples from a normal distribution. Uses the seedrandom library as the
* random generator.
*
* @param mean The mean. Default is 0.
* @param variance The variance. Default is 1.
*/
function normalRandom(mean = 0, variance = 1) {
let v1, v2, s;
do {
v1 = 2 * Math.random() - 1;
v2 = 2 * Math.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s > 1);
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
return mean + Math.sqrt(variance) * result;
}
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';
window.onload = async () => {
const data = getData(400);
tfvis.render.scatterplot(
{ name: 'XOR 训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
};
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';
window.onload = async () => {
const data = getData(400);
tfvis.render.scatterplot(
{ name: 'XOR 训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
const model = tf.sequential();
// 创建输出层
model.add(tf.layers.dense({
units: 4,
inputShape: [2],
activation: 'relu'
}));
// 创建隐藏层
model.add(tf.layers.dense({
units: 1,
activation: 'sigmoid'
}));
};
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';
window.onload = async () => {
const data = getData(400);
tfvis.render.scatterplot(
{ name: 'XOR 训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
const model = tf.sequential();
// 创建输出层
model.add(tf.layers.dense({
units: 4,
inputShape: [2],
activation: 'relu'
}));
// 创建隐藏层
model.add(tf.layers.dense({
units: 1,
activation: 'sigmoid'
}));
model.compile({
loss: tf.losses.logLoss,
optimizer: tf.train.adam(0.1)
});
const inputs = tf.tensor(data.map(p => [p.x, p.y]));
const labels = tf.tensor(data.map(p => p.label));
await model.fit(inputs, labels, {
epochs: 10,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss']
)
});
window.predict = (form) => {
const pred = model.predict(tf.tensor([[form.x.value * 1, form.y.value * 1]]));
alert(`预测结果:${pred.dataSync()[0]}`);
};
};
<form action="" onsubmit="predict(this);return false;">
x: <input type="text" name="x">
y: <input type="text" name="y">
<button type="submit">预测</button>
</form>
data.js:
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs';
export const IRIS_CLASSES =
['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾'];
export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;
// Iris flowers data. Source:
// https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
const IRIS_DATA = [
[5.1, 3.5, 1.4, 0.2, 0], [4.9, 3.0, 1.4, 0.2, 0], [4.7, 3.2, 1.3, 0.2, 0],
[4.6, 3.1, 1.5, 0.2, 0], [5.0, 3.6, 1.4, 0.2, 0], [5.4, 3.9, 1.7, 0.4, 0],
[4.6, 3.4, 1.4, 0.3, 0], [5.0, 3.4, 1.5, 0.2, 0], [4.4, 2.9, 1.4, 0.2, 0],
[4.9, 3.1, 1.5, 0.1, 0], [5.4, 3.7, 1.5, 0.2, 0], [4.8, 3.4, 1.6, 0.2, 0],
[4.8, 3.0, 1.4, 0.1, 0], [4.3, 3.0, 1.1, 0.1, 0], [5.8, 4.0, 1.2, 0.2, 0],
[5.7, 4.4, 1.5, 0.4, 0], [5.4, 3.9, 1.3, 0.4, 0], [5.1, 3.5, 1.4, 0.3, 0],
[5.7, 3.8, 1.7, 0.3, 0], [5.1, 3.8, 1.5, 0.3, 0], [5.4, 3.4, 1.7, 0.2, 0],
[5.1, 3.7, 1.5, 0.4, 0], [4.6, 3.6, 1.0, 0.2, 0], [5.1, 3.3, 1.7, 0.5, 0],
[4.8, 3.4, 1.9, 0.2, 0], [5.0, 3.0, 1.6, 0.2, 0], [5.0, 3.4, 1.6, 0.4, 0],
[5.2, 3.5, 1.5, 0.2, 0], [5.2, 3.4, 1.4, 0.2, 0], [4.7, 3.2, 1.6, 0.2, 0],
[4.8, 3.1, 1.6, 0.2, 0], [5.4, 3.4, 1.5, 0.4, 0], [5.2, 4.1, 1.5, 0.1, 0],
[5.5, 4.2, 1.4, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [5.0, 3.2, 1.2, 0.2, 0],
[5.5, 3.5, 1.3, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [4.4, 3.0, 1.3, 0.2, 0],
[5.1, 3.4, 1.5, 0.2, 0], [5.0, 3.5, 1.3, 0.3, 0], [4.5, 2.3, 1.3, 0.3, 0],
[4.4, 3.2, 1.3, 0.2, 0], [5.0, 3.5, 1.6, 0.6, 0], [5.1, 3.8, 1.9, 0.4, 0],
[4.8, 3.0, 1.4, 0.3, 0], [5.1, 3.8, 1.6, 0.2, 0], [4.6, 3.2, 1.4, 0.2, 0],
[5.3, 3.7, 1.5, 0.2, 0], [5.0, 3.3, 1.4, 0.2, 0], [7.0, 3.2, 4.7, 1.4, 1],
[6.4, 3.2, 4.5, 1.5, 1], [6.9, 3.1, 4.9, 1.5, 1], [5.5, 2.3, 4.0, 1.3, 1],
[6.5, 2.8, 4.6, 1.5, 1], [5.7, 2.8, 4.5, 1.3, 1], [6.3, 3.3, 4.7, 1.6, 1],
[4.9, 2.4, 3.3, 1.0, 1], [6.6, 2.9, 4.6, 1.3, 1], [5.2, 2.7, 3.9, 1.4, 1],
[5.0, 2.0, 3.5, 1.0, 1], [5.9, 3.0, 4.2, 1.5, 1], [6.0, 2.2, 4.0, 1.0, 1],
[6.1, 2.9, 4.7, 1.4, 1], [5.6, 2.9, 3.6, 1.3, 1], [6.7, 3.1, 4.4, 1.4, 1],
[5.6, 3.0, 4.5, 1.5, 1], [5.8, 2.7, 4.1, 1.0, 1], [6.2, 2.2, 4.5, 1.5, 1],
[5.6, 2.5, 3.9, 1.1, 1], [5.9, 3.2, 4.8, 1.8, 1], [6.1, 2.8, 4.0, 1.3, 1],
[6.3, 2.5, 4.9, 1.5, 1], [6.1, 2.8, 4.7, 1.2, 1], [6.4, 2.9, 4.3, 1.3, 1],
[6.6, 3.0, 4.4, 1.4, 1], [6.8, 2.8, 4.8, 1.4, 1], [6.7, 3.0, 5.0, 1.7, 1],
[6.0, 2.9, 4.5, 1.5, 1], [5.7, 2.6, 3.5, 1.0, 1], [5.5, 2.4, 3.8, 1.1, 1],
[5.5, 2.4, 3.7, 1.0, 1], [5.8, 2.7, 3.9, 1.2, 1], [6.0, 2.7, 5.1, 1.6, 1],
[5.4, 3.0, 4.5, 1.5, 1], [6.0, 3.4, 4.5, 1.6, 1], [6.7, 3.1, 4.7, 1.5, 1],
[6.3, 2.3, 4.4, 1.3, 1], [5.6, 3.0, 4.1, 1.3, 1], [5.5, 2.5, 4.0, 1.3, 1],
[5.5, 2.6, 4.4, 1.2, 1], [6.1, 3.0, 4.6, 1.4, 1], [5.8, 2.6, 4.0, 1.2, 1],
[5.0, 2.3, 3.3, 1.0, 1], [5.6, 2.7, 4.2, 1.3, 1], [5.7, 3.0, 4.2, 1.2, 1],
[5.7, 2.9, 4.2, 1.3, 1], [6.2, 2.9, 4.3, 1.3, 1], [5.1, 2.5, 3.0, 1.1, 1],
[5.7, 2.8, 4.1, 1.3, 1], [6.3, 3.3, 6.0, 2.5, 2], [5.8, 2.7, 5.1, 1.9, 2],
[7.1, 3.0, 5.9, 2.1, 2], [6.3, 2.9, 5.6, 1.8, 2], [6.5, 3.0, 5.8, 2.2, 2],
[7.6, 3.0, 6.6, 2.1, 2], [4.9, 2.5, 4.5, 1.7, 2], [7.3, 2.9, 6.3, 1.8, 2],
[6.7, 2.5, 5.8, 1.8, 2], [7.2, 3.6, 6.1, 2.5, 2], [6.5, 3.2, 5.1, 2.0, 2],
[6.4, 2.7, 5.3, 1.9, 2], [6.8, 3.0, 5.5, 2.1, 2], [5.7, 2.5, 5.0, 2.0, 2],
[5.8, 2.8, 5.1, 2.4, 2], [6.4, 3.2, 5.3, 2.3, 2], [6.5, 3.0, 5.5, 1.8, 2],
[7.7, 3.8, 6.7, 2.2, 2], [7.7, 2.6, 6.9, 2.3, 2], [6.0, 2.2, 5.0, 1.5, 2],
[6.9, 3.2, 5.7, 2.3, 2], [5.6, 2.8, 4.9, 2.0, 2], [7.7, 2.8, 6.7, 2.0, 2],
[6.3, 2.7, 4.9, 1.8, 2], [6.7, 3.3, 5.7, 2.1, 2], [7.2, 3.2, 6.0, 1.8, 2],
[6.2, 2.8, 4.8, 1.8, 2], [6.1, 3.0, 4.9, 1.8, 2], [6.4, 2.8, 5.6, 2.1, 2],
[7.2, 3.0, 5.8, 1.6, 2], [7.4, 2.8, 6.1, 1.9, 2], [7.9, 3.8, 6.4, 2.0, 2],
[6.4, 2.8, 5.6, 2.2, 2], [6.3, 2.8, 5.1, 1.5, 2], [6.1, 2.6, 5.6, 1.4, 2],
[7.7, 3.0, 6.1, 2.3, 2], [6.3, 3.4, 5.6, 2.4, 2], [6.4, 3.1, 5.5, 1.8, 2],
[6.0, 3.0, 4.8, 1.8, 2], [6.9, 3.1, 5.4, 2.1, 2], [6.7, 3.1, 5.6, 2.4, 2],
[6.9, 3.1, 5.1, 2.3, 2], [5.8, 2.7, 5.1, 1.9, 2], [6.8, 3.2, 5.9, 2.3, 2],
[6.7, 3.3, 5.7, 2.5, 2], [6.7, 3.0, 5.2, 2.3, 2], [6.3, 2.5, 5.0, 1.9, 2],
[6.5, 3.0, 5.2, 2.0, 2], [6.2, 3.4, 5.4, 2.3, 2], [5.9, 3.0, 5.1, 1.8, 2],
];
/**
* Convert Iris data arrays to `tf.Tensor`s.
*
* @param data The Iris input feature data, an `Array` of `Array`s, each element
* of which is assumed to be a length-4 `Array` (for petal length, petal
* width, sepal length, sepal width).
* @param targets An `Array` of numbers, with values from the set {0, 1, 2}:
* representing the true category of the Iris flower. Assumed to have the same
* array length as `data`.
* @param testSplit Fraction of the data at the end to split as test data: a
* number between 0 and 1.
* @return A length-4 `Array`, with
* - training data as `tf.Tensor` of shape [numTrainExapmles, 4].
* - training one-hot labels as a `tf.Tensor` of shape [numTrainExamples, 3]
* - test data as `tf.Tensor` of shape [numTestExamples, 4].
* - test one-hot labels as a `tf.Tensor` of shape [numTestExamples, 3]
*/
function convertToTensors(data, targets, testSplit) {
const numExamples = data.length;
if (numExamples !== targets.length) {
throw new Error('data and split have different numbers of examples');
}
// Randomly shuffle `data` and `targets`.
const indices = [];
for (let i = 0; i < numExamples; ++i) {
indices.push(i);
}
tf.util.shuffle(indices);
const shuffledData = [];
const shuffledTargets = [];
for (let i = 0; i < numExamples; ++i) {
shuffledData.push(data[indices[i]]);
shuffledTargets.push(targets[indices[i]]);
}
// Split the data into a training set and a tet set, based on `testSplit`.
const numTestExamples = Math.round(numExamples * testSplit);
const numTrainExamples = numExamples - numTestExamples;
const xDims = shuffledData[0].length;
// Create a 2D `tf.Tensor` to hold the feature data.
const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);
// Create a 1D `tf.Tensor` to hold the labels, and convert the number label
// from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);
// Split the data into training and test sets, using `slice`.
const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]);
const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]);
return [xTrain, yTrain, xTest, yTest];
}
/**
* Obtains Iris data, split into training and test sets.
*
* @param testSplit Fraction of the data at the end to split as test data: a
* number between 0 and 1.
*
* @param return A length-4 `Array`, with
* - training data as an `Array` of length-4 `Array` of numbers.
* - training labels as an `Array` of numbers, with the same length as the
* return training data above. Each element of the `Array` is from the set
* {0, 1, 2}.
* - test data as an `Array` of length-4 `Array` of numbers.
* - test labels as an `Array` of numbers, with the same length as the
* return test data above. Each element of the `Array` is from the set
* {0, 1, 2}.
*/
export function getIrisData(testSplit) {
return tf.tidy(() => {
const dataByClass = [];
const targetsByClass = [];
for (let i = 0; i < IRIS_CLASSES.length; ++i) {
dataByClass.push([]);
targetsByClass.push([]);
}
for (const example of IRIS_DATA) {
const target = example[example.length - 1];
const data = example.slice(0, example.length - 1);
dataByClass[target].push(data);
targetsByClass[target].push(target);
}
const xTrains = [];
const yTrains = [];
const xTests = [];
const yTests = [];
for (let i = 0; i < IRIS_CLASSES.length; ++i) {
const [xTrain, yTrain, xTest, yTest] =
convertToTensors(dataByClass[i], targetsByClass[i], testSplit);
xTrains.push(xTrain);
yTrains.push(yTrain);
xTests.push(xTest);
yTests.push(yTest);
}
const concatAxis = 0;
return [
tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis),
tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis)
];
});
}
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
// getIrisData用于获取训练集和验证集; IRIS_CLASSES 用于输出中文类别
import { getIrisData, IRIS_CLASSES } from './data';
window.onload = async () => {
// 百分之十五的数据用于验证集
// [xTrain:训练集输入特征, yTrain:训练集所有标签(输出结果), xTest:验证集输入特征, yTest:验证集所有标签(输出结果)
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
xTrain.print();
yTrain.print();
xTest.print();
yTest.print();
console.log(IRIS_CLASSES);
console.log( xTrain)
console.log( yTrain)
console.log(xTest);
console.log(yTest);
};
sigmoid激活函数解决二分类问题,softmax激活函数解决多分类问题。
https://developers.google.com/machine-learning/recommendation/dnn/softmax
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
// getIrisData用于获取训练集和验证集; IRIS_CLASSES 用于输出中文类别
import { getIrisData, IRIS_CLASSES } from './data';
window.onload = async () => {
// 百分之十五的数据用于验证集
// [xTrain:训练集输入特征, yTrain:训练集所有标签(输出结果), xTest:验证集输入特征, yTest:验证集所有标签(输出结果)
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
const model = tf.sequential();
model.add(tf.layers.dense({
units: 10,
inputShape: [xTrain.shape[1]],
activation: 'sigmoid'
}));
// 输出层
model.add(tf.layers.dense({
units: 3, // 神经元个数需为输出集个数。鸢尾花种类为3,故需设置为3
activation: 'softmax' // 输出3个概率,且其相加概率为1,故需设置激活函数为softmax
}));
};
交叉熵损失函数是对数损失函数的多分类版本。
https://ml-cheatsheet.readthedocs.io/en/latest/
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
// getIrisData用于获取训练集和验证集; IRIS_CLASSES 用于输出中文类别
import { getIrisData, IRIS_CLASSES } from './data';
window.onload = async () => {
// 百分之十五的数据用于验证集
// [xTrain:训练集输入特征, yTrain:训练集所有标签(输出结果), xTest:验证集输入特征, yTest:验证集所有标签(输出结果)
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
const model = tf.sequential();
model.add(tf.layers.dense({
units: 10,
inputShape: [xTrain.shape[1]],
activation: 'sigmoid'
}));
// 输出层
model.add(tf.layers.dense({
units: 3, // 神经元个数需为输出集个数。鸢尾花种类为3,故需设置为3
activation: 'softmax' // 输出3个概率,且其相加概率为1,故需设置激活函数为softmax
}));
model.compile({
// 设置交叉熵损失函数
loss: 'categoricalCrossentropy',
optimizer: tf.train.adam(0.1),
// 设置准确度度量
metrics: ['accuracy']
});
await model.fit(xTrain, yTrain, {
epochs: 100,
// 设置验证集
validationData: [xTest, yTest],
// 可视化训练过程
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' }, // 图表名称
['loss', 'val_loss', 'acc', 'val_acc'], // 度量:损失、验证集损失、准确度、验证集准确度
{ callbacks: ['onEpochEnd'] }
)
});
};
<form action="" onsubmit="predict(this); return false;">
花萼长度:<input type="text" name="a"><br>
花萼宽度:<input type="text" name="b"><br>
花瓣长度:<input type="text" name="c"><br>
花瓣宽度:<input type="text" name="d"><br>
<button type="submit">预测</button>
</form
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
// getIrisData用于获取训练集和验证集; IRIS_CLASSES 用于输出中文类别
import { getIrisData, IRIS_CLASSES } from './data';
window.onload = async () => {
// 百分之十五的数据用于验证集
// [xTrain:训练集输入特征, yTrain:训练集所有标签(输出结果), xTest:验证集输入特征, yTest:验证集所有标签(输出结果)
const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
const model = tf.sequential();
model.add(tf.layers.dense({
units: 10,
inputShape: [xTrain.shape[1]],
activation: 'sigmoid'
}));
// 输出层
model.add(tf.layers.dense({
units: 3, // 神经元个数需为输出集个数。鸢尾花种类为3,故需设置为3
activation: 'softmax' // 输出3个概率,且其相加概率为1,故需设置激活函数为softmax
}));
model.compile({
// 设置交叉熵损失函数
loss: 'categoricalCrossentropy',
optimizer: tf.train.adam(0.1),
// 设置准确度度量
metrics: ['accuracy']
});
await model.fit(xTrain, yTrain, {
epochs: 100,
// 设置验证集
validationData: [xTest, yTest],
// 可视化训练过程
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' }, // 图表名称
['loss', 'val_loss', 'acc', 'val_acc'], // 度量:损失、验证集损失、准确度、验证集准确度
{ callbacks: ['onEpochEnd'] }
)
});
window.predict = (form) => {
const input = tf.tensor([[
form.a.value * 1,
form.b.value * 1,
form.c.value * 1,
form.d.value * 1,
]]);
const pred = model.predict(input);
// argMax输出某个维中的最大值
alert(`预测结果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`);
};
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs';
export const IRIS_CLASSES =
['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾'];
export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;
// Iris flowers data. Source:
// https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
const IRIS_DATA = [
[5.1, 3.5, 1.4, 0.2, 0], [4.9, 3.0, 1.4, 0.2, 0], [4.7, 3.2, 1.3, 0.2, 0],
[4.6, 3.1, 1.5, 0.2, 0], [5.0, 3.6, 1.4, 0.2, 0], [5.4, 3.9, 1.7, 0.4, 0],
[4.6, 3.4, 1.4, 0.3, 0], [5.0, 3.4, 1.5, 0.2, 0], [4.4, 2.9, 1.4, 0.2, 0],
[4.9, 3.1, 1.5, 0.1, 0], [5.4, 3.7, 1.5, 0.2, 0], [4.8, 3.4, 1.6, 0.2, 0],
[4.8, 3.0, 1.4, 0.1, 0], [4.3, 3.0, 1.1, 0.1, 0], [5.8, 4.0, 1.2, 0.2, 0],
[5.7, 4.4, 1.5, 0.4, 0], [5.4, 3.9, 1.3, 0.4, 0], [5.1, 3.5, 1.4, 0.3, 0],
[5.7, 3.8, 1.7, 0.3, 0], [5.1, 3.8, 1.5, 0.3, 0], [5.4, 3.4, 1.7, 0.2, 0],
[5.1, 3.7, 1.5, 0.4, 0], [4.6, 3.6, 1.0, 0.2, 0], [5.1, 3.3, 1.7, 0.5, 0],
[4.8, 3.4, 1.9, 0.2, 0], [5.0, 3.0, 1.6, 0.2, 0], [5.0, 3.4, 1.6, 0.4, 0],
[5.2, 3.5, 1.5, 0.2, 0], [5.2, 3.4, 1.4, 0.2, 0], [4.7, 3.2, 1.6, 0.2, 0],
[4.8, 3.1, 1.6, 0.2, 0], [5.4, 3.4, 1.5, 0.4, 0], [5.2, 4.1, 1.5, 0.1, 0],
[5.5, 4.2, 1.4, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [5.0, 3.2, 1.2, 0.2, 0],
[5.5, 3.5, 1.3, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [4.4, 3.0, 1.3, 0.2, 0],
[5.1, 3.4, 1.5, 0.2, 0], [5.0, 3.5, 1.3, 0.3, 0], [4.5, 2.3, 1.3, 0.3, 0],
[4.4, 3.2, 1.3, 0.2, 0], [5.0, 3.5, 1.6, 0.6, 0], [5.1, 3.8, 1.9, 0.4, 0],
[4.8, 3.0, 1.4, 0.3, 0], [5.1, 3.8, 1.6, 0.2, 0], [4.6, 3.2, 1.4, 0.2, 0],
[5.3, 3.7, 1.5, 0.2, 0], [5.0, 3.3, 1.4, 0.2, 0], [7.0, 3.2, 4.7, 1.4, 1],
[6.4, 3.2, 4.5, 1.5, 1], [6.9, 3.1, 4.9, 1.5, 1], [5.5, 2.3, 4.0, 1.3, 1],
[6.5, 2.8, 4.6, 1.5, 1], [5.7, 2.8, 4.5, 1.3, 1], [6.3, 3.3, 4.7, 1.6, 1],
[4.9, 2.4, 3.3, 1.0, 1], [6.6, 2.9, 4.6, 1.3, 1], [5.2, 2.7, 3.9, 1.4, 1],
[5.0, 2.0, 3.5, 1.0, 1], [5.9, 3.0, 4.2, 1.5, 1], [6.0, 2.2, 4.0, 1.0, 1],
[6.1, 2.9, 4.7, 1.4, 1], [5.6, 2.9, 3.6, 1.3, 1], [6.7, 3.1, 4.4, 1.4, 1],
[5.6, 3.0, 4.5, 1.5, 1], [5.8, 2.7, 4.1, 1.0, 1], [6.2, 2.2, 4.5, 1.5, 1],
[5.6, 2.5, 3.9, 1.1, 1], [5.9, 3.2, 4.8, 1.8, 1], [6.1, 2.8, 4.0, 1.3, 1],
[6.3, 2.5, 4.9, 1.5, 1], [6.1, 2.8, 4.7, 1.2, 1], [6.4, 2.9, 4.3, 1.3, 1],
[6.6, 3.0, 4.4, 1.4, 1], [6.8, 2.8, 4.8, 1.4, 1], [6.7, 3.0, 5.0, 1.7, 1],
[6.0, 2.9, 4.5, 1.5, 1], [5.7, 2.6, 3.5, 1.0, 1], [5.5, 2.4, 3.8, 1.1, 1],
[5.5, 2.4, 3.7, 1.0, 1], [5.8, 2.7, 3.9, 1.2, 1], [6.0, 2.7, 5.1, 1.6, 1],
[5.4, 3.0, 4.5, 1.5, 1], [6.0, 3.4, 4.5, 1.6, 1], [6.7, 3.1, 4.7, 1.5, 1],
[6.3, 2.3, 4.4, 1.3, 1], [5.6, 3.0, 4.1, 1.3, 1], [5.5, 2.5, 4.0, 1.3, 1],
[5.5, 2.6, 4.4, 1.2, 1], [6.1, 3.0, 4.6, 1.4, 1], [5.8, 2.6, 4.0, 1.2, 1],
[5.0, 2.3, 3.3, 1.0, 1], [5.6, 2.7, 4.2, 1.3, 1], [5.7, 3.0, 4.2, 1.2, 1],
[5.7, 2.9, 4.2, 1.3, 1], [6.2, 2.9, 4.3, 1.3, 1], [5.1, 2.5, 3.0, 1.1, 1],
[5.7, 2.8, 4.1, 1.3, 1], [6.3, 3.3, 6.0, 2.5, 2], [5.8, 2.7, 5.1, 1.9, 2],
[7.1, 3.0, 5.9, 2.1, 2], [6.3, 2.9, 5.6, 1.8, 2], [6.5, 3.0, 5.8, 2.2, 2],
[7.6, 3.0, 6.6, 2.1, 2], [4.9, 2.5, 4.5, 1.7, 2], [7.3, 2.9, 6.3, 1.8, 2],
[6.7, 2.5, 5.8, 1.8, 2], [7.2, 3.6, 6.1, 2.5, 2], [6.5, 3.2, 5.1, 2.0, 2],
[6.4, 2.7, 5.3, 1.9, 2], [6.8, 3.0, 5.5, 2.1, 2], [5.7, 2.5, 5.0, 2.0, 2],
[5.8, 2.8, 5.1, 2.4, 2], [6.4, 3.2, 5.3, 2.3, 2], [6.5, 3.0, 5.5, 1.8, 2],
[7.7, 3.8, 6.7, 2.2, 2], [7.7, 2.6, 6.9, 2.3, 2], [6.0, 2.2, 5.0, 1.5, 2],
[6.9, 3.2, 5.7, 2.3, 2], [5.6, 2.8, 4.9, 2.0, 2], [7.7, 2.8, 6.7, 2.0, 2],
[6.3, 2.7, 4.9, 1.8, 2], [6.7, 3.3, 5.7, 2.1, 2], [7.2, 3.2, 6.0, 1.8, 2],
[6.2, 2.8, 4.8, 1.8, 2], [6.1, 3.0, 4.9, 1.8, 2], [6.4, 2.8, 5.6, 2.1, 2],
[7.2, 3.0, 5.8, 1.6, 2], [7.4, 2.8, 6.1, 1.9, 2], [7.9, 3.8, 6.4, 2.0, 2],
[6.4, 2.8, 5.6, 2.2, 2], [6.3, 2.8, 5.1, 1.5, 2], [6.1, 2.6, 5.6, 1.4, 2],
[7.7, 3.0, 6.1, 2.3, 2], [6.3, 3.4, 5.6, 2.4, 2], [6.4, 3.1, 5.5, 1.8, 2],
[6.0, 3.0, 4.8, 1.8, 2], [6.9, 3.1, 5.4, 2.1, 2], [6.7, 3.1, 5.6, 2.4, 2],
[6.9, 3.1, 5.1, 2.3, 2], [5.8, 2.7, 5.1, 1.9, 2], [6.8, 3.2, 5.9, 2.3, 2],
[6.7, 3.3, 5.7, 2.5, 2], [6.7, 3.0, 5.2, 2.3, 2], [6.3, 2.5, 5.0, 1.9, 2],
[6.5, 3.0, 5.2, 2.0, 2], [6.2, 3.4, 5.4, 2.3, 2], [5.9, 3.0, 5.1, 1.8, 2],
];
/**
* Convert Iris data arrays to `tf.Tensor`s.
* 普通数据转换为tensor
*
* @param data The Iris input feature data, an `Array` of `Array`s, each element
* of which is assumed to be a length-4 `Array` (for petal length, petal
* width, sepal length, sepal width).
* @param targets An `Array` of numbers, with values from the set {0, 1, 2}:
* representing the true category of the Iris flower. Assumed to have the same
* array length as `data`.
* @param testSplit Fraction of the data at the end to split as test data: a
* number between 0 and 1.
* @return A length-4 `Array`, with
* - training data as `tf.Tensor` of shape [numTrainExapmles, 4].
* - training one-hot labels as a `tf.Tensor` of shape [numTrainExamples, 3]
* - test data as `tf.Tensor` of shape [numTestExamples, 4].
* - test one-hot labels as a `tf.Tensor` of shape [numTestExamples, 3]
*/
function convertToTensors(data, targets, testSplit) {
const numExamples = data.length;
if (numExamples !== targets.length) {
throw new Error('data and split have different numbers of examples');
}
// Randomly shuffle `data` and `targets`.
// 执行随机化打乱数据
const indices = [];
for (let i = 0; i < numExamples; ++i) {
indices.push(i);
}
// TensorFlow的洗牌算法,打乱数组
tf.util.shuffle(indices);
const shuffledData = [];
const shuffledTargets = [];
for (let i = 0; i < numExamples; ++i) {
shuffledData.push(data[indices[i]]);
shuffledTargets.push(targets[indices[i]]);
}
// Split the data into a training set and a tet set, based on `testSplit`.
const numTestExamples = Math.round(numExamples * testSplit);
const numTrainExamples = numExamples - numTestExamples;
const xDims = shuffledData[0].length;
// Create a 2D `tf.Tensor` to hold the feature data.
// 创建二维tensor
const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);
// Create a 1D `tf.Tensor` to hold the labels, and convert the number label
// from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
// 将数组或集合转化为多分类的格式
const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);
// Split the data into training and test sets, using `slice`.
// 将数据分成训练集和测试集
const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]);
const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]);
return [xTrain, yTrain, xTest, yTest];
}
/**
* Obtains Iris data, split into training and test sets.
*
* @param testSplit Fraction of the data at the end to split as test data: a
* number between 0 and 1.
*
* @param return A length-4 `Array`, with
* - training data as an `Array` of length-4 `Array` of numbers.
* - training labels as an `Array` of numbers, with the same length as the
* return training data above. Each element of the `Array` is from the set
* {0, 1, 2}.
* - test data as an `Array` of length-4 `Array` of numbers.
* - test labels as an `Array` of numbers, with the same length as the
* return test data above. Each element of the `Array` is from the set
* {0, 1, 2}.
*/
export function getIrisData(testSplit) {
// 清除中间Tensor变量,减少计算机的中间性能消耗和资源消耗。
return tf.tidy(() => {
// 按类(0,1,2为输入数据数组最后一个元素)存放输入数据类型(花萼长度、花萼宽度、花瓣长度、花瓣宽度)。
const dataByClass = [];
// 代表输出结果的类型:即'山鸢尾', '变色鸢尾', '维吉尼亚鸢尾'
const targetsByClass = [];
for (let i = 0; i < IRIS_CLASSES.length; ++i) {
dataByClass.push([]);
targetsByClass.push([]);
}
for (const example of IRIS_DATA) {
const target = example[example.length - 1];
const data = example.slice(0, example.length - 1);
dataByClass[target].push(data);
targetsByClass[target].push(target);
}
const xTrains = [];
const yTrains = [];
const xTests = [];
const yTests = [];
for (let i = 0; i < IRIS_CLASSES.length; ++i) {
// 普通js数组转换为tensor
const [xTrain, yTrain, xTest, yTest] =
convertToTensors(dataByClass[i], targetsByClass[i], testSplit);
xTrains.push(xTrain);
yTrains.push(yTrain);
xTests.push(xTest);
yTests.push(yTest);
}
const concatAxis = 0;
return [
tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis),
tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis)
];
});
}
dataByClass:
targetByClass:
数据:黑色弧线。模型:蓝色直线。
模型太过强大,反而把数据拟合过头,反而遇到额外数据或新的数据,才模型上表现的不是很好。(绿色线为过拟合模型)
过拟合损失曲线: 红色:验证集损失。蓝色:训练集损失。
稍微带噪音的数据更容易复现过拟合的现象。 过度复杂的模型,为了拟合所有训练集的数据,把一些噪音也拟合了。由于最终模型比较复杂,面对训练集以外的数据反而不好了。
// 传入样本数量和方差(方差越大,噪音越多)
export function getData(numSamples, variance) {
let points = [];
function genGauss(cx, cy, label) {
for (let i = 0; i < numSamples / 2; i++) {
let x = normalRandom(cx, variance);
let y = normalRandom(cy, variance);
points.push({ x, y, label });
}
}
genGauss(2, 2, 1);
genGauss(-2, -2, 0);
return points;
}
/**
* Samples from a normal distribution. Uses the seedrandom library as the
* random generator.
*
* @param mean The mean. Default is 0.
* @param variance The variance. Default is 1.
*/
// 生成正态分布和高斯分布
function normalRandom(mean = 0, variance = 1) {
let v1, v2, s;
do {
v1 = 2 * Math.random() - 1;
v2 = 2 * Math.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s > 1);
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
return mean + Math.sqrt(variance) * result;
}
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data';
window.onload = async () => {
const data = getData(200, 2);
tfvis.render.scatterplot(
{ name: '训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
};
方差设为2: 方差设为1: 方差设为3:
简单模型解决复杂数据集或训练时间不够时往往出现此情况。 表现为训练损失和验证损失都比较高。 需增加模型复杂度,使用更多的神经元和更多的层去解决欠拟合问题。
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from '../xor/data.js';
window.onload = async () => {
const data = getData(200);
tfvis.render.scatterplot(
{ name: '训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
const model = tf.sequential();
model.add(tf.layers.dense({
units: 1,
inputShape: [2],
activation: 'sigmoid'
}));
model.compile({
loss: tf.losses.logLoss,
optimizer: tf.train.adam(0.1)
});
const inputs = tf.tensor(data.map(p => [p.x, p.y]));
const labels = tf.tensor(data.map(p => p.label));
await model.fit(inputs, labels, {
validationSplit: 0.2,
epochs: 200,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss', 'val_loss'],
{callbacks: ['onEpochEnd']},
)
});
};
训练集损失降的很低,验证集损失还是居高不下。
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data';
window.onload = async () => {
const data = getData(200, 2);
tfvis.render.scatterplot(
{ name: '训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
const model = tf.sequential();
model.add(tf.layers.dense({
units: 10,
inputShape: [2],
activation: "tanh",
}));
model.add(tf.layers.dense({
units: 1,
activation: 'sigmoid'
}));
model.compile({
loss: tf.losses.logLoss,
optimizer: tf.train.adam(0.1)
});
const inputs = tf.tensor(data.map(p => [p.x, p.y]));
const labels = tf.tensor(data.map(p => p.label));
await model.fit(inputs, labels, {
validationSplit: 0.2,
epochs: 200,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss', 'val_loss'],
{ callbacks: ['onEpochEnd'] }
)
});
};
早停法在验证集损失明显增大时,停止训练即可。
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data';
window.onload = async () => {
const data = getData(200, 2);
tfvis.render.scatterplot(
{ name: '训练数据' },
{
values: [
data.filter(p => p.label === 1),
data.filter(p => p.label === 0),
]
}
);
const model = tf.sequential();
model.add(tf.layers.dense({
units: 10,
inputShape: [2],
activation: "tanh",
// 1.使用权重衰减(L2正则化):通过设置L2正则化,把复杂模型的权重衰减掉了,模型变得没那么复杂了,因此不会过拟合了
// kernelRegularizer: tf.regularizers.l2({ l2: 1 })
}));
// 2.加入丢弃层并设置丢弃率:即将上述设置的神经元随机舍去一部分,从而减少神经元。
// 丢弃率设置为0.9,即10个神经元丢弃9个
// model.add(tf.layers.dropout({ rate: 0.9 }));
model.add(tf.layers.dense({
units: 1,
activation: 'sigmoid'
}));
model.compile({
loss: tf.losses.logLoss,
optimizer: tf.train.adam(0.1)
});
const inputs = tf.tensor(data.map(p => [p.x, p.y]));
const labels = tf.tensor(data.map(p => p.label));
await model.fit(inputs, labels, {
validationSplit: 0.2,
epochs: 200,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss', 'val_loss'],
{ callbacks: ['onEpochEnd'] }
)
});
};
仅加入L2正则化:
仅使用丢弃法:
同时使用权重衰减法和丢弃法:
MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,50%来自人口普查局的工作人员。该数据集的收集目的是希望通过算法,实现对手写数字的识别。
1998年,Yan LeCun 等人发表了论文《Gradient-Based Learning Applied to Document Recognition》,首次提出了LeNet-5 网络,利用上述数据集实现了手写字体的识别。
data.js:
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs';
const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;
const TRAIN_TEST_RATIO = 5 / 6;
const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
const MNIST_IMAGES_SPRITE_PATH =
'http://127.0.0.1:8080/mnist/mnist_images.png';
const MNIST_LABELS_PATH =
'http://127.0.0.1:8080/mnist/mnist_labels_uint8';
/**
* A class that fetches the sprited MNIST dataset and returns shuffled batches.
*
* NOTE: This will get much easier. For now, we do data fetching and
* manipulation manually.
*/
export class MnistData {
constructor() {
this.shuffledTrainIndex = 0;
this.shuffledTestIndex = 0;
}
async load() {
// Make a request for the MNIST sprited image.
const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const imgRequest = new Promise((resolve, reject) => {
img.crossOrigin = '';
img.onload = () => {
img.width = img.naturalWidth;
img.height = img.naturalHeight;
const datasetBytesBuffer =
new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
const chunkSize = 5000;
canvas.width = img.width;
canvas.height = chunkSize;
for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(
datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
IMAGE_SIZE * chunkSize);
ctx.drawImage(
img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
chunkSize);
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
for (let j = 0; j < imageData.data.length / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] = imageData.data[j * 4] / 255;
}
}
this.datasetImages = new Float32Array(datasetBytesBuffer);
resolve();
};
img.src = MNIST_IMAGES_SPRITE_PATH;
});
const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] =
await Promise.all([imgRequest, labelsRequest]);
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
// Create shuffled indices into the train/test set for when we select a
// random dataset element for training / validation.
this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);
// Slice the the images and labels into train and test sets.
this.trainImages =
this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels =
this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels =
this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
}
nextTrainBatch(batchSize) {
return this.nextBatch(
batchSize, [this.trainImages, this.trainLabels], () => {
this.shuffledTrainIndex =
(this.shuffledTrainIndex + 1) % this.trainIndices.length;
return this.trainIndices[this.shuffledTrainIndex];
});
}
nextTestBatch(batchSize) {
return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
this.shuffledTestIndex =
(this.shuffledTestIndex + 1) % this.testIndices.length;
return this.testIndices[this.shuffledTestIndex];
});
}
nextBatch(batchSize, data, index) {
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
for (let i = 0; i < batchSize; i++) {
const idx = index();
const image =
data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
batchImagesArray.set(image, i * IMAGE_SIZE);
const label =
data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
batchLabelsArray.set(label, i * NUM_CLASSES);
}
const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
return {xs, labels};
}
}
script.js:
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { MnistData } from './data';
window.onload = async () => {
// 创建data实例
const data = new MnistData();
// 加载图片和二进制文件
await data.load();
// 加载验证集,获取20个输入示例
const examples = data.nextTestBatch(20);
const surface = tfvis.visor().surface({ name: '输入示例' });
for (let i = 0; i < 20; i += 1) {
// 从20个示例中输出每个图片的tensor
const imageTensor = tf.tidy(() => {
return examples.xs
.slice([i, 0], [1, 784])
.reshape([28, 28, 1]);
});
// 创建canvas对象
const canvas = document.createElement('canvas');
canvas.width = 28;
canvas.height = 28;
canvas.style = 'margin: 4px';
// 把图片的tensor通过toPixels方法绘制到canvas上
await tf.browser.toPixels(imageTensor, canvas);
// 使用 tfvis的api在网页上展示出来
surface.drawArea.appendChild(canvas);
}
};
https://setosa.io/ev/image-kernels/
https://cs231n.github.io/convolutional-networks/
script.js:
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { MnistData } from './data';
window.onload = async () => {
// 创建data实例
const data = new MnistData();
// 加载图片和二进制文件
await data.load();
// 加载验证集,获取20个输入示例
const examples = data.nextTestBatch(20);
const surface = tfvis.visor().surface({ name: '输入示例' });
for (let i = 0; i < 20; i += 1) {
// 从20个示例中输出每个图片的tensor
const imageTensor = tf.tidy(() => {
return examples.xs
.slice([i, 0], [1, 784])
.reshape([28, 28, 1]);
});
// 创建canvas对象
const canvas = document.createElement('canvas');
canvas.width = 28;
canvas.height = 28;
canvas.style = 'margin: 4px';
// 把图片的tensor通过toPixels方法绘制到canvas上
await tf.browser.toPixels(imageTensor, canvas);
// 使用 tfvis的api在网页上展示出来
surface.drawArea.appendChild(canvas);
}
// 初始化神经网络模型(选用连续模型)
const model = tf.sequential();
// 添加一个二维卷积层
model.add(tf.layers.conv2d({
// 图片宽度、高度以及channel(由于是灰度图,故设置为1,彩色图需设置3,对应rgb)
inputShape: [28, 28, 1],
// 设置卷积核大小(5*5),建议设置为奇数,有中心点便于提取特征
kernelSize: 5,
// 设置filters个数
filters: 8,
// 设置移动步长
strides: 1,
// 设置激活函数为relu,可移除一些不常用的特征:https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
activation: 'relu',
// 设置卷积核的初始化方法,能加快收敛速度
kernelInitializer: 'varianceScaling'
}));
// 添加池化层
model.add(tf.layers.maxPool2d({
poolSize: [2, 2],
strides: [2, 2]
}));
// 重复卷积加池化的操作,从而进行特征的组合
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPool2d({
poolSize: [2, 2],
strides: [2, 2]
}));
// 把高维特征图转化为一维
model.add(tf.layers.flatten());
// 创建一个密集(全连接)层。
model.add(tf.layers.dense({
units: 10,
activation: 'softmax',
kernelInitializer: 'varianceScaling'
}));
};
script.js:
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { MnistData } from './data';
window.onload = async () => {
// 创建data实例
const data = new MnistData();
// 加载图片和二进制文件
await data.load();
// 加载验证集,获取20个输入示例
const examples = data.nextTestBatch(20);
const surface = tfvis.visor().surface({ name: '输入示例' });
for (let i = 0; i < 20; i += 1) {
// 从20个示例中输出每个图片的tensor
const imageTensor = tf.tidy(() => {
return examples.xs
.slice([i, 0], [1, 784])
.reshape([28, 28, 1]);
});
// 创建canvas对象
const canvas = document.createElement('canvas');
canvas.width = 28;
canvas.height = 28;
canvas.style = 'margin: 4px';
// 把图片的tensor通过toPixels方法绘制到canvas上
await tf.browser.toPixels(imageTensor, canvas);
// 使用 tfvis的api在网页上展示出来
surface.drawArea.appendChild(canvas);
}
// 初始化神经网络模型(选用连续模型)
const model = tf.sequential();
// 添加一个二维卷积层
model.add(tf.layers.conv2d({
// 图片宽度、高度以及channel(由于是灰度图,故设置为1,彩色图需设置3,对应rgb)
inputShape: [28, 28, 1],
// 设置卷积核大小(5*5),建议设置为奇数,有中心点便于提取特征
kernelSize: 5,
// 设置filters个数
filters: 8,
// 设置移动步长
strides: 1,
// 设置激活函数为relu,可移除一些不常用的特征:https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
activation: 'relu',
// 设置卷积核的初始化方法,能加快收敛速度
kernelInitializer: 'varianceScaling'
}));
// 添加池化层
model.add(tf.layers.maxPool2d({
poolSize: [2, 2],
strides: [2, 2]
}));
// 重复卷积加池化的操作,从而进行特征的组合
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPool2d({
poolSize: [2, 2],
strides: [2, 2]
}));
// 把高维特征图转化为一维
model.add(tf.layers.flatten());
// 创建一个密集(全连接)层。
model.add(tf.layers.dense({
units: 10,
activation: 'softmax',
kernelInitializer: 'varianceScaling'
}));
// 训练
model.compile({
// 设置损失函数(交叉熵)
loss: 'categoricalCrossentropy',
// 设置优化器
optimizer: tf.train.adam(),
// 设置度量单位:准确度
metrics: ['accuracy']
});
// 准备训练集
// 放在tidy中可以使得计算中产生的tensor被及时清理掉,从而不会影响性能
const [trainXs, trainYs] = tf.tidy(() => {
const d = data.nextTrainBatch(1000);
return [
d.xs.reshape([1000, 28, 28, 1]),
d.labels
];
});
// 准备验证集
const [testXs, testYs] = tf.tidy(() => {
const d = data.nextTestBatch(200);
return [
d.xs.reshape([200, 28, 28, 1]),
d.labels
];
});
// 训练
await model.fit(trainXs, trainYs, {
// 验证数据
validationData: [testXs, testYs],
batchSize: 500,
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss', 'val_loss', 'acc', 'val_acc'],
{ callbacks: ['onEpochEnd'] }
)
});
};
index.html
<script src="script.js"></script>
<canvas width="300" height="300" style="border: 2px solid #666;"></canvas>
<br>
<button onclick="window.clear();" style="margin: 4px;">清除</button>
<button onclick="window.predict();" style="margin: 4px;">预测</button>
script.js:
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { MnistData } from './data';
window.onload = async () => {
// 创建data实例
const data = new MnistData();
// 加载图片和二进制文件
await data.load();
// 加载验证集,获取20个输入示例
const examples = data.nextTestBatch(20);
const surface = tfvis.visor().surface({ name: '输入示例' });
for (let i = 0; i < 20; i += 1) {
// 从20个示例中输出每个图片的tensor
const imageTensor = tf.tidy(() => {
return examples.xs
.slice([i, 0], [1, 784])
.reshape([28, 28, 1]);
});
// 创建canvas对象
const canvas = document.createElement('canvas');
canvas.width = 28;
canvas.height = 28;
canvas.style = 'margin: 4px';
// 把图片的tensor通过toPixels方法绘制到canvas上
await tf.browser.toPixels(imageTensor, canvas);
// 使用 tfvis的api在网页上展示出来
surface.drawArea.appendChild(canvas);
}
// 初始化神经网络模型(选用连续模型)
const model = tf.sequential();
// 添加一个二维卷积层
model.add(tf.layers.conv2d({
// 图片宽度、高度以及channel(由于是灰度图,故设置为1,彩色图需设置3,对应rgb)
inputShape: [28, 28, 1],
// 设置卷积核大小(5*5),建议设置为奇数,有中心点便于提取特征
kernelSize: 5,
// 设置filters个数
filters: 8,
// 设置移动步长
strides: 1,
// 设置激活函数为relu,可移除一些不常用的特征:https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
activation: 'relu',
// 设置卷积核的初始化方法,能加快收敛速度
kernelInitializer: 'varianceScaling'
}));
// 添加池化层
model.add(tf.layers.maxPool2d({
poolSize: [2, 2],
strides: [2, 2]
}));
// 重复卷积加池化的操作,从而进行特征的组合
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPool2d({
poolSize: [2, 2],
strides: [2, 2]
}));
// 把高维特征图转化为一维
model.add(tf.layers.flatten());
// 创建一个密集(全连接)层。
model.add(tf.layers.dense({
units: 10,
activation: 'softmax',
kernelInitializer: 'varianceScaling'
}));
// 训练
model.compile({
// 设置损失函数(交叉熵)
loss: 'categoricalCrossentropy',
// 设置优化器
optimizer: tf.train.adam(),
// 设置度量单位:准确度
metrics: ['accuracy']
});
// 准备训练集
// 放在tidy中可以使得计算中产生的tensor被及时清理掉,从而不会影响性能
const [trainXs, trainYs] = tf.tidy(() => {
const d = data.nextTrainBatch(1000);
return [
d.xs.reshape([1000, 28, 28, 1]),
d.labels
];
});
// 准备验证集
const [testXs, testYs] = tf.tidy(() => {
const d = data.nextTestBatch(200);
return [
d.xs.reshape([200, 28, 28, 1]),
d.labels
];
});
// 训练
await model.fit(trainXs, trainYs, {
// 验证数据
validationData: [testXs, testYs],
batchSize: 500,
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss', 'val_loss', 'acc', 'val_acc'],
{ callbacks: ['onEpochEnd'] }
)
});
// 获取canvas对象
const canvas = document.querySelector('canvas');
// 监听canvas上的鼠标移动事件
canvas.addEventListener('mousemove', (e) => {
// 如果按住鼠标左键
if (e.buttons === 1) {
const ctx = canvas.getContext('2d');
// 设置画笔颜色为白色
ctx.fillStyle = 'rgb(255,255,255)';
// 画一个鼠标所在位置及宽高为25的矩形
ctx.fillRect(e.offsetX, e.offsetY, 25, 25);
}
});
window.clear = () => {
const ctx = canvas.getContext('2d');
// 设置canvas背景为黑色
ctx.fillStyle = 'rgb(0,0,0)';
// 画一个矩形
ctx.fillRect(0, 0, 300, 300);
};
clear();
window.predict = () => {
const input = tf.tidy(() => {
// fromPixel可以把canvas转化为tensor
// resizeBilinear可以把300*300图片转化为我们需要的28*28格式
// 并把彩色图片变成黑白图片,rgb三个通道改成1个通道,切1层:slice([0, 0, 0], [28, 28, 1])。切3层写法为slice([0, 0, 0], [28, 28, 3])
// 然后把数据变为float后做归一化处理(由于数字在0-255范围,故需除以255来进行归一化)
// 另外,最后输出格式要和训练的输入格式保持一致,故需要reshape为[1,28,28,1],即1张28*28像素的黑白图片
return tf.image.resizeBilinear(
tf.browser.fromPixels(canvas),
[28, 28],
true
).slice([0, 0, 0], [28, 28, 1])
.toFloat()
.div(255)
.reshape([1, 28, 28, 1]);
});
// 调用模型的预测方法
const pred = model.predict(input).argMax(1);
alert(`预测结果为 ${pred.dataSync()[0]}`);
};
};
imagenet_classes.js:
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
export const IMAGENET_CLASSES = {
0: 'tench, Tinca tinca',
1: 'goldfish, Carassius auratus',
2: 'great white shark, white shark, man-eater, man-eating shark, ' +
'Carcharodon carcharias',
3: 'tiger shark, Galeocerdo cuvieri',
4: 'hammerhead, hammerhead shark',
5: 'electric ray, crampfish, numbfish, torpedo',
6: 'stingray',
7: 'cock',
8: 'hen',
9: 'ostrich, Struthio camelus',
10: 'brambling, Fringilla montifringilla',
11: 'goldfinch, Carduelis carduelis',
12: 'house finch, linnet, Carpodacus mexicanus',
13: 'junco, snowbird',
14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
15: 'robin, American robin, Turdus migratorius',
16: 'bulbul',
17: 'jay',
18: 'magpie',
19: 'chickadee',
20: 'water ouzel, dipper',
21: 'kite',
22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
23: 'vulture',
24: 'great grey owl, great gray owl, Strix nebulosa',
25: 'European fire salamander, Salamandra salamandra',
26: 'common newt, Triturus vulgaris',
27: 'eft',
28: 'spotted salamander, Ambystoma maculatum',
29: 'axolotl, mud puppy, Ambystoma mexicanum',
30: 'bullfrog, Rana catesbeiana',
31: 'tree frog, tree-frog',
32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
33: 'loggerhead, loggerhead turtle, Caretta caretta',
34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
35: 'mud turtle',
36: 'terrapin',
37: 'box turtle, box tortoise',
38: 'banded gecko',
39: 'common iguana, iguana, Iguana iguana',
40: 'American chameleon, anole, Anolis carolinensis',
41: 'whiptail, whiptail lizard',
42: 'agama',
43: 'frilled lizard, Chlamydosaurus kingi',
44: 'alligator lizard',
45: 'Gila monster, Heloderma suspectum',
46: 'green lizard, Lacerta viridis',
47: 'African chameleon, Chamaeleo chamaeleon',
48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, ' +
'Varanus komodoensis',
49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
50: 'American alligator, Alligator mississipiensis',
51: 'triceratops',
52: 'thunder snake, worm snake, Carphophis amoenus',
53: 'ringneck snake, ring-necked snake, ring snake',
54: 'hognose snake, puff adder, sand viper',
55: 'green snake, grass snake',
56: 'king snake, kingsnake',
57: 'garter snake, grass snake',
58: 'water snake',
59: 'vine snake',
60: 'night snake, Hypsiglena torquata',
61: 'boa constrictor, Constrictor constrictor',
62: 'rock python, rock snake, Python sebae',
63: 'Indian cobra, Naja naja',
64: 'green mamba',
65: 'sea snake',
66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
69: 'trilobite',
70: 'harvestman, daddy longlegs, Phalangium opilio',
71: 'scorpion',
72: 'black and gold garden spider, Argiope aurantia',
73: 'barn spider, Araneus cavaticus',
74: 'garden spider, Aranea diademata',
75: 'black widow, Latrodectus mactans',
76: 'tarantula',
77: 'wolf spider, hunting spider',
78: 'tick',
79: 'centipede',
80: 'black grouse',
81: 'ptarmigan',
82: 'ruffed grouse, partridge, Bonasa umbellus',
83: 'prairie chicken, prairie grouse, prairie fowl',
84: 'peacock',
85: 'quail',
86: 'partridge',
87: 'African grey, African gray, Psittacus erithacus',
88: 'macaw',
89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
90: 'lorikeet',
91: 'coucal',
92: 'bee eater',
93: 'hornbill',
94: 'hummingbird',
95: 'jacamar',
96: 'toucan',
97: 'drake',
98: 'red-breasted merganser, Mergus serrator',
99: 'goose',
100: 'black swan, Cygnus atratus',
101: 'tusker',
102: 'echidna, spiny anteater, anteater',
103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, ' +
'Ornithorhynchus anatinus',
104: 'wallaby, brush kangaroo',
105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
106: 'wombat',
107: 'jelly fish',
108: 'sea anemone, anemone',
109: 'brain coral',
110: 'flatworm, platyhelminth',
111: 'nematode, nematode worm, roundworm',
112: 'conch',
113: 'snail',
114: 'slug',
115: 'sea slug, nudibranch',
116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
117: 'chambered nautilus, pearly nautilus, nautilus',
118: 'Dungeness crab, Cancer magister',
119: 'rock crab, Cancer irroratus',
120: 'fiddler crab',
121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, ' +
'Paralithodes camtschatica',
122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea ' +
'crawfish',
124: 'crayfish, crawfish, crawdad, crawdaddy',
125: 'hermit crab',
126: 'isopod',
127: 'white stork, Ciconia ciconia',
128: 'black stork, Ciconia nigra',
129: 'spoonbill',
130: 'flamingo',
131: 'little blue heron, Egretta caerulea',
132: 'American egret, great white heron, Egretta albus',
133: 'bittern',
134: 'crane',
135: 'limpkin, Aramus pictus',
136: 'European gallinule, Porphyrio porphyrio',
137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
138: 'bustard',
139: 'ruddy turnstone, Arenaria interpres',
140: 'red-backed sandpiper, dunlin, Erolia alpina',
141: 'redshank, Tringa totanus',
142: 'dowitcher',
143: 'oystercatcher, oyster catcher',
144: 'pelican',
145: 'king penguin, Aptenodytes patagonica',
146: 'albatross, mollymawk',
147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, ' +
'Eschrichtius robustus',
148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
149: 'dugong, Dugong dugon',
150: 'sea lion',
151: 'Chihuahua',
152: 'Japanese spaniel',
153: 'Maltese dog, Maltese terrier, Maltese',
154: 'Pekinese, Pekingese, Peke',
155: 'Shih-Tzu',
156: 'Blenheim spaniel',
157: 'papillon',
158: 'toy terrier',
159: 'Rhodesian ridgeback',
160: 'Afghan hound, Afghan',
161: 'basset, basset hound',
162: 'beagle',
163: 'bloodhound, sleuthhound',
164: 'bluetick',
165: 'black-and-tan coonhound',
166: 'Walker hound, Walker foxhound',
167: 'English foxhound',
168: 'redbone',
169: 'borzoi, Russian wolfhound',
170: 'Irish wolfhound',
171: 'Italian greyhound',
172: 'whippet',
173: 'Ibizan hound, Ibizan Podenco',
174: 'Norwegian elkhound, elkhound',
175: 'otterhound, otter hound',
176: 'Saluki, gazelle hound',
177: 'Scottish deerhound, deerhound',
178: 'Weimaraner',
179: 'Staffordshire bullterrier, Staffordshire bull terrier',
180: 'American Staffordshire terrier, Staffordshire terrier, American pit ' +
'bull terrier, pit bull terrier',
181: 'Bedlington terrier',
182: 'Border terrier',
183: 'Kerry blue terrier',
184: 'Irish terrier',
185: 'Norfolk terrier',
186: 'Norwich terrier',
187: 'Yorkshire terrier',
188: 'wire-haired fox terrier',
189: 'Lakeland terrier',
190: 'Sealyham terrier, Sealyham',
191: 'Airedale, Airedale terrier',
192: 'cairn, cairn terrier',
193: 'Australian terrier',
194: 'Dandie Dinmont, Dandie Dinmont terrier',
195: 'Boston bull, Boston terrier',
196: 'miniature schnauzer',
197: 'giant schnauzer',
198: 'standard schnauzer',
199: 'Scotch terrier, Scottish terrier, Scottie',
200: 'Tibetan terrier, chrysanthemum dog',
201: 'silky terrier, Sydney silky',
202: 'soft-coated wheaten terrier',
203: 'West Highland white terrier',
204: 'Lhasa, Lhasa apso',
205: 'flat-coated retriever',
206: 'curly-coated retriever',
207: 'golden retriever',
208: 'Labrador retriever',
209: 'Chesapeake Bay retriever',
210: 'German short-haired pointer',
211: 'vizsla, Hungarian pointer',
212: 'English setter',
213: 'Irish setter, red setter',
214: 'Gordon setter',
215: 'Brittany spaniel',
216: 'clumber, clumber spaniel',
217: 'English springer, English springer spaniel',
218: 'Welsh springer spaniel',
219: 'cocker spaniel, English cocker spaniel, cocker',
220: 'Sussex spaniel',
221: 'Irish water spaniel',
222: 'kuvasz',
223: 'schipperke',
224: 'groenendael',
225: 'malinois',
226: 'briard',
227: 'kelpie',
228: 'komondor',
229: 'Old English sheepdog, bobtail',
230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
231: 'collie',
232: 'Border collie',
233: 'Bouvier des Flandres, Bouviers des Flandres',
234: 'Rottweiler',
235: 'German shepherd, German shepherd dog, German police dog, alsatian',
236: 'Doberman, Doberman pinscher',
237: 'miniature pinscher',
238: 'Greater Swiss Mountain dog',
239: 'Bernese mountain dog',
240: 'Appenzeller',
241: 'EntleBucher',
242: 'boxer',
243: 'bull mastiff',
244: 'Tibetan mastiff',
245: 'French bulldog',
246: 'Great Dane',
247: 'Saint Bernard, St Bernard',
248: 'Eskimo dog, husky',
249: 'malamute, malemute, Alaskan malamute',
250: 'Siberian husky',
251: 'dalmatian, coach dog, carriage dog',
252: 'affenpinscher, monkey pinscher, monkey dog',
253: 'basenji',
254: 'pug, pug-dog',
255: 'Leonberg',
256: 'Newfoundland, Newfoundland dog',
257: 'Great Pyrenees',
258: 'Samoyed, Samoyede',
259: 'Pomeranian',
260: 'chow, chow chow',
261: 'keeshond',
262: 'Brabancon griffon',
263: 'Pembroke, Pembroke Welsh corgi',
264: 'Cardigan, Cardigan Welsh corgi',
265: 'toy poodle',
266: 'miniature poodle',
267: 'standard poodle',
268: 'Mexican hairless',
269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
271: 'red wolf, maned wolf, Canis rufus, Canis niger',
272: 'coyote, prairie wolf, brush wolf, Canis latrans',
273: 'dingo, warrigal, warragal, Canis dingo',
274: 'dhole, Cuon alpinus',
275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
276: 'hyena, hyaena',
277: 'red fox, Vulpes vulpes',
278: 'kit fox, Vulpes macrotis',
279: 'Arctic fox, white fox, Alopex lagopus',
280: 'grey fox, gray fox, Urocyon cinereoargenteus',
281: 'tabby, tabby cat',
282: 'tiger cat',
283: 'Persian cat',
284: 'Siamese cat, Siamese',
285: 'Egyptian cat',
286: 'cougar, puma, catamount, mountain lion, painter, panther, ' +
'Felis concolor',
287: 'lynx, catamount',
288: 'leopard, Panthera pardus',
289: 'snow leopard, ounce, Panthera uncia',
290: 'jaguar, panther, Panthera onca, Felis onca',
291: 'lion, king of beasts, Panthera leo',
292: 'tiger, Panthera tigris',
293: 'cheetah, chetah, Acinonyx jubatus',
294: 'brown bear, bruin, Ursus arctos',
295: 'American black bear, black bear, Ursus americanus, Euarctos ' +
'americanus',
296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
297: 'sloth bear, Melursus ursinus, Ursus ursinus',
298: 'mongoose',
299: 'meerkat, mierkat',
300: 'tiger beetle',
301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
302: 'ground beetle, carabid beetle',
303: 'long-horned beetle, longicorn, longicorn beetle',
304: 'leaf beetle, chrysomelid',
305: 'dung beetle',
306: 'rhinoceros beetle',
307: 'weevil',
308: 'fly',
309: 'bee',
310: 'ant, emmet, pismire',
311: 'grasshopper, hopper',
312: 'cricket',
313: 'walking stick, walkingstick, stick insect',
314: 'cockroach, roach',
315: 'mantis, mantid',
316: 'cicada, cicala',
317: 'leafhopper',
318: 'lacewing, lacewing fly',
319: 'dragonfly, darning needle, devil\'s darning needle, sewing needle, ' +
'snake feeder, snake doctor, mosquito hawk, skeeter hawk',
320: 'damselfly',
321: 'admiral',
322: 'ringlet, ringlet butterfly',
323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
324: 'cabbage butterfly',
325: 'sulphur butterfly, sulfur butterfly',
326: 'lycaenid, lycaenid butterfly',
327: 'starfish, sea star',
328: 'sea urchin',
329: 'sea cucumber, holothurian',
330: 'wood rabbit, cottontail, cottontail rabbit',
331: 'hare',
332: 'Angora, Angora rabbit',
333: 'hamster',
334: 'porcupine, hedgehog',
335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
336: 'marmot',
337: 'beaver',
338: 'guinea pig, Cavia cobaya',
339: 'sorrel',
340: 'zebra',
341: 'hog, pig, grunter, squealer, Sus scrofa',
342: 'wild boar, boar, Sus scrofa',
343: 'warthog',
344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
345: 'ox',
346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
347: 'bison',
348: 'ram, tup',
349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky ' +
'Mountain sheep, Ovis canadensis',
350: 'ibex, Capra ibex',
351: 'hartebeest',
352: 'impala, Aepyceros melampus',
353: 'gazelle',
354: 'Arabian camel, dromedary, Camelus dromedarius',
355: 'llama',
356: 'weasel',
357: 'mink',
358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
359: 'black-footed ferret, ferret, Mustela nigripes',
360: 'otter',
361: 'skunk, polecat, wood pussy',
362: 'badger',
363: 'armadillo',
364: 'three-toed sloth, ai, Bradypus tridactylus',
365: 'orangutan, orang, orangutang, Pongo pygmaeus',
366: 'gorilla, Gorilla gorilla',
367: 'chimpanzee, chimp, Pan troglodytes',
368: 'gibbon, Hylobates lar',
369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
370: 'guenon, guenon monkey',
371: 'patas, hussar monkey, Erythrocebus patas',
372: 'baboon',
373: 'macaque',
374: 'langur',
375: 'colobus, colobus monkey',
376: 'proboscis monkey, Nasalis larvatus',
377: 'marmoset',
378: 'capuchin, ringtail, Cebus capucinus',
379: 'howler monkey, howler',
380: 'titi, titi monkey',
381: 'spider monkey, Ateles geoffroyi',
382: 'squirrel monkey, Saimiri sciureus',
383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
384: 'indri, indris, Indri indri, Indri brevicaudatus',
385: 'Indian elephant, Elephas maximus',
386: 'African elephant, Loxodonta africana',
387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
389: 'barracouta, snoek',
390: 'eel',
391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus ' +
'kisutch',
392: 'rock beauty, Holocanthus tricolor',
393: 'anemone fish',
394: 'sturgeon',
395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
396: 'lionfish',
397: 'puffer, pufferfish, blowfish, globefish',
398: 'abacus',
399: 'abaya',
400: 'academic gown, academic robe, judge\'s robe',
401: 'accordion, piano accordion, squeeze box',
402: 'acoustic guitar',
403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
404: 'airliner',
405: 'airship, dirigible',
406: 'altar',
407: 'ambulance',
408: 'amphibian, amphibious vehicle',
409: 'analog clock',
410: 'apiary, bee house',
411: 'apron',
412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, ' +
'dustbin, trash barrel, trash bin',
413: 'assault rifle, assault gun',
414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
415: 'bakery, bakeshop, bakehouse',
416: 'balance beam, beam',
417: 'balloon',
418: 'ballpoint, ballpoint pen, ballpen, Biro',
419: 'Band Aid',
420: 'banjo',
421: 'bannister, banister, balustrade, balusters, handrail',
422: 'barbell',
423: 'barber chair',
424: 'barbershop',
425: 'barn',
426: 'barometer',
427: 'barrel, cask',
428: 'barrow, garden cart, lawn cart, wheelbarrow',
429: 'baseball',
430: 'basketball',
431: 'bassinet',
432: 'bassoon',
433: 'bathing cap, swimming cap',
434: 'bath towel',
435: 'bathtub, bathing tub, bath, tub',
436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station ' +
'waggon, waggon',
437: 'beacon, lighthouse, beacon light, pharos',
438: 'beaker',
439: 'bearskin, busby, shako',
440: 'beer bottle',
441: 'beer glass',
442: 'bell cote, bell cot',
443: 'bib',
444: 'bicycle-built-for-two, tandem bicycle, tandem',
445: 'bikini, two-piece',
446: 'binder, ring-binder',
447: 'binoculars, field glasses, opera glasses',
448: 'birdhouse',
449: 'boathouse',
450: 'bobsled, bobsleigh, bob',
451: 'bolo tie, bolo, bola tie, bola',
452: 'bonnet, poke bonnet',
453: 'bookcase',
454: 'bookshop, bookstore, bookstall',
455: 'bottlecap',
456: 'bow',
457: 'bow tie, bow-tie, bowtie',
458: 'brass, memorial tablet, plaque',
459: 'brassiere, bra, bandeau',
460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
461: 'breastplate, aegis, egis',
462: 'broom',
463: 'bucket, pail',
464: 'buckle',
465: 'bulletproof vest',
466: 'bullet train, bullet',
467: 'butcher shop, meat market',
468: 'cab, hack, taxi, taxicab',
469: 'caldron, cauldron',
470: 'candle, taper, wax light',
471: 'cannon',
472: 'canoe',
473: 'can opener, tin opener',
474: 'cardigan',
475: 'car mirror',
476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
477: 'carpenter\'s kit, tool kit',
478: 'carton',
479: 'car wheel',
480: 'cash machine, cash dispenser, automated teller machine, automatic ' +
'teller machine, automated teller, automatic teller, ATM',
481: 'cassette',
482: 'cassette player',
483: 'castle',
484: 'catamaran',
485: 'CD player',
486: 'cello, violoncello',
487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
488: 'chain',
489: 'chainlink fence',
490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ' +
'ring armour',
491: 'chain saw, chainsaw',
492: 'chest',
493: 'chiffonier, commode',
494: 'chime, bell, gong',
495: 'china cabinet, china closet',
496: 'Christmas stocking',
497: 'church, church building',
498: 'cinema, movie theater, movie theatre, movie house, picture palace',
499: 'cleaver, meat cleaver, chopper',
500: 'cliff dwelling',
501: 'cloak',
502: 'clog, geta, patten, sabot',
503: 'cocktail shaker',
504: 'coffee mug',
505: 'coffeepot',
506: 'coil, spiral, volute, whorl, helix',
507: 'combination lock',
508: 'computer keyboard, keypad',
509: 'confectionery, confectionary, candy store',
510: 'container ship, containership, container vessel',
511: 'convertible',
512: 'corkscrew, bottle screw',
513: 'cornet, horn, trumpet, trump',
514: 'cowboy boot',
515: 'cowboy hat, ten-gallon hat',
516: 'cradle',
517: 'crane',
518: 'crash helmet',
519: 'crate',
520: 'crib, cot',
521: 'Crock Pot',
522: 'croquet ball',
523: 'crutch',
524: 'cuirass',
525: 'dam, dike, dyke',
526: 'desk',
527: 'desktop computer',
528: 'dial telephone, dial phone',
529: 'diaper, nappy, napkin',
530: 'digital clock',
531: 'digital watch',
532: 'dining table, board',
533: 'dishrag, dishcloth',
534: 'dishwasher, dish washer, dishwashing machine',
535: 'disk brake, disc brake',
536: 'dock, dockage, docking facility',
537: 'dogsled, dog sled, dog sleigh',
538: 'dome',
539: 'doormat, welcome mat',
540: 'drilling platform, offshore rig',
541: 'drum, membranophone, tympan',
542: 'drumstick',
543: 'dumbbell',
544: 'Dutch oven',
545: 'electric fan, blower',
546: 'electric guitar',
547: 'electric locomotive',
548: 'entertainment center',
549: 'envelope',
550: 'espresso maker',
551: 'face powder',
552: 'feather boa, boa',
553: 'file, file cabinet, filing cabinet',
554: 'fireboat',
555: 'fire engine, fire truck',
556: 'fire screen, fireguard',
557: 'flagpole, flagstaff',
558: 'flute, transverse flute',
559: 'folding chair',
560: 'football helmet',
561: 'forklift',
562: 'fountain',
563: 'fountain pen',
564: 'four-poster',
565: 'freight car',
566: 'French horn, horn',
567: 'frying pan, frypan, skillet',
568: 'fur coat',
569: 'garbage truck, dustcart',
570: 'gasmask, respirator, gas helmet',
571: 'gas pump, gasoline pump, petrol pump, island dispenser',
572: 'goblet',
573: 'go-kart',
574: 'golf ball',
575: 'golfcart, golf cart',
576: 'gondola',
577: 'gong, tam-tam',
578: 'gown',
579: 'grand piano, grand',
580: 'greenhouse, nursery, glasshouse',
581: 'grille, radiator grille',
582: 'grocery store, grocery, food market, market',
583: 'guillotine',
584: 'hair slide',
585: 'hair spray',
586: 'half track',
587: 'hammer',
588: 'hamper',
589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
590: 'hand-held computer, hand-held microcomputer',
591: 'handkerchief, hankie, hanky, hankey',
592: 'hard disc, hard disk, fixed disk',
593: 'harmonica, mouth organ, harp, mouth harp',
594: 'harp',
595: 'harvester, reaper',
596: 'hatchet',
597: 'holster',
598: 'home theater, home theatre',
599: 'honeycomb',
600: 'hook, claw',
601: 'hoopskirt, crinoline',
602: 'horizontal bar, high bar',
603: 'horse cart, horse-cart',
604: 'hourglass',
605: 'iPod',
606: 'iron, smoothing iron',
607: 'jack-o\'-lantern',
608: 'jean, blue jean, denim',
609: 'jeep, landrover',
610: 'jersey, T-shirt, tee shirt',
611: 'jigsaw puzzle',
612: 'jinrikisha, ricksha, rickshaw',
613: 'joystick',
614: 'kimono',
615: 'knee pad',
616: 'knot',
617: 'lab coat, laboratory coat',
618: 'ladle',
619: 'lampshade, lamp shade',
620: 'laptop, laptop computer',
621: 'lawn mower, mower',
622: 'lens cap, lens cover',
623: 'letter opener, paper knife, paperknife',
624: 'library',
625: 'lifeboat',
626: 'lighter, light, igniter, ignitor',
627: 'limousine, limo',
628: 'liner, ocean liner',
629: 'lipstick, lip rouge',
630: 'Loafer',
631: 'lotion',
632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker ' +
'system',
633: 'loupe, jeweler\'s loupe',
634: 'lumbermill, sawmill',
635: 'magnetic compass',
636: 'mailbag, postbag',
637: 'mailbox, letter box',
638: 'maillot',
639: 'maillot, tank suit',
640: 'manhole cover',
641: 'maraca',
642: 'marimba, xylophone',
643: 'mask',
644: 'matchstick',
645: 'maypole',
646: 'maze, labyrinth',
647: 'measuring cup',
648: 'medicine chest, medicine cabinet',
649: 'megalith, megalithic structure',
650: 'microphone, mike',
651: 'microwave, microwave oven',
652: 'military uniform',
653: 'milk can',
654: 'minibus',
655: 'miniskirt, mini',
656: 'minivan',
657: 'missile',
658: 'mitten',
659: 'mixing bowl',
660: 'mobile home, manufactured home',
661: 'Model T',
662: 'modem',
663: 'monastery',
664: 'monitor',
665: 'moped',
666: 'mortar',
667: 'mortarboard',
668: 'mosque',
669: 'mosquito net',
670: 'motor scooter, scooter',
671: 'mountain bike, all-terrain bike, off-roader',
672: 'mountain tent',
673: 'mouse, computer mouse',
674: 'mousetrap',
675: 'moving van',
676: 'muzzle',
677: 'nail',
678: 'neck brace',
679: 'necklace',
680: 'nipple',
681: 'notebook, notebook computer',
682: 'obelisk',
683: 'oboe, hautboy, hautbois',
684: 'ocarina, sweet potato',
685: 'odometer, hodometer, mileometer, milometer',
686: 'oil filter',
687: 'organ, pipe organ',
688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
689: 'overskirt',
690: 'oxcart',
691: 'oxygen mask',
692: 'packet',
693: 'paddle, boat paddle',
694: 'paddlewheel, paddle wheel',
695: 'padlock',
696: 'paintbrush',
697: 'pajama, pyjama, pj\'s, jammies',
698: 'palace',
699: 'panpipe, pandean pipe, syrinx',
700: 'paper towel',
701: 'parachute, chute',
702: 'parallel bars, bars',
703: 'park bench',
704: 'parking meter',
705: 'passenger car, coach, carriage',
706: 'patio, terrace',
707: 'pay-phone, pay-station',
708: 'pedestal, plinth, footstall',
709: 'pencil box, pencil case',
710: 'pencil sharpener',
711: 'perfume, essence',
712: 'Petri dish',
713: 'photocopier',
714: 'pick, plectrum, plectron',
715: 'pickelhaube',
716: 'picket fence, paling',
717: 'pickup, pickup truck',
718: 'pier',
719: 'piggy bank, penny bank',
720: 'pill bottle',
721: 'pillow',
722: 'ping-pong ball',
723: 'pinwheel',
724: 'pirate, pirate ship',
725: 'pitcher, ewer',
726: 'plane, carpenter\'s plane, woodworking plane',
727: 'planetarium',
728: 'plastic bag',
729: 'plate rack',
730: 'plow, plough',
731: 'plunger, plumber\'s helper',
732: 'Polaroid camera, Polaroid Land camera',
733: 'pole',
734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black ' +
'Maria',
735: 'poncho',
736: 'pool table, billiard table, snooker table',
737: 'pop bottle, soda bottle',
738: 'pot, flowerpot',
739: 'potter\'s wheel',
740: 'power drill',
741: 'prayer rug, prayer mat',
742: 'printer',
743: 'prison, prison house',
744: 'projectile, missile',
745: 'projector',
746: 'puck, hockey puck',
747: 'punching bag, punch bag, punching ball, punchball',
748: 'purse',
749: 'quill, quill pen',
750: 'quilt, comforter, comfort, puff',
751: 'racer, race car, racing car',
752: 'racket, racquet',
753: 'radiator',
754: 'radio, wireless',
755: 'radio telescope, radio reflector',
756: 'rain barrel',
757: 'recreational vehicle, RV, R.V.',
758: 'reel',
759: 'reflex camera',
760: 'refrigerator, icebox',
761: 'remote control, remote',
762: 'restaurant, eating house, eating place, eatery',
763: 'revolver, six-gun, six-shooter',
764: 'rifle',
765: 'rocking chair, rocker',
766: 'rotisserie',
767: 'rubber eraser, rubber, pencil eraser',
768: 'rugby ball',
769: 'rule, ruler',
770: 'running shoe',
771: 'safe',
772: 'safety pin',
773: 'saltshaker, salt shaker',
774: 'sandal',
775: 'sarong',
776: 'sax, saxophone',
777: 'scabbard',
778: 'scale, weighing machine',
779: 'school bus',
780: 'schooner',
781: 'scoreboard',
782: 'screen, CRT screen',
783: 'screw',
784: 'screwdriver',
785: 'seat belt, seatbelt',
786: 'sewing machine',
787: 'shield, buckler',
788: 'shoe shop, shoe-shop, shoe store',
789: 'shoji',
790: 'shopping basket',
791: 'shopping cart',
792: 'shovel',
793: 'shower cap',
794: 'shower curtain',
795: 'ski',
796: 'ski mask',
797: 'sleeping bag',
798: 'slide rule, slipstick',
799: 'sliding door',
800: 'slot, one-armed bandit',
801: 'snorkel',
802: 'snowmobile',
803: 'snowplow, snowplough',
804: 'soap dispenser',
805: 'soccer ball',
806: 'sock',
807: 'solar dish, solar collector, solar furnace',
808: 'sombrero',
809: 'soup bowl',
810: 'space bar',
811: 'space heater',
812: 'space shuttle',
813: 'spatula',
814: 'speedboat',
815: 'spider web, spider\'s web',
816: 'spindle',
817: 'sports car, sport car',
818: 'spotlight, spot',
819: 'stage',
820: 'steam locomotive',
821: 'steel arch bridge',
822: 'steel drum',
823: 'stethoscope',
824: 'stole',
825: 'stone wall',
826: 'stopwatch, stop watch',
827: 'stove',
828: 'strainer',
829: 'streetcar, tram, tramcar, trolley, trolley car',
830: 'stretcher',
831: 'studio couch, day bed',
832: 'stupa, tope',
833: 'submarine, pigboat, sub, U-boat',
834: 'suit, suit of clothes',
835: 'sundial',
836: 'sunglass',
837: 'sunglasses, dark glasses, shades',
838: 'sunscreen, sunblock, sun blocker',
839: 'suspension bridge',
840: 'swab, swob, mop',
841: 'sweatshirt',
842: 'swimming trunks, bathing trunks',
843: 'swing',
844: 'switch, electric switch, electrical switch',
845: 'syringe',
846: 'table lamp',
847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
848: 'tape player',
849: 'teapot',
850: 'teddy, teddy bear',
851: 'television, television system',
852: 'tennis ball',
853: 'thatch, thatched roof',
854: 'theater curtain, theatre curtain',
855: 'thimble',
856: 'thresher, thrasher, threshing machine',
857: 'throne',
858: 'tile roof',
859: 'toaster',
860: 'tobacco shop, tobacconist shop, tobacconist',
861: 'toilet seat',
862: 'torch',
863: 'totem pole',
864: 'tow truck, tow car, wrecker',
865: 'toyshop',
866: 'tractor',
867: 'trailer truck, tractor trailer, trucking rig, rig, articulated ' +
'lorry, semi',
868: 'tray',
869: 'trench coat',
870: 'tricycle, trike, velocipede',
871: 'trimaran',
872: 'tripod',
873: 'triumphal arch',
874: 'trolleybus, trolley coach, trackless trolley',
875: 'trombone',
876: 'tub, vat',
877: 'turnstile',
878: 'typewriter keyboard',
879: 'umbrella',
880: 'unicycle, monocycle',
881: 'upright, upright piano',
882: 'vacuum, vacuum cleaner',
883: 'vase',
884: 'vault',
885: 'velvet',
886: 'vending machine',
887: 'vestment',
888: 'viaduct',
889: 'violin, fiddle',
890: 'volleyball',
891: 'waffle iron',
892: 'wall clock',
893: 'wallet, billfold, notecase, pocketbook',
894: 'wardrobe, closet, press',
895: 'warplane, military plane',
896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
897: 'washer, automatic washer, washing machine',
898: 'water bottle',
899: 'water jug',
900: 'water tower',
901: 'whiskey jug',
902: 'whistle',
903: 'wig',
904: 'window screen',
905: 'window shade',
906: 'Windsor tie',
907: 'wine bottle',
908: 'wing',
909: 'wok',
910: 'wooden spoon',
911: 'wool, woolen, woollen',
912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
913: 'wreck',
914: 'yawl',
915: 'yurt',
916: 'web site, website, internet site, site',
917: 'comic book',
918: 'crossword puzzle, crossword',
919: 'street sign',
920: 'traffic light, traffic signal, stoplight',
921: 'book jacket, dust cover, dust jacket, dust wrapper',
922: 'menu',
923: 'plate',
924: 'guacamole',
925: 'consomme',
926: 'hot pot, hotpot',
927: 'trifle',
928: 'ice cream, icecream',
929: 'ice lolly, lolly, lollipop, popsicle',
930: 'French loaf',
931: 'bagel, beigel',
932: 'pretzel',
933: 'cheeseburger',
934: 'hotdog, hot dog, red hot',
935: 'mashed potato',
936: 'head cabbage',
937: 'broccoli',
938: 'cauliflower',
939: 'zucchini, courgette',
940: 'spaghetti squash',
941: 'acorn squash',
942: 'butternut squash',
943: 'cucumber, cuke',
944: 'artichoke, globe artichoke',
945: 'bell pepper',
946: 'cardoon',
947: 'mushroom',
948: 'Granny Smith',
949: 'strawberry',
950: 'orange',
951: 'lemon',
952: 'fig',
953: 'pineapple, ananas',
954: 'banana',
955: 'jackfruit, jak, jack',
956: 'custard apple',
957: 'pomegranate',
958: 'hay',
959: 'carbonara',
960: 'chocolate sauce, chocolate syrup',
961: 'dough',
962: 'meat loaf, meatloaf',
963: 'pizza, pizza pie',
964: 'potpie',
965: 'burrito',
966: 'red wine',
967: 'espresso',
968: 'cup',
969: 'eggnog',
970: 'alp',
971: 'bubble',
972: 'cliff, drop, drop-off',
973: 'coral reef',
974: 'geyser',
975: 'lakeside, lakeshore',
976: 'promontory, headland, head, foreland',
977: 'sandbar, sand bar',
978: 'seashore, coast, seacoast, sea-coast',
979: 'valley, vale',
980: 'volcano',
981: 'ballplayer, baseball player',
982: 'groom, bridegroom',
983: 'scuba diver',
984: 'rapeseed',
985: 'daisy',
986: 'yellow lady\'s slipper, yellow lady-slipper, Cypripedium calceolus, ' +
'Cypripedium parviflorum',
987: 'corn',
988: 'acorn',
989: 'hip, rose hip, rosehip',
990: 'buckeye, horse chestnut, conker',
991: 'coral fungus',
992: 'agaric',
993: 'gyromitra',
994: 'stinkhorn, carrion fungus',
995: 'earthstar',
996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola ' +
'frondosa',
997: 'bolete',
998: 'ear, spike, capitulum',
999: 'toilet tissue, toilet paper, bathroom tissue'
};
index.html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
util.js
export function file2img(f) {
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
script.js:
import * as tf from '@tensorflow/tfjs';
import { IMAGENET_CLASSES } from './imagenet_classes';
import { file2img } from './utils';
// 定义模型地址
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
window.onload = async () => {
// 加载模型
const model = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
// 进行预测
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
// 将img 元素变为tensor,并进行归一化处理(先将数据变为-127.5~127.5之间,然后除以127.5使之处于-1~1之间),然后变为1张224*224的彩色图片
const input = tf.browser.fromPixels(img)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
// 将处理好的值传入到模型预测方法中
return model.predict(input);
});
// 得到预测类型索引
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${IMAGENET_CLASSES[index]}`);
}, 0);
};
};
data.js:
const IMAGE_SIZE = 224;
const loadImg = (src) => {
return new Promise(resolve => {
const img = new Image();
img.crossOrigin = "anonymous";
img.src = src;
img.width = IMAGE_SIZE;
img.height = IMAGE_SIZE;
img.onload = () => resolve(img);
});
};
export const getInputs = async () => {
const loadImgs = [];
const labels = [];
for (let i = 0; i < 30; i += 1) {
['android', 'apple', 'windows'].forEach(label => {
const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;
const img = loadImg(src);
loadImgs.push(img);
labels.push([
label === 'android' ? 1 : 0,
label === 'apple' ? 1 : 0,
label === 'windows' ? 1 : 0,
]);
});
}
const inputs = await Promise.all(loadImgs);
return {
inputs,
labels,
};
}
script.js:
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data'
window.onload = async () => {
const { inputs, labels } = await getInputs();
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
inputs.forEach(img => {
surface.drawArea.appendChild(img);
});
};
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
window.onload = async () => {
const { inputs, labels } = await getInputs();
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
inputs.forEach(img => {
surface.drawArea.appendChild(img);
});
// 加载模型
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
// 查看模型概况
mobilenet.summary();
// 通过名称获取模型某一个中间层
const layer = mobilenet.getLayer('conv_pw_13_relu');
// 截断模型是以mobilenet模型的输入为输入,截断层为输出的模型
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});
// 用于创建双层神经网络的新模型
const model = tf.sequential();
// 把截断层摊平成一个一维向量
model.add(tf.layers.flatten({
inputShape: layer.outputShape.slice(1)
}));
// 构建一个双层神经网络
model.add(tf.layers.dense({
units: 10,
activation: 'relu'
}));
model.add(tf.layers.dense({
units: NUM_CLASSES,
activation: 'softmax'
}));
};
utils:
import * as tf from '@tensorflow/tfjs';
export function img2x(imgEl){
return tf.tidy(() => {
const input = tf.browser.fromPixels(imgEl)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return input;
});
}
script:
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x } from './utils';
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];
window.onload = async () => {
const { inputs, labels } = await getInputs();
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
inputs.forEach(img => {
surface.drawArea.appendChild(img);
});
// 加载模型
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
// 查看模型概况
mobilenet.summary();
// 通过名称获取模型某一个中间层
const layer = mobilenet.getLayer('conv_pw_13_relu');
// 截断模型是以mobilenet模型的输入为输入,截断层为输出的模型
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});
// 用于创建双层神经网络的新模型
const model = tf.sequential();
// 把截断层摊平成一个一维向量
model.add(tf.layers.flatten({
inputShape: layer.outputShape.slice(1)
}));
// 构建一个双层神经网络
model.add(tf.layers.dense({
units: 10,
activation: 'relu'
}));
model.add(tf.layers.dense({
units: NUM_CLASSES,
activation: 'softmax'
}));
// 设置损失函数和优化器
model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });
// 把输入数据先给截断模型,截断模型处理完后再给新模型处理
// 将图片处理成Mobilenet所需格式后传入截断模型
const { xs, ys } = tf.tidy(() => {
const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
const ys = tf.tensor(labels);
return { xs, ys };
});
// 将截断模型处理后的数据传入新模型进行训练
await model.fit(xs, ys, {
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss'],
{ callbacks: ['onEpochEnd'] }
)
});
};
index.html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
utils:
import * as tf from '@tensorflow/tfjs';
// 将img变成mobilenet模型需要的格式
export function img2x(imgEl){
return tf.tidy(() => {
const input = tf.browser.fromPixels(imgEl)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return input;
});
}
export function file2img(f) {
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x, file2img } from './utils';
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];
window.onload = async () => {
const { inputs, labels } = await getInputs();
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
inputs.forEach(img => {
surface.drawArea.appendChild(img);
});
// 加载模型
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
// 查看模型概况
mobilenet.summary();
// 通过名称获取模型某一个中间层
const layer = mobilenet.getLayer('conv_pw_13_relu');
// 截断模型是以mobilenet模型的输入为输入,截断层为输出的模型
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});
// 用于创建双层神经网络的新模型
const model = tf.sequential();
// 把截断层摊平成一个一维向量
model.add(tf.layers.flatten({
inputShape: layer.outputShape.slice(1)
}));
// 构建一个双层神经网络
model.add(tf.layers.dense({
units: 10,
activation: 'relu'
}));
model.add(tf.layers.dense({
units: NUM_CLASSES,
activation: 'softmax'
}));
// 设置损失函数和优化器
model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });
// 把输入数据先给截断模型,截断模型处理完后再给新模型处理
// 将图片处理成Mobilenet所需格式后传入截断模型
const { xs, ys } = tf.tidy(() => {
const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
const ys = tf.tensor(labels);
return { xs, ys };
});
// 将截断模型处理后的数据传入新模型进行训练
await model.fit(xs, ys, {
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss'],
{ callbacks: ['onEpochEnd'] }
)
});
// 模型预测
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x);
return model.predict(input);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};
};
index.html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
<button onclick="download()">下载模型</button>
script.js:
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x, file2img } from './utils';
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];
window.onload = async () => {
const { inputs, labels } = await getInputs();
const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
inputs.forEach(img => {
surface.drawArea.appendChild(img);
});
// 加载模型
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
// 查看模型概况
mobilenet.summary();
// 通过名称获取模型某一个中间层
const layer = mobilenet.getLayer('conv_pw_13_relu');
// 截断模型是以mobilenet模型的输入为输入,截断层为输出的模型
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});
// 用于创建双层神经网络的新模型
const model = tf.sequential();
// 把截断层摊平成一个一维向量
model.add(tf.layers.flatten({
inputShape: layer.outputShape.slice(1)
}));
// 构建一个双层神经网络
model.add(tf.layers.dense({
units: 10,
activation: 'relu'
}));
model.add(tf.layers.dense({
units: NUM_CLASSES,
activation: 'softmax'
}));
// 设置损失函数和优化器
model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });
// 把输入数据先给截断模型,截断模型处理完后再给新模型处理
// 将图片处理成Mobilenet所需格式后传入截断模型
const { xs, ys } = tf.tidy(() => {
const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
const ys = tf.tensor(labels);
return { xs, ys };
});
// 将截断模型处理后的数据传入新模型进行训练
await model.fit(xs, ys, {
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss'],
{ callbacks: ['onEpochEnd'] }
)
});
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x);
return model.predict(input);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};
// 模型加载
window.download = async () => {
// 调用模型的save方法
await model.save('downloads://model');
};
};
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
<br>
utils.js:
import * as tf from '@tensorflow/tfjs';
export function img2x(imgEl){
return tf.tidy(() => {
const input = tf.browser.fromPixels(imgEl)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return input;
});
}
export function file2img(f) {
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
script.js:
import * as tf from '@tensorflow/tfjs';
import { img2x, file2img } from './utils';
const MODEL_PATH = 'http://127.0.0.1:8080';
const BRAND_CLASSES = ['android', 'apple', 'windows'];
window.onload = async () => {
const mobilenet = await tf.loadLayersModel(MODEL_PATH + '/mobilenet/web_model/model.json');
mobilenet.summary();
const layer = mobilenet.getLayer('conv_pw_13_relu');
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});
const model = await tf.loadLayersModel(MODEL_PATH + '/brand/web_model/model.json');
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x);
return model.predict(input);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};
};
https://github.com/tensorflow/tfjs-modelshttps://github.com/tensorflow/tfjs-models/tree/master/speech-commandshttps://www.tensorflow.org/datasets/catalog/speech_commands
<script src="script.js"></script>
<style>
#result>div {
float: left;
padding: 20px;
}
</style>
<div id="result"></div>
import * as speechCommands from '@tensorflow-models/speech-commands';
const MODEL_PATH = 'http://127.0.0.1:8080/speech';
window.onload = async () => {
// 创建识别器:(1.传入浏览器原生富丽雅变换方法2.需要识别的单词 3.null时为默认单词3.自定义模型的URL 4.自定义源信息的URL
const recognizer = speechCommands.create(
'BROWSER_FFT',
null,
MODEL_PATH + '/model.json',
MODEL_PATH + '/metadata.json'
);
// 调用识别器的ensureModelLoaded,确保模型加载好了
await recognizer.ensureModelLoaded();
};
<script src="script.js"></script>
<style>
#result>div {
float: left;
padding: 20px;
}
</style>
<div id="result"></div>
import * as speechCommands from '@tensorflow-models/speech-commands';
const MODEL_PATH = 'http://127.0.0.1:8080/speech';
window.onload = async () => {
// 创建识别器:(1.传入浏览器原生富丽雅变换方法2.需要识别的单词 3.null时为默认单词3.自定义模型的URL 4.自定义源信息的URL
const recognizer = speechCommands.create(
'BROWSER_FFT',
null,
MODEL_PATH + '/model.json',
MODEL_PATH + '/metadata.json'
);
// 调用识别器的ensureModelLoaded,确保模型加载好了
await recognizer.ensureModelLoaded();
// 查看模型能识别哪些单词
const labels = recognizer.wordLabels().slice(2);
const resultEl = document.querySelector('#result');
resultEl.innerHTML = labels.map(l => `
<div>${l}</div>
`).join('');
// 打开浏览器麦克风开关,监听麦克风输入,并将输入转换为模型输入值
recognizer.listen(result => {
// 结果中包含每个单词可能性的评分
const { scores } = result;
// 获取评分最高的分数
const maxValue = Math.max(...scores);
// 找到评分最高分数的索引
const index = scores.indexOf(maxValue) - 2;
// 将单词列表插入到HTML中,并高亮显示读到的单词
resultEl.innerHTML = labels.map((l, i) => `
<div style="background: ${i === index && 'green'}">${l}</div>
`).join('');
}, {
// 设置识别频率:文件很多波段,预测截取的波段和该波段有多少覆盖可通过此参数设置。数值越高,识别频率越高。
overlapFactor: 0.3,
// 可能性阈值:只有当监听到的单词和训练的单词有90%的相似,才会执行上面的回调函数
probabilityThreshold: 0.9
});
};
<script src="script.js"></script>
<button onclick="collect(this)">上一张</button>
<button onclick="collect(this)">下一张</button>
<button onclick="collect(this)">背景噪音</button>
<pre id="count"></pre>
import * as speechCommands from '@tensorflow-models/speech-commands';
import * as tfvis from '@tensorflow/tfjs-vis';
const MODEL_PATH = 'http://127.0.0.1:8080';
let transferRecognizer;
window.onload = async () => {
const recognizer = speechCommands.create(
'BROWSER_FFT',
null,
MODEL_PATH + '/speech/model.json',
MODEL_PATH + '/speech/metadata.json'
);
await recognizer.ensureModelLoaded();
// 创建迁移学习器
transferRecognizer = recognizer.createTransfer('轮播图');
};
// 点击按钮,收集语音数据
window.collect = async (btn) => {
// 录入过程中禁用按钮防止重复点击录入
btn.disabled = true;
const label = btn.innerText;
// 收集语音训练素材和相应的label
await transferRecognizer.collectExample(
label === '背景噪音' ? '_background_noise_' : label
);
btn.disabled = false;
// 可视化显示录入统计数据
document.querySelector('#count').innerHTML = JSON.stringify(transferRecognizer.countExamples(), null, 2);
};
<script src="script.js"></script>
<button onclick="collect(this)">上一张</button>
<button onclick="collect(this)">下一张</button>
<button onclick="collect(this)">背景噪音</button>
<pre id="count"></pre>
<button onclick="train()">训练</button>
<br><br>
监听开关:<input type="checkbox" onchange="toggle(this.checked)">
import * as speechCommands from '@tensorflow-models/speech-commands';
import * as tfvis from '@tensorflow/tfjs-vis';
const MODEL_PATH = 'http://127.0.0.1:8080';
let transferRecognizer;
window.onload = async () => {
const recognizer = speechCommands.create(
'BROWSER_FFT',
null,
MODEL_PATH + '/speech/model.json',
MODEL_PATH + '/speech/metadata.json'
);
await recognizer.ensureModelLoaded();
// 创建迁移学习器
transferRecognizer = recognizer.createTransfer('轮播图');
};
// 点击按钮,收集语音数据
window.collect = async (btn) => {
// 录入过程中禁用按钮防止重复点击录入
btn.disabled = true;
const label = btn.innerText;
// 收集语音训练素材和相应的label
await transferRecognizer.collectExample(
label === '背景噪音' ? '_background_noise_' : label
);
btn.disabled = false;
// 可视化显示录入统计数据
document.querySelector('#count').innerHTML = JSON.stringify(transferRecognizer.countExamples(), null, 2);
};
// 模型训练及过程可视化
window.train = async () => {
await transferRecognizer.train({
epochs: 30,
callback: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss', 'acc'],
{ callbacks: ['onEpochEnd'] }
)
});
};
// 麦克风打开关闭监听按钮的回调
window.toggle = async (checked) => {
// 如果开关打开
if (checked) {
// 调用迁移学习器的listen方法
await transferRecognizer.listen(result => {
// 监听一段语音结束后的回调函数
// 所有要识别语音的得分情况
const { scores } = result;
// 拿到所有类别
const labels = transferRecognizer.wordLabels();
// 拿到最高得分的索引index
const index = scores.indexOf(Math.max(...scores));
// 得到最高得分
console.log(labels[index]);
}, {
overlapFactor: 0, // 识别频率
probabilityThreshold: 0.75 // 可能性阈值
});
} else {
// 开关关闭时,停止监听
transferRecognizer.stopListening();
}
};
<script src="script.js"></script>
<button onclick="collect(this)">上一张</button>
<button onclick="collect(this)">下一张</button>
<button onclick="collect(this)">背景噪音</button>
<button onclick="save()">保存</button>
<pre id="count"></pre>
<button onclick="train()">训练</button>
<br><br>
监听开关:<input type="checkbox" onchange="toggle(this.checked)">
import * as speechCommands from '@tensorflow-models/speech-commands';
import * as tfvis from '@tensorflow/tfjs-vis';
const MODEL_PATH = 'http://127.0.0.1:8080';
let transferRecognizer;
window.onload = async () => {
const recognizer = speechCommands.create(
'BROWSER_FFT',
null,
MODEL_PATH + '/speech/model.json',
MODEL_PATH + '/speech/metadata.json'
);
await recognizer.ensureModelLoaded();
// 创建迁移学习器
transferRecognizer = recognizer.createTransfer('轮播图');
};
// 点击按钮,收集语音数据
window.collect = async (btn) => {
// 录入过程中禁用按钮防止重复点击录入
btn.disabled = true;
const label = btn.innerText;
// 收集语音训练素材和相应的label
await transferRecognizer.collectExample(
label === '背景噪音' ? '_background_noise_' : label
);
btn.disabled = false;
// 可视化显示录入统计数据
document.querySelector('#count').innerHTML = JSON.stringify(transferRecognizer.countExamples(), null, 2);
};
// 模型训练及过程可视化
window.train = async () => {
await transferRecognizer.train({
epochs: 30,
callback: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss', 'acc'],
{ callbacks: ['onEpochEnd'] }
)
});
};
// 麦克风打开关闭监听按钮的回调
window.toggle = async (checked) => {
// 如果开关打开
if (checked) {
// 调用迁移学习器的listen方法
await transferRecognizer.listen(result => {
// 监听一段语音结束后的回调函数
// 所有要识别语音的得分情况
const { scores } = result;
// 拿到所有类别
const labels = transferRecognizer.wordLabels();
// 拿到最高得分的索引index
const index = scores.indexOf(Math.max(...scores));
// 得到最高得分
console.log(labels[index]);
}, {
overlapFactor: 0, // 识别频率
probabilityThreshold: 0.75 // 可能性阈值
});
} else {
// 开关关闭时,停止监听
transferRecognizer.stopListening();
}
};
window.save = () => {
// 将语音识别的文件转换为二进制文件
const arrayBuffer = transferRecognizer.serializeExamples();
// 将二进制文件通过blob转化为文件下载到本地
const blob = new Blob([arrayBuffer]);
const link = document.createElement('a');
link.href = window.URL.createObjectURL(blob);
link.download = 'data.bin';
link.click();
};
<script src="script.js"></script>
监听开关:<input type="checkbox" onchange="toggle(this.checked)">
<style>
.slider {
width: 600px;
overflow: hidden;
margin: 10px auto;
}
.slider > div{
display: flex;
align-items: center;
}
</style>
<div class="slider">
<div>
<img src="https://cdn.pixabay.com/photo/2019/10/29/15/57/vancouver-4587302__480.jpg" alt="" width="600">
<img src="https://cdn.pixabay.com/photo/2019/10/31/07/14/coffee-4591159__480.jpg" alt="" width="600">
<img src="https://cdn.pixabay.com/photo/2019/11/01/11/08/landscape-4593909__480.jpg" alt="" width="600">
<img src="https://cdn.pixabay.com/photo/2019/11/02/21/45/maple-leaf-4597501__480.jpg" alt="" width="600">
<img src="https://cdn.pixabay.com/photo/2019/11/02/03/13/in-xinjiang-4595560__480.jpg" alt="" width="600">
<img src="https://cdn.pixabay.com/photo/2019/11/01/22/45/reschensee-4595385__480.jpg" alt="" width="600">
</div>
</div>
import * as speechCommands from '@tensorflow-models/speech-commands';
const MODEL_PATH = 'http://127.0.0.1:8080';
let transferRecognizer;
let curIndex = 0;
window.onload = async () => {
// 创建识别器
const recognizer = speechCommands.create(
'BROWSER_FFT',
null,
MODEL_PATH + '/speech/model.json',
MODEL_PATH + '/speech/metadata.json',
);
// 确保模型加载完毕
await recognizer.ensureModelLoaded();
// 创建迁移学习器
transferRecognizer = recognizer.createTransfer('轮播图');
// 调用fetch方法,请求保存的声音数据
const res = await fetch(MODEL_PATH + '/slider/data.bin');
// 转换为arrayBuffer
const arrayBuffer = await res.arrayBuffer();
// 将arrayBuffer加载到迁移学习器中
transferRecognizer.loadExamples(arrayBuffer);
// 进行训练
await transferRecognizer.train({ epochs: 30 });
console.log('done');
};
// 监听声音数据,使用训练好的模型进行预测
window.toggle = async (checked) => {
if (checked) {
await transferRecognizer.listen(result => {
const { scores } = result;
const labels = transferRecognizer.wordLabels();
const index = scores.indexOf(Math.max(...scores));
// 执行相应声音分类下的功能交互
window.play(labels[index]);
}, {
overlapFactor: 0,
probabilityThreshold: 0.5
});
} else {
transferRecognizer.stopListening();
}
};
window.play = (label) => {
// 获取.slider元素下的div元素
const div = document.querySelector('.slider>div');
// 根据指令,改变当前索引。注意,索引不能超过首尾范围
if (label === '上一张') {
if (curIndex === 0) { return; }
curIndex -= 1;
} else {
if (curIndex === document.querySelectorAll('img').length - 1) { return; }
curIndex += 1;
}
// 设置动画过渡效果
div.style.transition = "transform 1s"
// 根据图片索引,.slider元素下的div元素进行transform位置移动,从而实现轮播效果
div.style.transform = `translateX(-${100 * curIndex}%)`;
};
TensorFlow.js Converter依赖Python,如果你电脑上还安装或依赖别的Python版本,跟TensorFlow.js Converter依赖的Python版本不一致。为了让这些Python应用在各自独立的Python版本下运行,不相互冲突,我们需要为每个Python应用创建单独的Python环境。 安装Conta工具,就可以帮助我们创建虚拟的独立的Python环境。 https://mirror.tuna.tsinghua.edu.cn/help/anaconda/
如果需要查看已安装虚拟环境:
注意如果需要删除虚拟环境:
进入(激活)和退出该虚拟环境:
查看当前Python版本:
使用conda进行虚拟环境切换:
https://github.com/tensorflow/tfjs
https://github.com/tensorflow/tfjs/tree/master/tfjs-converter
查看是否安装成功:
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
export function file2img(f) {
return new Promise(resolve => {
const reader = new FileReader();
reader.readAsDataURL(f);
reader.onload = (e) => {
const img = document.createElement('img');
img.src = e.target.result;
img.width = 224;
img.height = 224;
img.onload = () => resolve(img);
};
});
}
import * as tf from '@tensorflow/tfjs';
import { IMAGENET_CLASSES } from './imagenet_classes';
import { file2img } from './utils';
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model2/model.json';
window.onload = async () => {
const model = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const input = tf.browser.fromPixels(img)
.toFloat()
.sub(255 / 2)
.div(255 / 2)
.reshape([1, 224, 224, 3]);
return model.predict(input);
});
const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${IMAGENET_CLASSES[index]}`);
}, 0);
};
};
export const IMAGENET_CLASSES = {
0: 'tench, Tinca tinca',
1: 'goldfish, Carassius auratus',
// ...
// ...
// ...
999: 'toilet tissue, toilet paper, bathroom tissue'
}
改为graph_model,也是可以的: