【AICon】探索RAG 技术在实际应用中遇到的挑战及应对策略!AICon精华内容已上线73%>>> 了解详情
写点什么

MXNet API 入门 —第 3 篇

  • 2017-07-13
  • 本文字数:4703 字

    阅读完需:约 15 分钟

第2 篇文章中,我们介绍了如何使用Symbols 定义计算中使用的Graph,并处理存储在NDArray(在第1 篇文章中有介绍)中的数据。

本文将介绍如何使用Symbol 和NDArray 准备所需数据并构建神经网络。随后将使用 Module API 训练该网络并预测结果。

定义数据集

我们(设想中的)数据集包含1000 个数据样本

  • 每个样本有100 个特征
  • 每个特征体现为一个介于 0 和 1 之间的浮点值
  • 样本被分为10 个类别,我们将使用神经网络预测特定样本的恰当类别,
  • 我们将使用 800 个样本进行训练,使用 200 个样本进行验证
  • 训练和验证过程的批大小为 10。
复制代码
import mxnet as mx
import numpy as np
import logging
logging.basicConfig(level=logging.INFO)
sample_count = 1000
train_count = 800
valid_count = sample_count - train_count
feature_count = 100
category_count = 10
batch=10

生成数据集

我们将通过均匀分布的方式生成这 1000 个样本,将其存储在一个名为“X”的 NDArray 中:1000 行,100 列

复制代码
X = mx.nd.uniform(low=0, high=1, shape=(sample_count,feature_count))
>>> X.shape
(1000L, 100L)
>>> X.asnumpy()
array([[ 0.70029777, 0.28444085, 0.46263582, ..., 0.73365158,
0.99670047, 0.5961988 ],
[ 0.34659418, 0.82824177, 0.72929877, ..., 0.56012964,
0.32261589, 0.35627609],
[ 0.10939316, 0.02995235, 0.97597599, ..., 0.20194994,
0.9266268 , 0.25102937],
...,
[ 0.69691515, 0.52568913, 0.21130568, ..., 0.42498392,
0.80869114, 0.23635457],
[ 0.3562004 , 0.5794751 , 0.38135922, ..., 0.6336484 ,
0.26392782, 0.30010447],
[ 0.40369365, 0.89351988, 0.88817406, ..., 0.13799617,
0.40905532, 0.05180593]], dtype=float32)

这 1000 个样本的类别用介于 0-9 的整数来代表,类别是随机生成的,存储在一个名为“Y”的 NDArray 中。

复制代码
Y = mx.nd.empty((sample_count,))
for i in range(0,sample_count-1):
Y[i] = np.random.randint(0,category_count)
>>> Y.shape
(1000L,)
>>> Y[0:10].asnumpy()
array([ 3., 3., 1., 9., 4., 7., 3., 5., 2., 2.], dtype=float32)

拆分数据集

随后我们将针对训练验证两个用途对数据集进行80/20拆分。为此需要使用 NDArray.crop 函数。在这里,数据集是完全随机的,因此可以使用前 80% 的数据进行训练,用后 20% 的数据进行验证。实际运用中,我们可能需要首先搅乱数据集,这样才能避免按顺序生成的数据可能造成的偏差。

复制代码
X_train = mx.nd.crop(X, begin=(0,0), end=(train_count,feature_count-1))
X_valid = mx.nd.crop(X, begin=(train_count,0), end=(sample_count,feature_count-1))
Y_train = Y[0:train_count]
Y_valid = Y[train_count:sample_count]

至此数据已经准备完毕!

构建网络

这个网络其实很简单,一起看看其中的每一层:

  • 输入层是由一个名为“Data”的 Symbol 代表的,随后会绑定至实际的输入数据。 ```

    data = mx.sym.Variable(‘data’)

复制代码
- fc1 是 ** 第一个隐藏层 **,通过 **64 个相互连接的神经元 ** 构建而来,输入层的每个特征都会连接至所有的 64 个神经元。如你所见,我们使用了高级的 Symbol.FullyConnected 函数,相比手工建立每个连接,这种做法更方便一些! ```
fc1 = mx.sym.FullyConnected(data, name='fc1', num_hidden=64)
  • fc1 的每个输出会进入到一个激活函数 (Activation function) 。在这里我们将使用一个线性整流单元 (Rectified linear unit) ,即“Relu”。之前承诺过尽量少讲理论知识,因此可以这样理解:激活函数将用于决定是否要“启动”某个神经元,例如其输入是否由足够有意义,可以预测出正确的结果。 ```

    relu1 = mx.sym.Activation(fc1, name=‘relu1’, act_type=“relu”)

复制代码
- fc2 是 ** 第二个隐藏层 **,由 **10 个相互连接的神经元 ** 构建而来,可映射至我们的 **10 个分类 **。每个神经元可输出一个任意标度 (Arbitrary scale) 的浮点值。10 个值中最大的那个代表了数据样本 ** 最有可能的类别 **。 ```
fc2 = mx.sym.FullyConnected(relu1, name='fc2', num_hidden=category_count)
  • 输出层会将 Softmax 函数应用给来自 fc2 层的 10 个值:这些值会被转换为 10 个介于 0 和 1 之间的值,所有值的总和为 1。每个值代表预测出的每个类别的可能性,其中最大的值代表最有可能的类别。 ```

    out = mx.sym.SoftmaxOutput(fc2, name=‘softmax’)
    mod = mx.mod.Module(out)

复制代码
## 构建数据迭代器
在第 1 篇文章中,我们了解到神经网络并不会一次只训练一个样本,因为这样做从性能的角度来看效率太低。因此我们会使用 ** 批 **,即 ** 一批固定数量的样本 **
为了给神经网络提供这样的“批”,我们需要使用 NDArrayIter 函数构建一个 ** 迭代器 **。其参数包括 ** 训练数据 **、分类(MXNet 将其称之为 ** 标签 (Label)**),以及 ** 批大小 **
如你所见,我们可以对整个数据集进行迭代,同时对 10 个样本和 10 个标签执行该操作。随后即可调用 reset() 函数将迭代器恢复为初始状态。

train_iter = mx.io.NDArrayIter(data=X_train,label=Y_train,batch_size=batch)

for batch in train_iter:
… print batch.data
… print batch.label

[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]
[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]
[<NDArray 10x99 @cpu(0)>]
[<NDArray 10 @cpu(0)>]

train_iter.reset()

复制代码
网络已经准备完成,开始训练吧!
## 训练模型
首先将输入 Symbol\*\* 绑定\*\* 至实际的数据集(样本和标签),这时候就会用到迭代器。

mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)

复制代码
随后对网络中的神经元权重进行 ** 初始化 **。这个步骤非常重要:使用“恰当”的技术对齐进行初始化可以帮助网络 ** 更快速地 ** 学习。此时可用的技术很多,Xavier 初始化器(名称源自该技术的发明人 Xavier Glorot?—?[PDF](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf))就是其中之一。

Allowed, but not efficient

mod.init_params()

Much better

mod.init_params(initializer=mx.init.Xavier(magnitude=2.))

复制代码
接着需要定义 ** 优化 ** 参数:
- 我们将使用 [随机坡降法 (Stochastic Gradient Descent)](https://en.wikipedia.org/wiki/Stochastic_gradient_descent) 算法(又名 SGD),该算法在机器学习和深度学习领域有着广泛的应用。
- 我们会将 ** 学习速率 ** 设置为 0.1,这是 SGD 算法一个非常普遍的设置。

mod.init_optimizer(optimizer=‘sgd’, optimizer_params=((‘learning_rate’, 0.1), ))

复制代码
最后,终于可以开始训练网络了!我们会执行 50 个 ** 回合 (Epoch)** 的训练,也就是说,整个数据集需要在这个网络中(以 10 个样本为一批)运行 50 次。

mod.fit(train_iter, num_epoch=50)
INFO:root:Epoch[0] Train-accuracy=0.097500
INFO:root:Epoch[0] Time cost=0.085
INFO:root:Epoch[1] Train-accuracy=0.122500
INFO:root:Epoch[1] Time cost=0.074
INFO:root:Epoch[2] Train-accuracy=0.153750
INFO:root:Epoch[2] Time cost=0.087
INFO:root:Epoch[3] Train-accuracy=0.162500
INFO:root:Epoch[3] Time cost=0.082
INFO:root:Epoch[4] Train-accuracy=0.192500
INFO:root:Epoch[4] Time cost=0.094
INFO:root:Epoch[5] Train-accuracy=0.210000
INFO:root:Epoch[5] Time cost=0.108
INFO:root:Epoch[6] Train-accuracy=0.222500
INFO:root:Epoch[6] Time cost=0.104
INFO:root:Epoch[7] Train-accuracy=0.243750
INFO:root:Epoch[7] Time cost=0.110
INFO:root:Epoch[8] Train-accuracy=0.263750
INFO:root:Epoch[8] Time cost=0.101
INFO:root:Epoch[9] Train-accuracy=0.286250
INFO:root:Epoch[9] Time cost=0.097
INFO:root:Epoch[10] Train-accuracy=0.306250
INFO:root:Epoch[10] Time cost=0.100

INFO:root:Epoch[20] Train-accuracy=0.507500

INFO:root:Epoch[30] Train-accuracy=0.718750

INFO:root:Epoch[40] Train-accuracy=0.923750

INFO:root:Epoch[50] Train-accuracy=0.998750
INFO:root:Epoch[50] Time cost=0.077

复制代码
如你所见,训练的准确度有了飞速提升,50 个回合后已经接近 **99% 以上 **。似乎我们的网络已经从训练数据集中学成了。非常惊人!
但针对验证数据集执行的效果如何呢?
## 验证模型
随后将新的数据样本放入网络,例如剩下的那 20%** 尚未 ** 在训练中使用过的数据。
首先构建一个迭代器,这一次将使用 ** 验证 ** 样本和标签。

pred_iter = mx.io.NDArrayIter(data=X_valid,label=Y_valid, batch_size=batch)

复制代码
随后要使用 Module.iter\_predict() 函数,借此让样本在网络中运行。这样做的同时,还需要对 ** 预测的标签 **** 实际标签 ** 进行对比。我们需要追踪比分并显示 ** 验证准确度 **,即,网络针对验证数据集的执行效果到底如何。

pred_count = valid_count
correct_preds = total_correct_preds = 0
for preds, i_batch, batch in mod.iter_predict(pred_iter):
label = batch.label[0].asnumpy().astype(int)
pred_label = preds[0].asnumpy().argmax(axis=1)
correct_preds = np.sum(pred_label==label)
total_correct_preds = total_correct_preds + correct_preds
print(‘Validation accuracy: %2.2f’ % (1.0*total_correct_preds/pred_count))

复制代码
这个过程中发生了不少事 :)
iter\_predict() 返回了:
{1}
- i\_batch:批编号。
- batch:一个 NDArray 数组。这里它其实保存了一个 NDArray,其中存储了当前批的内容。我们将用它找出当前批中 10 个数据样本的标签,随后将其存储在名为 Label 的 Numpy array 中(10 个元素)。
- preds:也是一个 NDArray 数组。这里它保存了一个 NDArray,其中存储了当前批预测出的标签:对于每个样本,我们提供了 ** 所有 10 个分类预测出的可能性 **(10x10 矩阵)。因此我们将使用 argmax() 找出最高值的 ** 指数 **,即 ** 最可能的分类 **。所以 pred\_label 实际上是一个 10 元素数组,其中保存了当前批中每个数据样本预测出的分类。
{1}
随后我们需要使用 Numpy.sum() 将 label 和 pred\_label 中相等值的数量进行对比。
最后需要计算并显示验证准确度。
> 验证准确度:0.09
什么?只有 9%?** 真是太悲催了 **!如果你希望证明我们的数据集真的是随机的,那么你有证据了!
底线在于,我们确实可以通过训练神经网络学习 ** 任何东西 **,但如果数据本身是 ** 无意义的 **(例如我们本例中使用的数据),那么就什么都预测不出来。** 种瓜得瓜,种豆得豆 **
如果你已经读到这里,我猜你是真心希望看到本例的完整代码 ;) 请花些时间用你自己的数据进行验证,这才是学习的最佳方法。
代码已发布至 GitHub:[mxnet\_example1.py](https://gist.github.com/juliensimon/7cfef0423b0183e891774a289e156b49#file-mxnet_example1-py)。
## 后续内容:
- 第 4 篇:使用预训练模型进行图片分类(Inception v3)
- 第 5 篇:进一步了解预训练模型(VGG16 和 ResNet-152)
- 第 6 篇:通过树莓派进行实时物体检测(并让它讲话!)
** 作者 **:[Julien Simon](https://medium.com/@julsimon),** 阅读英文原文 **:[An introduction to the MXNet API?—?part 3](https://medium.com/@julsimon/an-introduction-to-the-mxnet-api-part-3-1803112ba3a8)
- - - - - -
感谢 [杜小芳](http://www.infoq.com/cn/author/%E6%9D%9C%E5%B0%8F%E8%8A%B3) 对本文的审校。
给 InfoQ 中文站投稿或者参与内容翻译工作,请邮件至 [editors@cn.infoq.com](mailto:editors@cn.infoq.com)。也欢迎大家通过新浪微博([@InfoQ](http://www.weibo.com/infoqchina),[@丁晓昀](http://weibo.com/u/1451714913)),微信(微信号:[InfoQChina](http://www.geekbang.org/ivtw))关注我们。
2017-07-13 17:393839
用户头像

发布了 283 篇内容, 共 101.6 次阅读, 收获喜欢 61 次。

关注

评论

发布
暂无评论
发现更多内容

自动化测试的痛点与发展趋势

老张

DevOps 自动化测试

电子元器件行业MES系统能解决哪些管理难题?

万界星空科技

工业互联网 制造业 电子元器件 mes 万界星空科技

下一代积木式智能组装编排,集成开发效率10倍提升

华为云开发者联盟

开发 华为云 华为云开发者联盟 DTSE Tech Talk

【教程】无法验证app需要互联网连接以验证是否信任开发者

雪奈椰子

【愚公系列】2024远控性能大解密!5款评价最高远控软件ToDesk、TeamViewer、向日葵、Parsec、AirDroid谁与争锋?

愚公搬代码

大模型在产品原型生成中的应用实践

得物技术

大前端

新学期提效神器汇总!男大女大们准备好了吗?

飞桨PaddlePaddle

百度 BAIDU 百度飞桨 AI应用 飞桨星河社区

新闻网站封锁AI爬虫 AI与新闻媒体博弈继续

郑州埃文科技

AI 爬虫

抖音详情API:视频内容获取与解析技巧

技术冰糖葫芦

API 接口

引领测试开发新风向:模型驱动测试的魔力

测试人

软件测试

MySQL数据库中SQL语句分几类?

小魏写代码

工作两年涨薪40%,揭秘我的学习之路!

霍格沃兹测试开发学社

5月17-19日 上海线下 · CSP直通车训练营 · CST导师亲授【名额有限,先到先得】

ShineScrum捷行

ScrumMaster 敏捷教练认证 上海线下、 Scrum专业认证

一次性搞定多任务!Python自动化复用浏览器技巧大揭秘

测试人

软件测试

引领测试开发新风向:模型驱动测试的魔力

测吧(北京)科技有限公司

测试

MediaHuman YouTube Downloader mac(YouTube视频下载工具) v3.9.9.88中文注册版

iMac小白

下一代积木式智能组装编排,集成开发效率10倍提升

华为云PaaS服务小智

华为云

测试人生 | 工作两年涨薪40%,揭秘我的学习之路!

测吧(北京)科技有限公司

测试

【FAQ】HarmonyOS SDK 闭源开放能力 —Map Kit

HMS Core

HarmonyOS

网络安全审计是什么意思?与等保测评有什么区别?

行云管家

网络安全 等保测评 网络安全审计

数字经济的主要产品及使用!

青否数字人

数字人

浪潮信息边缘服务器支持英特尔第五代至强处理器

财见

软件测试工作两年涨薪40%,揭秘我的学习之路!

测试人

软件测试

企业数据内控安全就用行云防水堡!不容错过!

行云管家

数据安全 数据泄露 企业数据 防水堡

SecGPT-Mini,一个在CPU上可体验的开源网络安全大模型

云起无垠

Ableton Live 12 Suite for mac(音乐制作工具) v12.0中文激活版

iMac小白

从Language Model到Chat Application:对话接口的设计与实现

阿里技术

application Language 设计与实现 对话接口

模型驱动测试引领测试开发新风向

霍格沃兹测试开发学社

2024-03-06:用go语言,每一种货币都给定面值val[i],和拥有的数量cnt[i], 想知道目前拥有的货币,在钱数为1、2、3...m时,能找零成功的钱数有多少? 也就是说当钱数的范围是1~

福大大架构师每日一题

福大大架构师每日一题

低代码平台与MES:智能制造的新篇章

万界星空科技

制造业 低代码平台 mes 万界星空科技 机器人组装行业

MediaHuman YouTube to MP3 Converter mac(YouTube音乐转MP3转换器) v3.9.9.88中文注册版

iMac小白

MXNet API入门 —第3篇_语言 & 开发_Julien Simon_InfoQ精选文章