写点什么

AWS 与微软合作发布 Gluon API 可快速构建机器学习模型

  • 2017-10-18
  • 本文字数:1422 字

    阅读完需:约 5 分钟

2017 年 10 月 12 日, AWS 与微软合作发布了 Gluon 开源项目,该项目旨在帮助开发者更加简单快速的构建机器学习模型,同时保留了较好的性能。

根据 Gluon 项目官方 Github 页面上的描述,Gluon API 支持任意一种深度学习框架,其相关规范已经在 Apache MXNet 项目中实施,开发者只需安装最新版本的 MXNet(master)即可体验。AWS 用户可以创建一个AWS Deep Learning AMI 进行体验。

该页面提供了一段简易使用说明,摘录如下:

本教程以一个两层神经网络的构建和训练为例,我们将它称呼为多层感知机(multilayer perceptron)。(本示范建议使用Python 3.3 或以上,并且使用 Jupyter notebook 来运行。详细教程可参考这个页面。)

首先,进行如下引用声明:

复制代码
import mxnet as mx
from mxnet import gluon, autograd, ndarray
import numpy as np

然后,使用gluon.data.DataLoader承载训练数据和测试数据。这个 DataLoader 是一个 iterator 对象类,非常适合处理规模较大的数据集。

复制代码
train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=lambda data, label: (data.astype(np.float32)/255, label)),
batch_size=32, shuffle=True)
test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=lambda data, label: (data.astype(np.float32)/255, label)),
batch_size=32, shuffle=False)

接下来,定义神经网络:

复制代码
# 先把模型做个初始化
net = gluon.nn.Sequential()
# 然后定义模型架构
with net.name_scope():
net.add(gluon.nn.Dense(128, activation="relu")) # 第一层设置 128 个节点
net.add(gluon.nn.Dense(64, activation="relu")) # 第二层设置 64 个节点
net.add(gluon.nn.Dense(10)) # 输出层

然后把模型的参数设置一下:

复制代码
# 先随机设置模型参数
# 数值从一个标准差为 0.05 正态分布曲线里面取
net.collect_params().initialize(mx.init.Normal(sigma=0.05))
# 使用 softmax cross entropy loss 算法
# 计算模型的预测能力
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
# 使用随机梯度下降算法 (sgd) 进行训练
# 并且将学习率的超参数设置为 .1
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})

之后就可以开始跑训练了,一共分四个步骤。一、把数据放进去;二、在神经网络模型算出输出之后,比较其与实际结果的差距;三、用 Gluon 的autograd计算模型各参数对此差距的影响;四、用 Gluon 的trainer方法优化这些参数以降低差距。以下我们先让它跑 10 轮的训练:

复制代码
epochs = 10
for e in range(epochs):
for i, (data, label) in enumerate(train_data):
data = data.as_in_context(mx.cpu()).reshape((-1, 784))
label = label.as_in_context(mx.cpu())
with autograd.record(): # Start recording the derivatives
output = net(data) # the forward iteration
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(data.shape[0])
# Provide stats on the improvement of the model over each epoch
curr_loss = ndarray.mean(loss).asscalar()
print("Epoch {}. Current Loss: {}.".format(e, curr_loss))

若想了解更多 Gluon 说明与用法,可以查看 gluon.mxnet.io 这个网站。

2017-10-18 20:242266

评论

发布
暂无评论
发现更多内容

Go 专栏|变量和常量的声明与赋值

AlwaysBeta

Go 语言

从一个并发异常问题引起的想法

卢卡多多

并发编程 9月日更

Go 专栏|流程控制,一网打尽

AlwaysBeta

Go 语言

Electron团队为什么要干掉remote模块

刘晓伦

Electron Node

升级mysql-connector-java-8.x踩坑纪实

小江

Java MySQL 时间戳 服务器时区 夏令时

【HTML5游戏】从敲打空格键开始

devpoint

HTML5游戏 9月日更

未来10年,5个C/C++吃香的细分领域技术

奔着腾讯去

云原生 网络安全 音视频 DPDK 虚拟化技术

链路压测中各接口性能统计

FunTester

性能测试 测试框架 测试开发 FunTester 链路测试

Go 专栏|错误处理:defer,panic 和 recover

AlwaysBeta

Go 语言

Go 专栏|说说方法

AlwaysBeta

Go 语言

在线JSON转JAVA工具

入门小站

工具

(深入篇)漫游语音识别技术—带你走进语音识别技术的世界

RTE开发者社区

深度学习 音视频 语音识别

Go 专栏|基础数据类型:整数、浮点数、复数、布尔值和字符串

AlwaysBeta

Go 语言

Linux之lastlog命令

入门小站

Linux

HTTP系列之:HTTP中的cookies

程序那些事

Java 网络协议 HTTP cookies

Go 专栏|函数那些事

AlwaysBeta

Go 语言

多线程知识体系01-线程池源码阅读讲解-Executor

小马哥

多线程 高并发 源码阅读 源码剖析 日更

Java + opencv 实现图片修复(图片去水印)

张音乐

Java OpenCV 音视频 9月日更 图片去水印

线程同步类CyclicBarrier在性能测试集合点应用

FunTester

多线程 性能测试 线程安全 测试框架 FunTester

【报名】飞桨中国行丨企业零门槛AI创新应用-智能制造专场

百度大脑

人工智能

【Flutter 专题】58 图解 Flutter 嵌入原生 AndroidView 小尝试

阿策小和尚

Flutter 小菜 0 基础学习 Flutter Android 小菜鸟 9月日更

【LeetCode】 二叉树中和为某一值的路径Java题解

Albert

算法 LeetCode 9月日更

Go 专栏|复合数据类型:数组和切片 slice

AlwaysBeta

Go 语言

Go 专栏|复合数据类型:字典 map 和 结构体 struct

AlwaysBeta

Go 语言

Vue进阶(九十):过滤器

No Silver Bullet

Vue 9月日更

ShardingSphere 语句解析生成初探

源码 ShardingSphere

LeetCode刷题278-简单-第一个错误版本

ベ布小禅

9月日更

Linux内核四大核心框架

hanaper

Go 专栏|接口 interface

AlwaysBeta

Go 语言

【Vue2.x 源码学习】第四十三篇 - 组件部分 - 组件相关流程总结

Brave

源码 vue2 9月日更

模块(二)如何设计架构

我是一只小小鸟

AWS与微软合作发布Gluon API 可快速构建机器学习模型_微软_sai_InfoQ精选文章