AICon全球人工智能与机器学习技术大会周四开幕,点击查看完整日程>> 了解详情
写点什么

TensorFlow 工程实战(二):用 tf.layers API 在动态图上识别手写数字

  • 2019 年 8 月 14 日
  • 本文字数:2546 字

    阅读完需:约 8 分钟

TensorFlow工程实战(二):用tf.layers API在动态图上识别手写数字

通过使用卷积网络在 MNIST 数据集上进行识别任务的实例,演示如何用 tf.layers API 构建模型,并在动态图中进行训练。

本文摘选自电子工业出版社出版、李金洪编著的《深度学习之TensorFlow工程化项目实战》一书的实例 21:用 tf.layers API 在动态图上识别手写数字。


实例描述

有一组手写数字图片,要求用 tf.layers API 在动态图上搭建模型,将其识别出来。


一、启动动态图并加载手写图片数据集

本例加载 TFDS 模块中集成好的 MNIST 数据集。该数据集常用于验证模型的功能性实验中。


用 tf.enable_eager_execution 函数启动动态图,并加载 MNIST 数据集。具体代码如下:


代码 1 tf_layer 模型


import tensorflow as tfimport tensorflow.contrib.eager as tfetf.enable_eager_execution()print("TensorFlow 版本: {}".format(tf.VERSION))import tensorflow_datasets as tfdsimport numpy as np#加载训练和验证数据集ds_train, ds_test = tfds.load(name="mnist", split=["train", "test"])ds_train = ds_train.shuffle(1000).batch(10).prefetch(tf.data.experimental.AUTOTUNE)
复制代码


代码第 8 行,调用 tfds.load 方法加载 MNIST 数据集。该方法返回的两个变量 ds_train 与 ds_test 都属于 DatasetV1Adapter 类型。DatasetV1Adapter 类型的数据集的使用方式与 tf.data.Dataset 接口的数据集的使用方式非常相似。


代码第 9 行,对数据集 ds_train 进行打乱顺序、按批次组合和设置缓存操作。


二、定义模型的类

下面定义 MNISTModel 类对模型进行封装。MNISTModel 类继承于 tf.layers.Layer 类。其中有两个方法——__init__与 call。


  • __init__用于定义网络的各个操作层。本实例中所用到的卷积网络、全连接网络都是用 tf.layers 实现的,其用法与 TF-slim 接口非常相似。

  • call 用于将网络中的各层链接起来,形成正向运算的神经网络。


整个网络结构是:卷积操作+最大池化+卷积操作+最大池化+全连接+dropout 方法+全连接。其中,卷积和池化部分在第 8 章还会深入探讨。


全连接是最基础的神经网络模型之一,该网络的结构是将所有的下层节点与每一个上层节点全部连在一起。


dropout 是一种改善过拟合的方法。通过随机丢弃部分网络节点来忽略数据集中的小概率样本。


具体代码如下:


代码 1 tf_layer 模型(续)


class MNISTModel(tf.layers.Layer):              #定义模型类  def __init__(self, name):    super(MNISTModel, self).__init__(name=name)
self._input_shape = [-1, 28, 28, 1] #定义输入形状 #定义卷积层 self.conv1 =tf.layers.Conv2D(32, 5, activation=tf.nn.relu) #定义卷积层 self.conv2 = tf.layers.Conv2D(64, 5, activation=tf.nn.relu) #定义全连接层 self.fc1 =tf.layers.Dense(1024, activation=tf.nn.relu) self.fc2 = tf.layers.Dense(10) self.dropout = tf.layers.Dropout(0.5) #定义dropout层 #定义池化层 self.max_pool2d = tf.layers.MaxPooling2D( (2, 2), (2, 2), padding='SAME')
def call(self, inputs, training): #定义call方法 x = tf.reshape(inputs, self._input_shape) #将网络连接起来 x = self.conv1(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.max_pool2d(x) x = tf.keras.layers.Flatten()(x) x = self.fc1(x) if training: x = self.dropout(x) x = self.fc2(x) return x
复制代码


三、定义网络的反向传播

定义 loss 函数,并建立优化器及梯度 OP。具体代码如下:


代码 1 tf_layer 模型(续)


def loss(model,inputs, labels):    predictions = model(inputs, training=True)        cost = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=predictions, labels=labels )    return tf.reduce_mean( cost )#训练optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)grad = tfe.implicit_gradients(loss)
复制代码


四、训练模型

定义 loss 函数,并建立优化器及梯度 OP。具体代码如下:


代码 1 tf_layer 模型(续)


model = MNISTModel("net")           #实例化模型global_step = tf.train.get_or_create_global_step()for epoch in range(1):          #按照指定次数迭代数据集    for i,data  in enumerate (ds_train):                        inputs, targets = tf.cast( data["image"],tf.float32), data["label"]        optimizer.apply_gradients(grad( model,inputs, targets) , global_step=global_step)        if i % 100 == 0:          print("Step %d: Loss on training set : %f" %            (i, loss(model,inputs, targets).numpy()))          #获取要保存的变量          all_variables = ( model.variables + optimizer.variables() + [global_step])          tfe.Saver(all_variables).save(    #生成检查点文件          "./tfelog/linermodel.cpkt", global_step=global_step)ds = tfds.as_numpy(ds_test.batch(100))onetestdata = next(ds)print("Loss on test set: %f" % loss( model,onetestdata["image"].astype(np.float32), onetestdata["label"]).numpy())
复制代码


代码第 11 行(书中第 57 行),手动将要保存的文件一起传入 tfe.Saver 进行保存。这是动态图接口使用起来相对不方便的地方。它并不能自动将全局的变量都搜集起来。


代码运行后,输出以下结果:


TensorFlow 版本: 1.13.1


Step 0: Loss on training set : 2.252767


……


Step 5600: Loss on training set : 0.002125


Loss on test set: 0.055677


本文摘选自电子工业出版社出版、李金洪编著的《深度学习之TensorFlow工程化项目实战》一书,更多实战内容点此查看。



本文经授权发布,转载请联系电子工业出版社。


系列文章:


TensorFlow 工程实战(一):用 TF-Hub 库微调模型评估人物年龄


TensorFlow 工程实战(二):用 tf.layers API 在动态图上识别手写数字(本文)


TensorFlow 工程实战(三):结合知识图谱实现电影推荐系统


TensorFlow 工程实战(四):使用带注意力机制的模型分析评论者是否满意


TensorFlow 工程实战(五):构建 DeblurGAN 模型,将模糊相片变清晰


TensorFlow 工程实战(六):在 iPhone 手机上识别男女并进行活体检测


2019 年 8 月 14 日 08:008211

评论

发布
暂无评论
发现更多内容

「央视新闻」网上带人回血上岸靠谱吗《手机搜狐网》

王城

「央视新闻」专业带回血的导师靠谱吗《手机搜狐网》

王城

「人民日报」 网上一对一带回血的可靠吗《手机搜狐网》

王城

前端避坑指南丨辛辛苦苦开发的APP竟然被判定为简单网页打包?

APICloud

「央视新闻」实力导师带赚回血上岸《手机搜狐网》

王城

技巧哦

「央视网」求带回血的良心导师《特图》

王城

技巧

第四范式x英特尔“AI应用与异构内存编程挑战赛”圆满收官

第四范式开发者社区

「央视新闻」一对一带人回血是真的吗《手机搜狐网》

王城

「央视新闻」有没有靠谱的回血老师《手机搜狐网》

王城

「人民日报」1万本金能回血40万吗《手机搜狐网》

王城

「央视新闻」感谢老师带我回血《手机搜狐网》

王城

技巧

混沌工程:分布式系统稳定性的“疫苗”

中原银行

微服务 云原生 混沌工程

「央视新闻」网赌回血上岸真实案例《手机搜狐网》

王城

技巧

「央视新闻」最厉害的回血导师《手机搜狐网》

王城

技巧

「央视新闻」回血上岸最快的方法《手机搜狐网》

王城

技巧

「央视新闻」精准回血计划老师QQ《手机搜狐网》

王城

「人民日报」有没有真心带回血上岸《手机搜狐网》

王城

「央视新闻」2千想回血10万需要多久《手机搜狐网》

王城

技巧

「央视新闻」有没有真正带你回血的导师《手机搜狐网》

王城

「央视新闻」输了十几万求回血办法《手机搜狐网》

王城

技巧

「人民日报」真正有实力的回血导师是哪个《手机搜狐网》

王城

「央视新闻」精准计划带你分分钟回血《手机搜狐网》

王城

「央视新闻」导师真正的回血平台《手机搜狐网》

王城

技巧

「央视新闻」 真正有实力带回血上岸的导师《手机搜狐网》

王城

「央视新闻」带你回血的导师是真的吗《手机搜狐网》

王城

「央视新闻」回血上岸计划导师QQ《手机搜狐网》

王城

技巧

「央视新闻」金牌回血上岸导师《手机搜狐网》

王城

技巧

「人民日报」2万本金回血15个的方法《手机搜狐网》

王城

「央视新闻」带人回血的高级导师《手机搜狐网》

王城

技巧

「人民日报」真有人能带人回血上岸吗《手机搜狐网》

王城

「央视新闻」快三带人回血骗局《手机搜狐网》

王城

技巧

数据cool谈(第2期)寻找下一代企业级数据库

数据cool谈(第2期)寻找下一代企业级数据库

TensorFlow工程实战(二):用tf.layers API在动态图上识别手写数字-InfoQ