Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【译】前端开发者们,快带上 Tensorflow.js 拥抱机器学习吧 #41

Open
JChehe opened this issue Sep 12, 2019 · 1 comment

Comments

@JChehe
Copy link
Owner

JChehe commented Sep 12, 2019

原文:Machine Learning For Front-End Developers With Tensorflow.js——Charlie Gerard

摘要:结合 JavaScript 和诸如 Tensorflow.js 等框架是入门机器学习的好办法。本文将涵盖 Tensorflow.js 目前提供的三大主要功能,并阐明了在前端使用机器学习的局限性。

机器学习常给人的感觉是属于数据科学家和 Python 开发者的领域。然而,在过去数年,开源框架的涌现使得语言不再成为限制,JavaScript 就是其一。在本文,我们将使用 Tensorflow.js 并结合示例项目去探索在浏览器中使用机器学习的不同可能性。

什么是机器学习?

在开始深入代码前,我们先简单讲解机器学习是什么,及其核心概念和术语。

定义

通用定义是赋予计算机从数据中获得学习能力而没有显式编程的能力。

人工智能领域的先驱者,Arthur Samuel 在 1959 年创造“机器学习”这个概念时,对它下的定义:“Field of study that gives computers the ability to learn without being explicitly programmed”。

如果将其与传统编程相比,意味着让计算机去识别数据中的模式,然后拥有预测能力,而无需我们明确它该做什么。

以欺诈检测为例。这显然没有固定的标准来判断交易是否存在欺诈行为;欺诈可以发生在任何国家/地区、任何账户、任何客户、任何时刻。手动跟踪所有这一切几乎是不可能的。

然而,我们可以利用所收集的欺诈相关数据去训练一个机器学习算法,使其理解数据中的模式,最终生成一个能预测任何新交易是否为诈骗的模型(model)。

核心概念

为理解后续代码案例,我们还需要先学习一些常用术语。

Google 官方机器学习术语表

模型(Model)

当使用数据集训练机器学习算法时,模型就是该训练的输出(结果)。它有点类似函数,将新数据作为输入,产生一个预测作为输出。

标签和特征(Labels And Features)

标签和特征是与你在训练过程中向算法提供的数据有关。

标签是指如何对数据集中每个样本进行分类,以及如何对其打标签。例如,数据集是描述不同动物的 CSV 文件,那么我们的标签可以是 “cat”、“dog” 或 “snake” 之类的词语。

特征是数据集中每个样本的特征。以上述动物为例,它可以是“胡须、喵叫”、“顽皮、犬吠”、“爬行动物、猖獗”等。

这样,机器学习算法就能够找到特征与其标签之间的联系,并用于将来的预测。

神经网络(Neural Networks)

神经网络是一组机器学习算法,其试图通过使用人工神经元层来模仿大脑的工作方式。

本文并不需要你深入了解它的工作方式,但如果想了解更多,下面有一个非常棒的视频:

But what is a Neural Network? | Deep learning, chapter 1

至此,我们已经定义了一些机器学习的常用术语。下面让我们谈谈使用 JavaScript 和 Tensorflow.js 框架能做些什么。

功能

目前支持三大功能:

  1. 使用预训练模型
  2. 迁移学习
  3. 定义、运行并使用自己的模型

我们先从最简单的一个说起。

1. 使用预训练模型

对于你打算解决的问题,可能存在已使用特定数据集训练过的模型,那么你就可以在代码中导入并使用它。

比如说,我们将构建一个用于预测图片是否是猫的网站。那么流行的图像分类模型 MobileNet 可作为 Tensorflow.js 的预训练模型来使用。

实现代码如下:(译者注:mobilenet 需要全局翻墙才能加载成功)

<html lang="en">
  <head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <meta http-equiv="X-UA-Compatible" content="ie=edge">
    <title>Cat detection</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.1"> </script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"> </script>
  </head>
  <body>
    <img id="image" alt="cat laying down" src="cat.jpeg"/>

    <script>
      const img = document.getElementById('image');

      const predictImage = async () => {
        console.log("Model loading...");
        const model = await mobilenet.load();
        console.log("Model is loaded!")

        const predictions = await model.classify(img);
        console.log('Predictions: ', predictions);
      }
      predictImage();
    </script>
  </body>
</html>

首先在 HTML 头部引入 Tensorflow.js 和 MobileNet 模型:

<script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/1.0.1/tf.js"> </script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"> </script>

然后,在 body 内放置一张用于预测的 img 元素:

<img id="image" alt="cat laying down" src="cat.jpeg"/>

最后,在 script 标签内加载预训练 MobileNet 模型,并对 #image 元素进行分类。该分类会返回一个含有 3 个预测结果的数组,并根据概率分数进行排序(分数最高的排第一位)。

const predictImage = async () => {
  console.log("Model loading...");
  const model = await mobilenet.load();
  console.log("Model is loaded!")
  const predictions = await model.classify(img);
  console.log('Predictions: ', predictions);
}

predictImage();

以上就是在浏览器上使用预训练模型和 Tensorflow.js 的方式。

Note:如果你想知道 MobileNet 模型还能进行哪些分类,可以在 Github 的 这里 寻找。

需要引起注意的是:在浏览器中加载一个预训练模型会比较耗时(可能超过 10 秒),所以你需要做预加载或调整界面交互以减轻对用户的影响。

如果你更倾向于将 Tensorflow.js 视为一个 NPM 模块,那么可通过以下方式导入:

import * as mobilenet from '@tensorflow-models/mobilenet';

点击 CodeSandbox 可以随意玩耍这个案例。

现在我们已经学会了如何使用预训练模型,下面我们看看第二个功能:迁移学习。

2. 迁移学习(Transfer Learning)

迁移学习是一种将预训练模型和自定义训练数据相结合的能力。换句话说,你可以利用模型的功能并添加自己的样例,而无需从零开始创建所有内容。

例如,现在有一个图片分类模型,它由一个已被数千张训练过的算法得到。此时你无需从零开始,因为迁移学习允许你将新的自定义图片样本与预训练模型组合得到一个新的图片分类器。

为了更好地说明,我们在原来代码的基础上进行了调整,以对新图像进行分类:

Note:下面就是最终实验结果,可以点击 这里 体验。

transfer learning demo

下面是该案例中最重要的代码段,可在 CodeSandbox 查看完整代码。

开始部分仍需要引入 Tensorflow.js 和 MobileNet,但这次还需要加上 KNN(k-nearest neighbor) 分类器:

<!-- Load TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!-- Load MobileNet -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
<!-- Load KNN Classifier -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>

需要分类器(而不只是使用 MobileNet 模块)的原因是:我们加入了先前没有的自定义样本数据。所以 KNN 分类器的作用是:将所有东西结合在一起,从而能基于结合后的数据作出预测行为。

然后,我们用 video 标签替换掉猫的图像,即从摄像头获取图像。

<video autoplay id="webcam" width="227" height="227"></video>

最后,需要在页面上添加几个按钮,用于记录一些视频样本作为标签和启动预测。

<section>
  <button class="button">Left</button>

  <button class="button">Right</button>

  <button class="test-predictions">Test</button>
</section>

现在,开始编写 JavaScript 文件。首先,定义一些重要的变量:

// 类别的数量
const NUM_CLASSES = 2;
// 类别的标签
const classes = ["Left", "Right"];
// 网络摄像头的图片尺寸,必须为 227
const IMAGE_SIZE = 227;
// KNN 的 K 值
const TOPK = 10;

const video = document.getElementById("webcam");

本案例要对摄像头中的用户头部向左/右倾斜进行分类,所以需要两个类别标签 leftright

之所以将图片尺寸设置为 227 像素,是为了匹配已训练过 MobileNet 模型的数据的格式。即后者必须采用相同格式才能对新数据进行分类。

如果样本确实较大,就需要在输入 KNN 分类器前对数据进行大小调整。

接着,我们将 K 值设为 10。K 值对于 KNN 算法很重要,它表示对新输入进行分类时考虑多少实例。

在该案例中,10 表示在预测新数据的标签时,将从训练数据中最邻近的 10 个数据查找。

最后,获取 video 元素。

对于实现逻辑,我们从加载模型和分类器开始:

async load() {
    const knn = knnClassifier.create();
    const mobilenetModule = await mobilenet.load();
    console.log("model loaded");
}

然后,获取视频源:

navigator.mediaDevices
  .getUserMedia({ video: true, audio: false })
  .then(stream => {
    video.srcObject = stream;
    video.width = IMAGE_SIZE;
    video.height = IMAGE_SIZE;
  });

紧接着,为按钮绑定事件,以记录样例数据:

setupButtonEvents() {
    for (let i = 0; i < NUM_CLASSES; i++) {
      let button = document.getElementsByClassName("button")[i];

      button.onmousedown = () => {
        this.training = i;
        this.recordSamples = true;
      };
      button.onmouseup = () => (this.training = -1);
    }
  }

获取摄像头的图像样例,并重新格式化它们,最后将它们结合到 MobileNet 模块:

// 从 video 元素获取图像数据
const image = tf.browser.fromPixels(video);

let logits;
// 'conv_preds' 是 MobileNet 的对数激活(logits activation)
const infer = () => this.mobilenetModule.infer(image, "conv_preds");

// 当其中一个按钮按下时,则训练类别
if (this.training != -1) {
  logits = infer();

  // 添加当前图像到分类器
  this.knn.addExample(logits, this.training);
}

最后,一旦我们收集到一些摄像头图像,我们就可以使用以下代码进行预测:

logits = infer();
const res = await this.knn.predictClass(logits, TOPK);
const prediction = classes[res.classIndex];

当不再需要摄像头数据时,我们可以将其释放:

image.dispose();
if (logits != null) {
  logits.dispose();
}

再次提醒,如果想查看完整代码,请点击 CodeSandbox

3. 在浏览器中训练一个模型

最后一个功能是完全在浏览器上定义、训练并使用模型。我们将通过构建一个识别鸢尾花种类的案例来阐述。

为此,我们将基于开源数据集创建一个能对鸢尾花进行分类(三种类别,分别是:Setosa、Virginica 和 Versicolor)的神经网络。

在线案例地址完整代码

Edit tfjs-all

每个机器学习项目的核心是数据集。项目初期首先要做的事情之一是将数据集拆分为训练集和测试集。

这样做的原因是我们将用训练集训练算法,用测试集检查预测的准确性,以验证模型是可用还是需要调整。

Note:为了让事情变简单,我已将训练集和测试集拆分为两个 JSON 文件,可以在 CodeSanbox 找到。

训练集含有 130 项,测试集含有 14 项。数据看起来像这样:

{
  "sepal_length": 5.1,
  "sepal_width": 3.5,
  "petal_length": 1.4,
  "petal_width": 0.2,
  "species": "setosa"
}

如你所见,萼片(sepal)和花瓣(petal)的长宽是四个不同的特征,以及作为标签的物种(species)。

为了能被 Tensorflow.js 所用,我们需要将数据转成框架能理解的数据格式。对于本例,训练数据的 [130, 4] 表示有 130 个带有 4 个特征的鸢尾花。

import * as trainingSet from "training.json";
import * as testSet from "testing.json";

const trainingData = tf.tensor2d(
  trainingSet.map(item => [
    item.sepal_length,
    item.sepal_width,
    item.petal_length,
    item.petal_width
  ]),
  [130, 4]
);

const testData = tf.tensor2d(
  testSet.map(item => [
    item.sepal_length,
    item.sepal_width,
    item.petal_length,
    item.petal_width
  ]),
  [14, 4]
);

接着,我们需要将输出数据进行转换:

const output = tf.tensor2d(trainingSet.map(item => [
    item.species === 'setosa' ? 1 : 0,
    item.species === 'virginica' ? 1 : 0,
    item.species === 'versicolor' ? 1 : 0

]), [130,3])

然后,一旦数据准备就绪就可以开始创建模型:

const model = tf.sequential();

model.add(tf.layers.dense(
    {
        inputShape: 4,
        activation: 'sigmoid',
        units: 10
    }
));

model.add(tf.layers.dense(
    {
        inputShape: 10,
        units: 3,
        activation: 'softmax'
    }
));

上述代码中,我们首先实例化一个序列模型,添加一个输入层和输出层。

对于参数(inputShapre,activation 和 units),它们超出了本文的叙述范围。它们取决于你所创建的模型、数据的类型等。

一旦模型就绪,我们将能使用测试数据进行验证:

async function train_data(){
    for(let i=0;i<15;i++){
      const res = await model.fit(trainingData, outputData,{epochs: 40});
    }
}

async function main() {
  await train_data();
  model.predict(testSet).print();
}

验证通过后,我们就可以供用户使用。

每次调用 main 函数,预测的输出大概像这样:

[1,0,0] // Setosa
[0,1,0] // Virginica
[0,0,1] // Versicolor

预测返回的结果是一个含有 3 个数值的数组,它们分别表示属于三个种类之一的概率。即数值越接近 1 ,概率越高。

如果分类的输出结果是 [0.0002, 0.9494, 0.0503],数组的第二个数值最大,那么该模型预测新输入大概率是 Virginica。

这就是 Tensorflow.js 的一个简单神经网络!

在这里,我们仅讨论了一个关于鸢尾花的小数据集,但对于更大的数据集或图像,步骤是一致的:

  • 收集数据;
  • 分割为训练集和测试集;
  • 将数据转换为 Tensorflow.js 能理解的格式;
  • 选择算法;
  • 拟合数据;
  • 预测。

如果你想分享该模型到另一个应用,那么可以保存它:

await model.save('file:///path/to/my-model'); // in Node.js

Note:关于保存模型的更多可选项,可看 这里

局限性

是的!我们刚刚介绍了 Tensorflow.js 目前提供的三大功能。

在文章结尾之前,我认为有必要简单提一下在前端使用机器学习的局限性。

1. 性能

导入一个预训练的模型会对应用产生性能问题。例如,一些对象检测模型超过 10MB,这无疑会显著减慢网站的加载速度。所以,请务必考虑用户体验,并优化资源的加载,以改善感知性能。

感知性能(Perceived Performance):计算机工程中的感知性能是指软件功能在执行其任务时的速度。该概念主要适用于用户接受方面。 通过显示启动屏幕或文件进度对话框,应用程序启动或下载文件所需的时间不会更快。但是,它满足了一些人的需求:它对用户来说似乎更快,并提供了一个视觉提示,让他们知道系统正在处理他们的请求。——维基百科

2. 输入数据的质量

如果从零开始构建模型,那么必须要自己收集数据或从一些开源的数据集中寻找。

在进行任何类型的数据处理或尝试不同的算法之前,请务必检查输入数据的质量。举例来说,你尝试构建一个用于识别文本情感的情感分析模型,那么就需要确保数据的准确性和多样性。如果使用的数据质量较低,那么训练的结果将毫无用处。

3. 法律责任

使用开源的预训练模型无疑是简单快速的。然而,这也意味着你可能不知道它如何生成、数据集由什么构成,甚至使用了何种算法。有些模型被称为“黑盒子”,这意味着你并不知道它们是如何预测某种输出。

对于部分构建意图,这可能存在一定问题。例如,假设你使用一个基于扫描影象的机器学习模型来帮助人们检查患癌可能性。万一发生假阴性(模型预测病人未患癌,但实际已患癌),这可能牵涉到法律责任。此时,你必须能够解释为何该模型作出这一预测。

总结

总的来说,结合 JavaScript 和诸如 Tensorflow.js 等框架是入门机器学习的好办法。尽管生产环境的应用程序应该使用 Python 之类的语言,但 JavaScript 降低了门槛,让开发者能体验各种功能和理解基础概念。这无疑降低了在决定投入精力学习另一种语言前的成本。

本教程仅介绍了 Tensorflow.js 的可能性,但其他库和工具的生态正在成长。更多特定领域的框架能让你结合机器学习进行探索,例如音乐方面的 Magenta.js,或预测网站用户浏览行为的 guess.js

随着工具越来越高效,使得构建支持机器学习的 JavaScript 应用成为了可能,这无疑令人兴奋。现在就是学习它的好时机,因为社区正努力改善其可用性。

更多资源

如果想学习更多,下面提供一些资源:

其他框架和工具

案例、模型和数据集

灵感

感谢阅读!

@wangrongding
Copy link

好文👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants