TensorFlow.js 运行现有模型进行图像分类
这篇文章介绍使用 TensorFlow.js 通过 Javascript 在浏览器上运行现有的模型进行图像分类预测,示例源码在文章末尾。
介绍
TensorFlow.js 是一个用于使用 JavaScript 进行机器学习开发的库,可以使用 JavaScript 开发机器学习模型,并直接在浏览器或 Node.js 中使用机器学习模型。
TensorFlow.js 支持以下三种方式进行机器学习模型的开发、训练以及预测:
运行现有模型
使用现成的 JavaScript 模型或转换 Python TensorFlow 模型以在浏览器中或 Node.js 下运行。
重新训练现有模型
使用自己的数据重新训练现有的机器学习模型。
开发机器学习模型
使用灵活且直观的 API 直接用 JavaScript 构建和训练模型。
可谓是极其的强悍,更多的介绍文档,可到官网查阅:https://tensorflow.google.cn/js
实践
1. 创建空项目
为了方便,这里我们使用 Vue CLI 创建一个 TypeScript 的空项目,并修改 src/App.vue 文件中的内容为如下代码:
<template>
<div id="app">
TODO
</div>
</template>
<script lang="ts">
import { Component, Vue } from 'vue-property-decorator'
@Component
export default class App extends Vue {}
</script>
如何创建项目工程可参考这篇文章:使用VueCLI快速搭建TS+Vue工程
启动项目:
npm run serve
2. 准备工作
一张进行分类预测的图片,地址(可以是任意图):
https://github.com/tensorflow/tfjs-models/blob/master/mobilenet/demo/coffee.jpg
将图片下载下来保存到public目录下,并重命名为:coffee.jpg
3. 构建页面
<template>
<div id="app">
<div>
<img id="img" src="/coffee.jpg" :width="imgSize" :height="imgSize" />
</div>
<Button @click="onPredict" :disabled="isRunning || !isLoadModel">预测</Button>
<div v-html="text"></div>
</div>
</template>
<script lang="ts">
import { Component, Vue } from 'vue-property-decorator'
@Component
export default class App extends Vue {
// 图片尺寸
imgSize = 224
// 模型是否已加载
isLoadModel = false
// 正在预测
isRunning = false
// 结果
text = ''
onPredict () {
if (!this.isLoadModel) {
alert('模型加载失败,无法进行预测')
return
} else if (this.isRunning) {
alert('当前正在预测中')
return
}
this.predict()
}
}
</script>
4. 安装TensorFlow.js
执行安装命令:
npm install @tensorflow/tfjs
当前安装的版本为:2.7.0
安装现有模型,这里使用 MobileNet:
npm install @tensorflow-models/mobilenet
MobileNet 模型项目地址:https://github.com/tensorflow/tfjs-models/tree/master/mobilenet
5. 使用TensorFlow.js
①. 引入库
import * as tf from '@tensorflow/tfjs'
import * as MobileNet from '@tensorflow-models/mobilenet'
②. 加载模型
模型会通过网络下载到浏览器中,似乎需要科学上网。
// 图像分类模型 <-- 状态
model: tf.LayersModel | any
mounted () {
this.text = '正在加载模型...'
// 使用cpu
tf.setBackend('cpu')
// 加载模型
MobileNet.load()
.then(mobileNet => {
this.model = mobileNet
this.isLoadModel = true
this.text = '模型加载成功'
})
.catch(e => {
console.error(e)
this.text = e.message
})
}
③. 运行预测
/**
* 预测
*/
predict () {
this.isRunning = true
this.text = '正在预测...'
const img: any = document.getElementById('img')
// 调用模型进行预测,取出前5个
this.model.classify(img, 5)
.then((result: Array<any>) => {
// 输出
const items: Array<string> = []
result.forEach((item: any) => items.push(`${item.className}:${Math.round(100 * item.probability)}%`))
this.text = items.join('<br>')
})
.catch((e: any) => this.text = e)
.finally(() => this.isRunning = false)
}
6. 运行结果
tfjs_demo
模型
官方已支持并封装的模型:
详见模型库项目地址:https://github.com/tensorflow/tfjs-models
链接
TensorFlow.js 官网:https://tensorflow.google.cn/js
TensorFlow.js 已支持的模型:https://github.com/tensorflow/tfjs-models
TensorFlow.js API文档:https://js.tensorflow.org/api/lates