前端未来的主流技术方向有哪些?腾讯、京东、同城旅行等大厂都是怎么布局的?戳此了解 了解详情
写点什么

TensorFlow 工程实战(二):用 tf.layers API 在动态图上识别手写数字

2019 年 8 月 14 日

TensorFlow工程实战(二):用tf.layers API在动态图上识别手写数字

通过使用卷积网络在 MNIST 数据集上进行识别任务的实例,演示如何用 tf.layers API 构建模型,并在动态图中进行训练。

本文摘选自电子工业出版社出版、李金洪编著的《深度学习之TensorFlow工程化项目实战》一书的实例 21:用 tf.layers API 在动态图上识别手写数字。


实例描述

有一组手写数字图片,要求用 tf.layers API 在动态图上搭建模型,将其识别出来。


一、启动动态图并加载手写图片数据集

本例加载 TFDS 模块中集成好的 MNIST 数据集。该数据集常用于验证模型的功能性实验中。


用 tf.enable_eager_execution 函数启动动态图,并加载 MNIST 数据集。具体代码如下:


代码 1 tf_layer 模型


import tensorflow as tfimport tensorflow.contrib.eager as tfetf.enable_eager_execution()print("TensorFlow 版本: {}".format(tf.VERSION))import tensorflow_datasets as tfdsimport numpy as np#加载训练和验证数据集ds_train, ds_test = tfds.load(name="mnist", split=["train", "test"])ds_train = ds_train.shuffle(1000).batch(10).prefetch(tf.data.experimental.AUTOTUNE)
复制代码


代码第 8 行,调用 tfds.load 方法加载 MNIST 数据集。该方法返回的两个变量 ds_train 与 ds_test 都属于 DatasetV1Adapter 类型。DatasetV1Adapter 类型的数据集的使用方式与 tf.data.Dataset 接口的数据集的使用方式非常相似。


代码第 9 行,对数据集 ds_train 进行打乱顺序、按批次组合和设置缓存操作。


二、定义模型的类

下面定义 MNISTModel 类对模型进行封装。MNISTModel 类继承于 tf.layers.Layer 类。其中有两个方法——__init__与 call。


  • __init__用于定义网络的各个操作层。本实例中所用到的卷积网络、全连接网络都是用 tf.layers 实现的,其用法与 TF-slim 接口非常相似。

  • call 用于将网络中的各层链接起来,形成正向运算的神经网络。


整个网络结构是:卷积操作+最大池化+卷积操作+最大池化+全连接+dropout 方法+全连接。其中,卷积和池化部分在第 8 章还会深入探讨。


全连接是最基础的神经网络模型之一,该网络的结构是将所有的下层节点与每一个上层节点全部连在一起。


dropout 是一种改善过拟合的方法。通过随机丢弃部分网络节点来忽略数据集中的小概率样本。


具体代码如下:


代码 1 tf_layer 模型(续)


class MNISTModel(tf.layers.Layer):              #定义模型类  def __init__(self, name):    super(MNISTModel, self).__init__(name=name)
self._input_shape = [-1, 28, 28, 1] #定义输入形状 #定义卷积层 self.conv1 =tf.layers.Conv2D(32, 5, activation=tf.nn.relu) #定义卷积层 self.conv2 = tf.layers.Conv2D(64, 5, activation=tf.nn.relu) #定义全连接层 self.fc1 =tf.layers.Dense(1024, activation=tf.nn.relu) self.fc2 = tf.layers.Dense(10) self.dropout = tf.layers.Dropout(0.5) #定义dropout层 #定义池化层 self.max_pool2d = tf.layers.MaxPooling2D( (2, 2), (2, 2), padding='SAME')
def call(self, inputs, training): #定义call方法 x = tf.reshape(inputs, self._input_shape) #将网络连接起来 x = self.conv1(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.max_pool2d(x) x = tf.keras.layers.Flatten()(x) x = self.fc1(x) if training: x = self.dropout(x) x = self.fc2(x) return x
复制代码


三、定义网络的反向传播

定义 loss 函数,并建立优化器及梯度 OP。具体代码如下:


代码 1 tf_layer 模型(续)


def loss(model,inputs, labels):    predictions = model(inputs, training=True)        cost = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=predictions, labels=labels )    return tf.reduce_mean( cost )#训练optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)grad = tfe.implicit_gradients(loss)
复制代码


四、训练模型

定义 loss 函数,并建立优化器及梯度 OP。具体代码如下:


代码 1 tf_layer 模型(续)


model = MNISTModel("net")           #实例化模型global_step = tf.train.get_or_create_global_step()for epoch in range(1):          #按照指定次数迭代数据集    for i,data  in enumerate (ds_train):                        inputs, targets = tf.cast( data["image"],tf.float32), data["label"]        optimizer.apply_gradients(grad( model,inputs, targets) , global_step=global_step)        if i % 100 == 0:          print("Step %d: Loss on training set : %f" %            (i, loss(model,inputs, targets).numpy()))          #获取要保存的变量          all_variables = ( model.variables + optimizer.variables() + [global_step])          tfe.Saver(all_variables).save(    #生成检查点文件          "./tfelog/linermodel.cpkt", global_step=global_step)ds = tfds.as_numpy(ds_test.batch(100))onetestdata = next(ds)print("Loss on test set: %f" % loss( model,onetestdata["image"].astype(np.float32), onetestdata["label"]).numpy())
复制代码


代码第 11 行(书中第 57 行),手动将要保存的文件一起传入 tfe.Saver 进行保存。这是动态图接口使用起来相对不方便的地方。它并不能自动将全局的变量都搜集起来。


代码运行后,输出以下结果:


TensorFlow 版本: 1.13.1


Step 0: Loss on training set : 2.252767


……


Step 5600: Loss on training set : 0.002125


Loss on test set: 0.055677


本文摘选自电子工业出版社出版、李金洪编著的《深度学习之TensorFlow工程化项目实战》一书,更多实战内容点此查看。



本文经授权发布,转载请联系电子工业出版社。


系列文章:


TensorFlow 工程实战(一):用 TF-Hub 库微调模型评估人物年龄


TensorFlow 工程实战(二):用 tf.layers API 在动态图上识别手写数字(本文)


TensorFlow 工程实战(三):结合知识图谱实现电影推荐系统


TensorFlow 工程实战(四):使用带注意力机制的模型分析评论者是否满意


TensorFlow 工程实战(五):构建 DeblurGAN 模型,将模糊相片变清晰


TensorFlow 工程实战(六):在 iPhone 手机上识别男女并进行活体检测


2019 年 8 月 14 日 08:008150

评论

发布
暂无评论
发现更多内容

基于机器学习的自动化测试弹窗处理实践

bilibili游戏技术

手机 自动化测试 yolo 弹窗

文件IO

拾贝

Java 版学生成绩管理系统,附源码!

村雨遥

Java 6 月日更

谁也讲不明白的SQL注入攻击被我讲明白了(中)?

网络安全学海

程序员 网络安全 计算机 渗透测试 SQL注入

创业者需要知道的13种思维模型

俞凡

创业 认知

sftp的使用

拾贝

(序)【Spring源码专题】展开Spring源码构建之旅(利用IDEA和Gradle)

李浩宇/Alex

spring 6月日更 6 月日更 源码搭建

2021最新Spring Security知识梳理

北游学Java

Java spring

硬刚Hbase - 17道题你能秒我?我Hbase八股文反手就甩你一脸

王知无

硬刚Presto | Presto原理&调优&面试&实战全面升级版

王知无

硬刚数据仓库|SQL Boy的福音之数据仓库体系建模&实施&注意事项小总结

王知无

[译] D8 优化

Antway

6 月日更

架构实践营模块7作业

Geek_649372

架构训练营

如何做好技术选型和分析决策

Man

技术选型 CMMI

zip解压缩

拾贝

demo

秦时明月

架构实战训练营 - 模块七课后作业

Johnny

架构实战营

这样理解Mysql索引,阿里面试官也给你点赞

慕枫技术笔记

MySQL 后端 索引

「SQL数据分析系列」10. 重谈连接

数据与智能

数据库 sql 连接

我的书要出版啦~

石璞东

深度学习 tensorflow 计算机视觉 前端开发 卷积神经网络

kubelet分析-csi driver注册分析-Node Driver Registrar源码分析

良凯尔

源码 Kubernetes CSI Kubernetes Plugin

基于FPGA系统合成两条视频流实现3D视频效果

不脱发的程序猿

智能硬件 FPGA系统 视频流 合成3D视频

最佳的管理者-库克

卢卡多多

苹果 管理者 六月日更

计算机网络概述

若尘

计算机网络 六月日更

硬刚Apache Iceberg | 技术调研&在各大公司的实践应用大总结

王知无

硬刚ClickHouse | 4万字长文ClickHouse基础&实践&调优全视角解析

王知无

可编程网关 Pipy 第三弹:事件模型设计

张晓辉

Spring事件发布与监听机制

陈皮的JavaLib

Java spring 事件监听

硬刚Hive | 4万字基础调优面试小总结

王知无

硬刚用户画像(一) | 标签体系下的用户画像建设小指南

王知无

Data Mesh,数据网格的道与术

王知无

技术为帆,纵横四海- Lazada技术东南亚探索和成长之旅

技术为帆,纵横四海- Lazada技术东南亚探索和成长之旅

TensorFlow工程实战(二):用tf.layers API在动态图上识别手写数字-InfoQ