【ArchSummit】如何通过AIOps推动可量化的业务价值增长和效率提升?>>> 了解详情
写点什么

MXNet API 入门 —第 4 篇

  • 2017-07-16
  • 本文字数:3009 字

    阅读完需:约 10 分钟

第3 篇文章中,我们构建并训练了第一个神经网络,接下来可以处理一些更复杂的样本了。

最顶尖的深度学习模型通常都复杂到让人难以置信。其中可能包含数百层,就算用不了数周,往往也要数天时间来使用海量数据进行训练。这类模型的构建和优化需要大量经验。

好在这些模型的使用还是很简单的,通常只需要编写几行代码。本文将使用一个名为 Inception v3的预训练模型进行图片分类。

Inception v3

诞生于 2015 年 12 月的 Inception v3 GoogleNet 模型(曾赢得 2014 年度 ImageNet 挑战赛)的改进版。本文不准备深入介绍该模型的研究论文,不过打算强调一下论文的结论:相比当时最棒的模型,Inception v3 的准确度高出了15%–25%,同时计算的经济性方面低六倍,并且至少将参数的数量减少了五倍(例如使用该模型对内存的要求更低)。

简直就是神器!那么我们该如何使用?

MXNet model zoo

Model zoo 提供了一系列可直接使用的预训练模型,并且通常还会提供模型定义模型参数(例如神经元权重),(也许还会提供)使用说明。

首先来下载定义和参数(你也许需要更改文件名)。第一个文件可以直接打开:其中包含了每一层的定义。第二个文件是一个二进制文件,请不要打开 ;)

复制代码
$ wget http://data.dmlc.ml/models/imagenet/inception-bn/Inception-BN-symbol.json
$ wget http://data.dmlc.ml/models/imagenet/inception-bn/Inception-BN-0126.params
$ mv Inception-BN-0126.params Inception-BN-0000.params

该模型已通过 ImageNet 数据集进行了训练,因此我们还需要下载对应的图片分类清单(共有 1000 个分类)。

复制代码
$ wget http://data.dmlc.ml/models/imagenet/synset.txt
$ wc -l synset.txt
1000 synset.txt
$ head -5 synset.txt
n01440764 tench, Tinca tinca
n01443537 goldfish, Carassius auratus
n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
n01491361 tiger shark, Galeocerdo cuvieri
n01494475 hammerhead, hammerhead shark

搞定,开始实战。

加载模型

我们需要:

  • 加载处于保存状态的模型:MXNet 将其称之为检查点 (Checkpoint)。随后即可得到输入的 Symbol 和模型参数。 ```

    import mxnet as mx

    sym, arg_params, aux_params = mx.model.load_checkpoint(‘Inception-BN’, 0)

复制代码
- 新建一个 Module 并为其指派输入 Symbol。我们还可以使用一个 Context 参数决定要在哪里运行该模型:默认值为 cpu(0),但也可改为 gpu(0) 以便通过 GPU 运行。 ```
mod = mx.mod.Module(symbol=sym)
  • 将输入 Symbol 绑定至输入数据。将其称之为“数据”是因为在网络的输入层中就使用了这样的名称(可以从 JSON 文件的前几行代码中看到)。
  • 将“数据”的形态 (Shape)定义为 1x3x224x224。别慌 ;),“224x224”是图片的分辨率,模型就是这样训练出来的。“3”是通道数量:红绿蓝(严格按照这样的顺序),“1”是批大小:我们将一次预测一张图片。
复制代码
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
  • 设置模型参数。 ```

    mod.set_params(arg_params, aux_params)

复制代码
这样就可以了。只需要四行代码!随后可以放入一些数据看看会发生什么。嗯……先别急。
## 准备数据
数据准备:从七十年代以来,这一直是个痛苦的过程……从关系型数据库到机器学习,再到深度学习,这方面没有任何改进。虽然乏味但很必要。开始吧。
还记得吗,这个模型需要通过四维 NDArray 来保存一张 224x224 分辨率图片的红、绿、蓝通道数据。我们将使用流行的 [OpenCV](http://www.opencv.org/) 库从输入图片中构建这样的 NDArray。如果还没安装 OpenCV,考虑到本例的要求,直接运行 pip install opencv-python 就够了 :)。
随后的步骤如下:
- ** 读取 ** 图片:将返回一个 Numpy 数组,其形态为(图片高度, 图片宽度, 3),按顺序代表 **BGR**(蓝、绿、红)三个通道。 ```
img = cv2.imread(filename)
{1}
  • 将图片转换为 RGB。 ```

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

复制代码
- 将图片 ** 调整大小 ** 至 **224x224**。 ```
img = cv2.resize(img, (224, 224,))
  • 重塑数组的形态,从(图片高度, 图片宽度, 3)重塑为(3, 图片高度, 图片宽度)。 ```

    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)

复制代码
- 添加一个 ** 第四维度 ** 并构建 NDArray ```
img = img[np.newaxis, :]
array = mx.nd.array(img)
>>> print array.shape
(1L, 3L, 224L, 224L)

晕了?一起用个例子看看吧。输入下列这张图片:

输入 448x336 的图片(来源:metaltraveller.com)

处理完毕后,该图会被缩小尺寸并拆分为 RGB 通道,存储在 array[0] 中(生成下文图片的代码可参阅这里)。

array[0][0]:224x224,红色通道

array 0 :224x224,绿色通道

array 0 :224x224,蓝色通道

如果批大小大于 1,那么可以通过 array 1 指定第二张图片,使用 array 2 指定第三张图片,以此类推。

无论这个过程是乏味还是有趣,接下来我们开始预测吧!

开始预测

你可能还记得第 3 篇文章中提到,Module 对象必须以为单位向模型提供数据:最常见的做法是使用数据迭代器(因此我们使用了 NDArrayIter 对象)。

在这里我们想要预测一张图片,因此尽管可以使用数据迭代器,不过也没啥必要。但我们可以创建一个名为 Batch 的具名元组 (Named tuple), 它可以充当假的迭代器,在引用数据属性时返回输入的 NDArray。

复制代码
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

随后即可将这个“Batch”传递给模型开始预测。

复制代码
mod.forward(Batch([array]))

这个模型会输出一个包含1000 个可能性的 NDArray,每个可能性对应一个分类。由于批大小等于 1,因此只需要一行代码。

复制代码
prob = mod.get_outputs()[0].asnumpy()
>>> prob.shape
(1, 1000)

使用 squeeze() 将其转换为数组,随后使用 argsort() 创建第二个数组,其中保存了这些可能性按照降序排列的指数

复制代码
prob = np.squeeze(prob)
>>> prob.shape
(1000,)
>> prob
[ 4.14978594e-08 1.31608676e-05 2.51907986e-05 2.24045834e-05
2.30327873e-06 3.40798979e-05 7.41563645e-06 3.04062659e-08 etc.
sortedprob = np.argsort(prob)[::-1]
>> sortedprob.shape
(1000,)

根据模型的计算,这张图片最可能的分类是#546,可能性为58%

复制代码
>> sortedprob
[546 819 862 818 542 402 650 420 983 632 733 644 513 875 776 917 795
etc.
>> prob[546]
0.58039135

这个分类叫什么名字呢?我们可以使用 synset.txt 文件构建分类清单,并找出 546 号的名称。

复制代码
synsetfile = open('synset.txt', 'r')
categorylist = []
for line in synsetfile:
categorylist.append(line.rstrip())
>>> categorylist[546]
'n03272010 electric guitar'

可能性第二大的分类是什么?

复制代码
>>> prob[819]
0.27168664
>>> categorylist[819]
'n04296562 stage

挺棒的,你说呢?

就是这样,我们已经了解了如何使用预训练的顶尖模型进行图片分类。而这一切只需要4 行代码……除此之外只要准备好数据就够了。

完整代码如下,请自行尝试并继续保持关注 ??

代码已发布至 GitHub: mxnet_example2.py

后续内容:

  • 第 5 篇:进一步了解预训练模型(VGG16 和 ResNet-152)
  • 第 6 篇:通过树莓派进行实时物体检测(并让它讲话!)

作者 Julien Simon 阅读英文原文 An introduction to the MXNet API?—?part 4


感谢杜小芳对本文的审校。

给InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 editors@cn.infoq.com 。也欢迎大家通过新浪微博( @InfoQ @丁晓昀),微信(微信号: InfoQChina )关注我们。

2017-07-16 17:037213
用户头像

发布了 283 篇内容, 共 101.9 次阅读, 收获喜欢 61 次。

关注

评论

发布
暂无评论
发现更多内容

活动回顾|火山引擎DataLeap分享:DataOps、数据治理、指标体系最佳实践(文中领取PPT)

字节跳动数据平台

数据中台 数据治理 抖音 DataOps 企业号 7 月 PK 榜

探索Linux命名空间和控制组:实现资源隔离与管理的双重利器

柠檬汁Code(binbin0325)

Linux 容器 namespace 底层原理 Cgroups

飞桨AI Studio可以玩多模态了?MiniGPT4实战演练!

飞桨PaddlePaddle

人工智能 百度 paddle 飞桨 百度飞桨

未来前端框架将持续推进组件化开发

没有用户名丶

Android 架构模式如何选择

vivo互联网技术

mvc Compose MVVM 解耦 MVI

分享一些常用的开源博客社区网站

兮动人

博客 开源社区

如何为Spring和Mybatis增加可逆计算支持

canonical

Spring Boot mybatis 低代码 可逆计算 Nop平台

在 Go 语言单元测试中如何解决 MySQL 存储依赖问题

江湖十年

golang Web 后端 单元测试 测试 单元测试

TypeScript 玩转类型操作之字符串处理能力

小乌龟快跑

typescript 面试 前端

这次是运行在 Intel AIxBoard™ 开发板上的 TDengine 预测“未来”

爱倒腾的程序员

细数不懂Spring底层原理带来的伤与痛

java易二三

spring 程序员 Spring Boot 计算机 底层原理

Linux系统安装gcc详细教程。

百度搜索:蓝易云

云计算 Linux 运维 服务器 GCC

工赋开发者社区 | 面向CPS的制造执行系统(MES)实验平台验证

工赋开发者社区

Linux系统安装MySQL详细教程

百度搜索:蓝易云

MySQL 云计算 Linux 运维 服务器

2023 云原生编程挑战赛火热报名中!导师解析 Serverless 冷启动赛题

阿里巴巴云原生

阿里云 Serverless 云原生

RLHF 技术:如何能更有效?又有何局限性?

Baihai IDP

人工智能 强化学习 白海科技 RLHF 大语言模型

PoseiSwap:通过 RWA 的全新叙事,反哺 Nautilus Chain 生态

大瞿科技

【华秋干货铺】一文轻松搞定PCB叠层和阻抗设计

华秋电子

合并k个已排序的链表

攻城狮Wayne

Nodejs快速搭建简单的HTTP服务器详细教程。

百度搜索:蓝易云

node.js 云计算 Linux 运维 HTTP

企业号 8 月 PK 榜,火热开启!

InfoQ写作社区官方

热门活动 企业号 8 月 PK 榜

面试官:说出 Java 中的 7 种重试机制

java易二三

编程 程序员 面试 计算机

特性快闪:使用 Databend 玩转 Iceberg

Databend

解析游戏陪练app源码的开发与意义

山东布谷网络科技

游戏 开源代码 APP软件开发

7月征文活动结果出炉,快来看看有没有你

InfoQ写作社区官方

热门活动 年中技术盘点

直播平台源码开发,信息收发功能搭建

山东布谷科技

软件开发 直播 源码搭建 消息发送 直播平台源码

火山引擎AB测试:广告实验深度打通巨量引擎,高效测试广告素材

字节跳动数据平台

大数据 A/B测试 对比试验 企业号 7 月 PK 榜 数字化增长

Spring AOP 中的代理对象是怎么创建出来的?

江南一点雨

Java spring

Centos7系统中找不到yum及安装方法。

百度搜索:蓝易云

云计算 Linux centos 运维 yum

工赋开发者社区 | 复杂电子装备制造数字化工厂实现逻辑与实施步骤

工赋开发者社区

零信任体系化能力建设(1):身份可信与访问管理

权说安全

MXNet API入门 —第4篇_语言 & 开发_Julien Simon_InfoQ精选文章