【AICon】AI 基础设施、LLM运维、大模型训练与推理,一场会议,全方位涵盖! >>> 了解详情
写点什么

TensorFlow 2.0 迁移学习实践指南

  • 2019-12-05
  • 本文字数:10948 字

    阅读完需:约 36 分钟

TensorFlow 2.0迁移学习实践指南


阅读深度学习论文总是很有趣,也很有教育意义,特别是当这些论文和你现在做的项目属于同一领域时更是如此。但是,这些论文包含的架构和解决方案通常很难训练,特别是当你想去尝试他们的方法时,比如说 ILSCVR(ImageNet Large Scale Visual Recognition)竞赛中的一些获奖者的方法。我记得我在读 VGG16 的论文时就在想“这个方法很酷,但是我的 GPU 跑这个网络时都快挂了。”为了能轻松使用这些网络,Tensorflow 2 提供了大量的预训练模型,你可以很快用上它们。而本文,我们将介绍怎样通过一些有名的 CNN(Convolutional Neural Network)架构来训练这些论文里介绍的新的神经网络模型。


这时你可能会问“预训练模型是什么?”。本质上来说,预训练模型是之前在大数据集上已经训练好并保存下来的模型,比如说在 ImageNet 数据集上训练的模型。这些模型可以在 tensorflow.keras.applications 模块里找到。有两种方式使用这些预训练模型,你可以直接使用它们,或者通过迁移学习使用它们。由于大数据集通常用于某种全局解,所以你可以让预训练模型定制化,使其特别针对性地解决某个特定的问题。通过这个方式,你可以在训练时利用一些最有名的神经网络,不会损失太多的训练时间和计算资源。另外,你可以选定网络里的一些层,修改这些层的行为,实现这些模型的微调。我们在后面的文章里会讲到这一点。

架构

在本文中,我们使用 3 个预训练模型来解决分类问题的一个例子:VGG16、GoogLeNet(Inception)和 ResNet。这每一个架构都赢得了当年的 ILSCVR 竞赛。2014 年,VGG16 与 GooLeNet 有着相同的最好成绩,而 ResNet 赢得了 2015 年的竞赛。这些模型是 Tensorflow 2 中 tensorflow.keras.applications 模块的一部分。让我们深入探究一下这几个模型。


我们首先看一下 VGG16 这个架构。它是一个大型的卷积神经网络,由 K. Simonyan 和 A. Zisserman 在“Very Deep Convolutional Networks for Large-Scale Image Recognition”这篇论文里提出。这个网络在 ImageNet 数据集上达到了 92.7%的 top-5 测试精确度。但是,训练这个网络需要好几周。下图是这个模型的高层概览:



VGG16 架构


GoogLeNet 也被称为 Inception,这是因为它使用了两个概念:1x1 卷积和 Inception 模块。第一个概念中,1x1 卷积用于降维的模块。通过降维,计算量也会减少,这也就意味着网络的深度和宽度可以增加了。GooLeNet 使用了 Inception 模块,每个卷积层的大小都不相同。



带有降维功能的 Inception 模块


如图所示,1x1 卷积层、3x3 卷积层、5x5 卷积层和 3x3 最大池化层操作组合在了一起,然后这些层的运行结果会在输出节点处堆叠在一起。GooLeNet 总共有 22 层,看起来像下面这样:



本文中,我们要使用的最后一个网络架构是残差网络,或者称作 ResNet。前面提到的网络的问题在于它们太深了,它们有太多层,导致很难训练(因为梯度消失)。所以,ResNet 使用所谓的“identity shortcut connection”(或者称作残差模块)来解决这个问题。



带有降维和不带降维的残差模块


本质上来讲,ResNet 沿用了 VGG 的 3x3 卷积层的设计,每层卷积后面都有一个 Batch Normalization 层和 ReLu 激活函数。但是,差异点在于我们的 ResNet 在最后一个 ReLu 前插入了 input 节点。另一个变种是,输入值(input value)传入了 1x1 卷积层。

数据集

在本文中,我们使用“Cats vs Dogs”的数据集。这个数据集包含了 23,262 张猫和狗的图像。



你可能注意到了,这些照片没有归一化,它们的大小是不一样的。但是非常棒的一点是,你可以在 Tensorflow Datasets 中获取这个数据集。所以,确保你的环境里安装了 Tensorflow Dataset。


pip install tensorflow-dataset
复制代码


和这个库中的其他数据集不同,这个数据集没有划分成训练集和测试集,所以我们需要自己对这两类数据集做个区分。你可以在这里找到这个数据集的更多信息。

实现

这个实现分成了几个部分。首先,我们实现了一个类,其负责载入数据和准备数据。然后,我们导入预训练模型,构建一个用于修改最顶端的几层网络。最后,我们把训练过程运行起来,并进行评估。当然,在这之前,我们必须导入一些代码库,定义一些全局常量:


import numpy as npimport matplotlib.pyplot as plt
import tensorflow as tfimport tensorflow_datasets as tfds
IMG_SIZE = 160BATCH_SIZE = 32SHUFFLE_SIZE = 1000IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
复制代码


好,让我们仔细来看下实现!

数据载入器

这个类负责载入数据和准备数据,用于后续的数据处理。以下是这个类的实现:


class DataLoader(object):    def __init__(self, image_size, batch_size):                self.image_size = image_size        self.batch_size = batch_size                # 80% train data, 10% validation data, 10% test data        split_weights = (8, 1, 1)        splits = tfds.Split.TRAIN.subsplit(weighted=split_weights)                (self.train_data_raw, self.validation_data_raw, self.test_data_raw), self.metadata = tfds.load(            'cats_vs_dogs', split=list(splits),            with_info=True, as_supervised=True)                # Get the number of train examples        self.num_train_examples = self.metadata.splits['train'].num_examples*80/100        self.get_label_name = self.metadata.features['label'].int2str                # Pre-process data        self._prepare_data()        self._prepare_batches()            # Resize all images to image_size x image_size    def _prepare_data(self):        self.train_data = self.train_data_raw.map(self._resize_sample)        self.validation_data = self.validation_data_raw.map(self._resize_sample)        self.test_data = self.test_data_raw.map(self._resize_sample)        # Resize one image to image_size x image_size    def _resize_sample(self, image, label):        image = tf.cast(image, tf.float32)        image = (image/127.5) - 1        image = tf.image.resize(image, (self.image_size, self.image_size))        return image, label        def _prepare_batches(self):        self.train_batches = self.train_data.shuffle(1000).batch(self.batch_size)        self.validation_batches = self.validation_data.batch(self.batch_size)        self.test_batches = self.test_data.batch(self.batch_size)       # Get defined number of  not processed images    def get_random_raw_images(self, num_of_images):        random_train_raw_data = self.train_data_raw.shuffle(1000)        return random_train_raw_data.take(num_of_images)
复制代码


这个类实现了很多功能,它实现了很多“public”方法


  • _prepare_data:内部方法,用于缩放和归一化数据集里的图像。构造函数需要用到该函数。

  • _resize_sample:内部方法,用于缩放单张图像。

  • _prepare_batches:内部方法,用于将图像打包创建为 batches。创建 train_batches、validation_batches 和 test_batches,分别用于训练、评估过程。

  • get_random_raw_images:这个方法用于从原始的、没有经过处理的数据中随机获取固定数量的图像。


但是,这个类的主要功能还是在构造函数中完成的。让我们仔细看看这个类的构造函数。


def __init__(self, image_size, batch_size):
self.image_size = image_size self.batch_size = batch_size
# 80% train data, 10% validation data, 10% test data split_weights = (8, 1, 1) splits = tfds.Split.TRAIN.subsplit(weighted=split_weights)
(self.train_data_raw, self.validation_data_raw, self.test_data_raw), self.metadata = tfds.load( 'cats_vs_dogs', split=list(splits), with_info=True, as_supervised=True)
# Get the number of train examples self.num_train_examples = self.metadata.splits['train'].num_examples*80/100 self.get_label_name = self.metadata.features['label'].int2str
# Pre-process data self._prepare_data() self._prepare_batches()
复制代码


首先我们通过传入参数定义了图像大小和 batch 大小。然后,由于该数据集本身没有区分训练集和测试集,我们通过划分权值对数据进行划分。这真是 Tensorflow Dataset 引入的非常棒的功能,因为我们可以留在 Tensorflow 生态系统中做这件事,我们不用引入其他的库(比如 Pandas 或者 Scikit Learn)。一旦我们执行了数据划分,我们就开始计算训练样本数量,然后调用辅助函数来为训练准备数据。在这之后,我们需要做的仅仅是实例化这个类的对象,然后载入数据即可。


data_loader = DataLoader(IMG_SIZE, BATCH_SIZE)
plt.figure(figsize=(10, 8))i = 0for img, label in data_loader.get_random_raw_images(20): plt.subplot(4, 5, i+1) plt.imshow(img) plt.title("{} - {}".format(data_loader.get_label_name(label), img.shape)) plt.xticks([]) plt.yticks([]) i += 1plt.tight_layout()plt.show()
复制代码


以下是输出结果:


基础模型 & Wrapper

下一个步骤就是载入预训练模型了。我们前面提到过,这些模型位于 tensorflow.kearas.applications。我们可以用下面的语句直接载入它们:


vgg16_base = tf.keras.applications.VGG16(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')googlenet_base = tf.keras.applications.InceptionV3(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')resnet_base = tf.keras.applications.ResNet101V2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
复制代码


这段代码就是我们创建上述三种网络结构基础模型的方式。注意,每个模型构造函数的 include_top 参数传入的是 false。这意味着这些模型是用于提取特征的。我们一旦创建了这些模型,我们就需要修改这些模型顶部的网络层,使之适用于我们的具体问题。我们使用 Wrapper 类来完成这个步骤。这个类接收预训练模型,然后添加一个 Global Average Polling Layer 和一个 Dense Layer。本质上,这最后的 Dense Layer 会用于我们的二分类问题(猫或狗)。Wrapper 类把所有这些元素都放到了一起,放在了同一个模型中。


class Wrapper(tf.keras.Model):    def __init__(self, base_model):        super(Wrapper, self).__init__()                self.base_model = base_model        self.average_pooling_layer = tf.keras.layers.GlobalAveragePooling2D()        self.output_layer = tf.keras.layers.Dense(1)            def call(self, inputs):        x = self.base_model(inputs)        x = self.average_pooling_layer(x)        output = self.output_layer(x)        return output
复制代码


然后我们就可以创建 Cats vs Dogs 分类问题的模型了,并且编译这个模型。


base_learning_rate = 0.0001
vgg16_base.trainable = Falsevgg16 = Wrapper(vgg16_base)vgg16.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate), loss='binary_crossentropy', metrics=['accuracy'])
googlenet_base.trainable = Falsegooglenet = Wrapper(googlenet_base)googlenet.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate), loss='binary_crossentropy', metrics=['accuracy'])
resnet_base.trainable = Falseresnet = Wrapper(resnet_base)resnet.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate), loss='binary_crossentropy', metrics=['accuracy'])
复制代码


注意,我们标记了基础模型是不参与训练的,这意味着在训练过程中,我们只会训练新添加到顶部的网络层,而在网络底部的权重值不会发生变化。

训练

在我们开始整个训练过程之前,让我们思考一下,这些模型的大部头其实已经被训练过了。所以,我们可以执行评估过程来看看评估结果如何:


steps_per_epoch = round(data_loader.num_train_examples)//BATCH_SIZEvalidation_steps = 20
loss1, accuracy1 = vgg16.evaluate(data_loader.validation_batches, steps = 20)loss2, accuracy2 = googlenet.evaluate(data_loader.validation_batches, steps = 20)loss3, accuracy3 = resnet.evaluate(data_loader.validation_batches, steps = 20)
print("--------VGG16---------")print("Initial loss: {:.2f}".format(loss1))print("Initial accuracy: {:.2f}".format(accuracy1))print("---------------------------")
print("--------GoogLeNet---------")print("Initial loss: {:.2f}".format(loss2))print("Initial accuracy: {:.2f}".format(accuracy2))print("---------------------------")
print("--------ResNet---------")print("Initial loss: {:.2f}".format(loss3))print("Initial accuracy: {:.2f}".format(accuracy3))print("---------------------------")
复制代码


有意思的是,这些模型在没有预先训练的情况下,我们得到的结果也还过得去(50%的精确度):


———VGG16———Initial loss: 5.30Initial accuracy: 0.51—————————-
——GoogLeNet—–Initial loss: 7.21Initial accuracy: 0.51—————————-
——–ResNet———Initial loss: 6.01Initial accuracy: 0.51—————————-
复制代码


把 50%作为训练的起点已经挺好的了。所以,就让我们把训练过程跑起来吧,看看我们是否能得到更好的结果。首先,我们训练 VGG16:


history = vgg16.fit(data_loader.train_batches,                    epochs=10,                    validation_data=data_loader.validation_batches)
复制代码


训练过程历史数据显示大致如下:



VGG16 的训练过程历史数据


然后我们可以训练 GoogLeNet。


history = googlenet.fit(data_loader.train_batches,                    epochs=10,                    validation_data=data_loader.validation_batches)
复制代码


这个网络训练过程历史数据如下:



GoogLeNet 的训练过程历史数据


最后是 ResNet 的训练:


history = resnet.fit(data_loader.train_batches,                    epochs=10,                    validation_data=data_loader.validation_batches)
复制代码


以下是 ResNet 训练过程历史数据如下:



ResNet 的训练过程历史数据


由于我们只训练了顶部的几层网络,而不是整个网络,所以训练这三个模型只用了几个小时,而不是几个星期。

评估

我们看到在训练开始前,我们已经有了 50%左右的精确度。让我们来看下训练后是什么情况:


loss1, accuracy1 = vgg16.evaluate(data_loader.test_batches, steps = 20)loss2, accuracy2 = googlenet.evaluate(data_loader.test_batches, steps = 20)loss3, accuracy3 = resnet.evaluate(data_loader.test_batches, steps = 20)
print("--------VGG16---------")print("Loss: {:.2f}".format(loss1))print("Accuracy: {:.2f}".format(accuracy1))print("---------------------------")
print("--------GoogLeNet---------")print("Loss: {:.2f}".format(loss2))print("Accuracy: {:.2f}".format(accuracy2))print("---------------------------")
print("--------ResNet---------")print("Loss: {:.2f}".format(loss3))print("Accuracy: {:.2f}".format(accuracy3))print("---------------------------")
复制代码


结果如下:


——–VGG16———Loss: 0.25Accuracy: 0.93—————————
——–GoogLeNet———Loss: 0.54Accuracy: 0.95———————————–ResNet———Loss: 0.40Accuracy: 0.97—————————
复制代码


我们可以看到这三个模型的结果都相当好,其中 ResNet 效果最好,精确度高达 97%。

结论

在本文中,我们演示了怎样使用 Tensorflow 进行迁移学习。我们创建了一个试验场,在其中可以尝试不同的数据预训练架构,并且在几个小时内就能得到较好的结果。在我们的例子里,我们使用了三个很有名的卷积架构,快速将其修改用于具体的问题。在下篇文章中,我们将微调这些模型,来看看我们是否能得到更好的结果。


原文链接:


https://rubikscode.net/2019/11/11/transfer-learning-with-tensorflow-2/


公众号推荐:

跳进 AI 的奇妙世界,一起探索未来工作的新风貌!想要深入了解 AI 如何成为产业创新的新引擎?好奇哪些城市正成为 AI 人才的新磁场?《中国生成式 AI 开发者洞察 2024》由 InfoQ 研究中心精心打造,为你深度解锁生成式 AI 领域的最新开发者动态。无论你是资深研发者,还是对生成式 AI 充满好奇的新手,这份报告都是你不可错过的知识宝典。欢迎大家扫码关注「AI前线」公众号,回复「开发者洞察」领取。

2019-12-05 08:042851
用户头像
蔡芳芳 InfoQ主编

发布了 781 篇内容, 共 496.5 次阅读, 收获喜欢 2749 次。

关注

评论

发布
暂无评论
发现更多内容

如何使用 NFTScan NFT 数据创建一个 ERC-6551 账户?

NFT Research

ERC721 NFT\

AI推理实践丨多路极致性能目标检测最佳实践设计解密

华为云开发者联盟

Go语言:通过TDD测试驱动开发学习 Mocking (模拟)的思想

不在线第一只蜗牛

TDD Go 语言

设计通用流程和可变点的方法一些思考

快乐非自愿限量之名

设计 教程 通用流程

如何用好强大的 TDengine 集群 ? 先了解 RAFT 在 3.0 中的应用

爱倒腾的程序员

涛思数据 时序数据库 ​TDengine

人脸识别技术在智能交通管理中的应用

来自四九城儿

推荐!十个平台工程工具助力开发人员提升效率和体验

SEAL安全

成都堡垒机采购选择哪家好?具体功能有哪些?具体多少钱?

行云管家

网络安全 信息安全 成都 堡垒机

白盒、黑盒、SAST、DAST傻傻分不清?

华为云PaaS服务小智

华为云 华为开发者大会2023 代码检查

CSS 属性选择器,前端开发的效率好物

伤感汤姆布利柏

TiDB 7.1 资源管控验证测试

TiDB 社区干货传送门

版本测评 新版本/特性解读 7.x 实践

国产替代:国有企业数智化转型的挑战与机遇

用友BIP

国产替代

如何利用ChatGPT革新智能合约和区块链

互联网工科生

区块链 ChatGPT

汽车电子国产化,华秋助力国产电源IC高质量发展

华秋电子

BOM/PCB/Gerber比对功能再升级,华秋DFM新版邀您体验!

华秋电子

华秋约定您!7月11-13日慕尼黑上海电子展不见不散~

华秋电子

跨文件,跨函数能力是什么?和污点分析能力有什么关系?

华为云PaaS服务小智

编程 软件开发 华为云 华为开发者大会2023 代码检查

架构成长之路 | 图解分布式共识算法 Paxos 议会协议

阿里技术

分布式 PAXOS Paxos 议会协议

干货满满!阿里、京东、网易等多位专家力荐的高并发编程速成笔记

小小怪下士

Java 编程 程序员 高并发

【HDC.Cloud 2023】华为开发者大会2023来了!这份PaaS参会指南请查收!

华为云PaaS服务小智

云计算 华为云 华为开发者大会2023

快速重拾 Tmux

高端章鱼哥

Linux tmux

【深入浅出 Yarn 架构与实现】5-3 Yarn 调度器资源抢占模型

快乐非自愿限量之名

架构 YARN

【6.30-7.7】写作社区优秀技术博文一览

InfoQ写作社区官方

热门活动 优质创作周报

不可不知的八个出色的Java项目

这我可不懂

Java 工具

上传IPA后需要多久才能在构建版本中看到应用?

雪奈椰子

云数据库是杀猪盘么,去掉中间商赚差价,aws数据库性能提升 10 倍!价格便宜十倍。

TiDB 社区干货传送门

数据库架构设计 7.x 实践

数据库领域2023上半年盘点

亚信AntDB数据库

数据库 AntDB AntDB数据库

温州是几线城市?有几家正规等保测评机构?

行云管家

等级保护 等保测评机构 温州

TensorFlow 2.0迁移学习实践指南_语言 & 开发_Rubikscode_InfoQ精选文章