TensorFlow.js 运行现有模型进行图像分类

738

这篇文章介绍使用 TensorFlow.js 通过 Javascript 在浏览器上运行现有的模型进行图像分类预测,示例源码在文章末尾。

介绍

TensorFlow.js 是一个用于使用 JavaScript 进行机器学习开发的库,可以使用 JavaScript 开发机器学习模型,并直接在浏览器或 Node.js 中使用机器学习模型。

TensorFlow.js 支持以下三种方式进行机器学习模型的开发、训练以及预测:

运行现有模型

使用现成的 JavaScript 模型或转换 Python TensorFlow 模型以在浏览器中或 Node.js 下运行。

重新训练现有模型

使用自己的数据重新训练现有的机器学习模型。

开发机器学习模型

使用灵活且直观的 API 直接用 JavaScript 构建和训练模型。

可谓是极其的强悍,更多的介绍文档,可到官网查阅:https://tensorflow.google.cn/js

实践

1. 创建空项目

为了方便,这里我们使用 Vue CLI 创建一个 TypeScript 的空项目,并修改 src/App.vue 文件中的内容为如下代码:

<template>
  <div id="app">
    TODO
  </div>
</template>

<script lang="ts">
import { Component, Vue } from 'vue-property-decorator'

@Component
export default class App extends Vue {}
</script>

如何创建项目工程可参考这篇文章:使用VueCLI快速搭建TS+Vue工程

启动项目:

npm run serve

访问地址:http://localhost:8080

2. 准备工作

  • 一张进行分类预测的图片,地址(可以是任意图):

https://github.com/tensorflow/tfjs-models/blob/master/mobilenet/demo/coffee.jpg

将图片下载下来保存到public目录下,并重命名为:coffee.jpg

3. 构建页面

<template>
  <div id="app">
    <div>
      <img id="img" src="/coffee.jpg" :width="imgSize" :height="imgSize" />
    </div>
    <Button @click="onPredict" :disabled="isRunning || !isLoadModel">预测</Button>
    <div v-html="text"></div>
  </div>
</template>

<script lang="ts">
import { Component, Vue } from 'vue-property-decorator'

@Component
export default class App extends Vue {
  // 图片尺寸
  imgSize = 224
  // 模型是否已加载
  isLoadModel = false
  // 正在预测
  isRunning = false
  // 结果
  text = ''

  onPredict () {
    if (!this.isLoadModel) {
      alert('模型加载失败,无法进行预测')
      return
    } else if (this.isRunning) {
      alert('当前正在预测中')
      return
    }
    this.predict()
  }
}
</script>

4. 安装TensorFlow.js

执行安装命令:

npm install @tensorflow/tfjs

当前安装的版本为:2.7.0

安装现有模型,这里使用 MobileNet:

npm install @tensorflow-models/mobilenet

MobileNet 模型项目地址:https://github.com/tensorflow/tfjs-models/tree/master/mobilenet

5. 使用TensorFlow.js

①. 引入库

import * as tf from '@tensorflow/tfjs'
import * as MobileNet from '@tensorflow-models/mobilenet'

②. 加载模型

模型会通过网络下载到浏览器中,似乎需要科学上网。

// 图像分类模型  <-- 状态
model: tf.LayersModel | any

mounted () {
  this.text = '正在加载模型...'
  // 使用cpu
  tf.setBackend('cpu')
  // 加载模型
  MobileNet.load()
    .then(mobileNet => {
      this.model = mobileNet
      this.isLoadModel = true
      this.text = '模型加载成功'
    })
    .catch(e => {
      console.error(e)
      this.text = e.message
    })
}

③. 运行预测

/**
 * 预测
 */
predict () {
  this.isRunning = true
  this.text = '正在预测...'
  const img: any = document.getElementById('img')
  // 调用模型进行预测,取出前5个
  this.model.classify(img, 5)
      .then((result: Array<any>) => {
        // 输出
        const items: Array<string> = []
        result.forEach((item: any) => items.push(`${item.className}:${Math.round(100 * item.probability)}%`))
        this.text = items.join('<br>')
      })
      .catch((e: any) => this.text = e)
      .finally(() => this.isRunning = false)
}

6. 运行结果

tfjs_demo

tfjs_demo

模型

官方已支持并封装的模型:

Type

Model

Demo

Details

Install

Images

MobileNet

Classify images with labels from the ImageNet database.

npm i @tensorflow-models/mobilenet

source

PoseNet

live

A machine learning model which allows for real-time human pose estimation in the browser. See a detailed description here.

npm i @tensorflow-models/posenet

source

Coco SSD

Object detection model that aims to localize and identify multiple objects in a single image. Based on the TensorFlow object detection API.

npm i @tensorflow-models/coco-ssd

source

BodyPix

live

Real-time person and body part segmentation in the browser using TensorFlow.js.

npm i @tensorflow-models/body-pix

source

DeepLab v3

Semantic segmentation

npm i @tensorflow-models/deeplab

source

Audio

Speech Commands

live

Classify 1 second audio snippets from the speech commands dataset.

npm i @tensorflow-models/speech-commands

source

Text

Universal Sentence Encoder

Encode text into a 512-dimensional embedding to be used as inputs to natural language processing tasks such as sentiment classification and textual similarity.

npm i @tensorflow-models/universal-sentence-encoder

source

Text Toxicity

live

Score the perceived impact a comment might have on a conversation, from "Very toxic" to "Very healthy".

npm i @tensorflow-models/toxicity

source

General Utilities

KNN Classifier

This package provides a utility for creating a classifier using the K-Nearest Neighbors algorithm. Can be used for transfer learning.

npm i @tensorflow-models/knn-classifier

source

详见模型库项目地址:https://github.com/tensorflow/tfjs-models

链接

源码

获取源码