TensorFlow 工程实战(二):用 tf.layers API 在动态图上识别手写数字

阅读数:7461 2019 年 8 月 14 日 08:00

TensorFlow工程实战(二):用tf.layers API在动态图上识别手写数字

通过使用卷积网络在 MNIST 数据集上进行识别任务的实例,演示如何用 tf.layers API 构建模型,并在动态图中进行训练。

本文摘选自电子工业出版社出版、李金洪编著的《深度学习之 TensorFlow 工程化项目实战》一书的实例 21:用 tf.layers API 在动态图上识别手写数字。

实例描述

有一组手写数字图片,要求用 tf.layers API 在动态图上搭建模型,将其识别出来。

一、启动动态图并加载手写图片数据集

本例加载 TFDS 模块中集成好的 MNIST 数据集。该数据集常用于验证模型的功能性实验中。

用 tf.enable_eager_execution 函数启动动态图,并加载 MNIST 数据集。具体代码如下:

代码 1 tf_layer 模型

复制代码
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()
print("TensorFlow 版本: {}".format(tf.VERSION))
import tensorflow_datasets as tfds
import numpy as np
#加载训练和验证数据集
ds_train, ds_test = tfds.load(name="mnist", split=["train", "test"])
ds_train = ds_train.shuffle(1000).batch(10).prefetch(tf.data.experimental.AUTOTUNE)

代码第 8 行,调用 tfds.load 方法加载 MNIST 数据集。该方法返回的两个变量 ds_train 与 ds_test 都属于 DatasetV1Adapter 类型。DatasetV1Adapter 类型的数据集的使用方式与 tf.data.Dataset 接口的数据集的使用方式非常相似。

代码第 9 行,对数据集 ds_train 进行打乱顺序、按批次组合和设置缓存操作。

二、定义模型的类

下面定义 MNISTModel 类对模型进行封装。MNISTModel 类继承于 tf.layers.Layer 类。其中有两个方法——__init__ 与 call。

  • __init__ 用于定义网络的各个操作层。本实例中所用到的卷积网络、全连接网络都是用 tf.layers 实现的,其用法与 TF-slim 接口非常相似。
  • call 用于将网络中的各层链接起来,形成正向运算的神经网络。

整个网络结构是:卷积操作 + 最大池化 + 卷积操作 + 最大池化 + 全连接 +dropout 方法 + 全连接。其中,卷积和池化部分在第 8 章还会深入探讨。

全连接是最基础的神经网络模型之一,该网络的结构是将所有的下层节点与每一个上层节点全部连在一起。

dropout 是一种改善过拟合的方法。通过随机丢弃部分网络节点来忽略数据集中的小概率样本。

具体代码如下:

代码 1 tf_layer 模型(续)

复制代码
class MNISTModel(tf.layers.Layer): #定义模型类
def __init__(self, name):
super(MNISTModel, self).__init__(name=name)
self._input_shape = [-1, 28, 28, 1] #定义输入形状
#定义卷积层
self.conv1 =tf.layers.Conv2D(32, 5, activation=tf.nn.relu)
#定义卷积层
self.conv2 = tf.layers.Conv2D(64, 5, activation=tf.nn.relu)
#定义全连接层
self.fc1 =tf.layers.Dense(1024, activation=tf.nn.relu)
self.fc2 = tf.layers.Dense(10)
self.dropout = tf.layers.Dropout(0.5) #定义 dropout 层
#定义池化层
self.max_pool2d = tf.layers.MaxPooling2D(
(2, 2), (2, 2), padding='SAME')
def call(self, inputs, training): #定义 call 方法
x = tf.reshape(inputs, self._input_shape) #将网络连接起来
x = self.conv1(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.max_pool2d(x)
x = tf.keras.layers.Flatten()(x)
x = self.fc1(x)
if training:
x = self.dropout(x)
x = self.fc2(x)
return x

三、定义网络的反向传播

定义 loss 函数,并建立优化器及梯度 OP。具体代码如下:

代码 1 tf_layer 模型(续)

复制代码
def loss(model,inputs, labels):
predictions = model(inputs, training=True)
cost = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=predictions, labels=labels )
return tf.reduce_mean( cost )
#训练
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
grad = tfe.implicit_gradients(loss)

四、训练模型

定义 loss 函数,并建立优化器及梯度 OP。具体代码如下:

代码 1 tf_layer 模型(续)

复制代码
model = MNISTModel("net") #实例化模型
global_step = tf.train.get_or_create_global_step()
for epoch in range(1): #按照指定次数迭代数据集
for i,data in enumerate (ds_train):
inputs, targets = tf.cast( data["image"],tf.float32), data["label"]
optimizer.apply_gradients(grad( model,inputs, targets) , global_step=global_step)
if i % 100 == 0:
print("Step %d: Loss on training set : %f" %
(i, loss(model,inputs, targets).numpy()))
#获取要保存的变量
all_variables = ( model.variables + optimizer.variables() + [global_step])
tfe.Saver(all_variables).save( #生成检查点文件
"./tfelog/linermodel.cpkt", global_step=global_step)
ds = tfds.as_numpy(ds_test.batch(100))
onetestdata = next(ds)
print("Loss on test set: %f" % loss( model,onetestdata["image"].astype(np.float32), onetestdata["label"]).numpy())

代码第 11 行(书中第 57 行),手动将要保存的文件一起传入 tfe.Saver 进行保存。这是动态图接口使用起来相对不方便的地方。它并不能自动将全局的变量都搜集起来。

代码运行后,输出以下结果:

TensorFlow 版本: 1.13.1
Step 0: Loss on training set : 2.252767
……
Step 5600: Loss on training set : 0.002125
Loss on test set: 0.055677

本文摘选自电子工业出版社出版、李金洪编著的《深度学习之 TensorFlow 工程化项目实战》一书,更多实战内容点此查看。

TensorFlow工程实战(二):用tf.layers API在动态图上识别手写数字

本文经授权发布,转载请联系电子工业出版社。

系列文章:

TensorFlow 工程实战(一):用 TF-Hub 库微调模型评估人物年龄
TensorFlow 工程实战(二):用 tf.layers API 在动态图上识别手写数字(本文)
TensorFlow 工程实战(三):结合知识图谱实现电影推荐系统
TensorFlow 工程实战(四):使用带注意力机制的模型分析评论者是否满意
TensorFlow 工程实战(五):构建 DeblurGAN 模型,将模糊相片变清晰
TensorFlow 工程实战(六):在 iPhone 手机上识别男女并进行活体检测

收藏

评论

微博

用户头像
发表评论

注册/登录 InfoQ 发表评论