写点什么

如何应用 TFGAN 快速实践生成对抗网络?

2018 年 6 月 03 日

前言

生成对抗网络(Generative Adversarial Nets ,GAN)目前已广泛应用于图像生成、超分辨率图片生成、图像压缩、图像风格转换、数据增强、文本生成等场景。越来越多的研发人员从事 GAN 网络的研究,提出了各种 GAN 模型的变种,包括 CGAN、InfoGAN、WGAN、CycleGAN 等。为了更容易地应用及实践 GAN 模型,谷歌开源了名为 TFGAN 的 TensorFlow 库,可快速实践各种 GAN 模型。本文主要讲解 TFGAN 如何应用于原生 GAN、CGAN、InfoGAN、WGAN 等场景,如下所示:

其中,原生GAN 生成的Mnist 图像不可控:CGAN 可按照数字标签生成相应标签的数字图像;InfoGAN 可认为是无监督的CGAN,前两行表示用分类潜变量控制数字的生成类别,中间两行表示用连续型潜变量控制数字的粗细,最后两行表示用连续型潜变量控制数字的倾斜方向;ImageToImage 是CGAN 的一种,实现图像的风格转换。

生成对抗网络与TFGAN

GAN 由 Goodfellow 首先提出,主要由两部分构成:Generator(生成器),简称 G;Discriminator(判别器), 简称 D。生成器主要用噪声 z 生成一个类似真实数据的样本,样本越逼真越好;判别器用于估计一个样本来自于真实数据还是生成数据,判定越准确越好。如下图所示:

上图中,对于真实的采样数据,通过判别网络后,生成D(x)。D(x) 的输出是0-1 范围内的一个实数,用来判断这个图片是一个真实图片的概率是多大。这样对于真实数据,D(x) 越接近1 越好。对于随机噪声z,通过生成网络G 后,G 将这个随机噪声转化为生成数据x。如果是图片生成问题,G 网络的输出就是一张生成的假图片,用G(z) 表示。判别模型D 要使得D(G(z)) 接近与0,即能够判断生成的图片是假的;生成模型G 要使得D(G(z)) 接近于1,即要能够要欺骗判别模型,使得D 认为G(z) 生成的假数据是真的。这样通过判别模型D 和生成模型G 的博弈,使得D 无法判断一张图片是生成出来的还是真实的而结束。

假设P_r 和P_g 分别代表真实数据的分布与生成数据的分布,这样判别模型的目标函数可以表示为:

而生成模型的是让判别模型D 无法区别真实数据与生成数据,这样优化目标函数为:

TFGAN 库的地址为 https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan ,主要包含以下几个组件:

  1. 核心架构,主要包括创建 TFGAN 模型,添加 Loss 值,创建训练 operation,运行训练 operation。
  2. 常用操作,主要提供了梯度修剪操作,归一化操作及条件化操作等。
  3. 损失函数,主要提供了 GAN 中常用的损失和惩罚函数,如 Wasserstein 损失、梯度惩罚、互信息惩罚等。
  4. 模型评估,提供了 Inception Score 和 Frechet Distance 指标,用于评估无条件生成模型。
  5. 示例,谷歌同时开源了常用的 GAN 网络示例代码,包括 unconditional GAN,conditional GAN, InfoGAN,WGAN 等。相关用例可从 https://github.com/tensorflow/models/tree/master/research/gan/ 地址下载。

使用 TFGAN 库训练 GAN 网络主要包含如下几个步骤:

1. 确定 GAN 网络的输入,如下所示:

复制代码
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

2. 设定 GANModel 中的生成模型和判别模型,如下所示:

复制代码
gan_model = tfgan.gan_model(
generator_fn=mnist.unconditional_generator, # you define
discriminator_fn=mnist.unconditional_discriminator, # you define
real_data=images,
generator_inputs=noise)

3. 设定 GANLoss 中的损失方程,如下所示:

复制代码
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss)

4. 设定 GANTrainOps 中的训练操作,如下所示:

复制代码
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))

5. 运行模型训练,如下所示:

复制代码
tfgan.gan_train(
train_ops,
hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
logdir=FLAGS.train_log_dir)

CGAN

CGAN(Conditional Generative Adversarial Nets),针对 GAN 本身不可控的缺点,加入监督信息,训练从无监督变成有监督,指导 GAN 网络进行生成。例如输入分类的标签,可生成相应标签的图像。这样 CGAN 的目标方程可以转换为:

其中,y 是加入的监督信息,D(x|y) 表示在y 的条件下判定真实数据x,D(G(z|y)) 表示在y 的条件下判定生成数据G(z|y)。例如,MNIST 数据集可根据数字label 信息,生成相应标签的图片;人脸生成数据集,可根据性别、是否微笑、年龄等信息,生成相应的人脸图片。CGAN 的架构如下图所示:

在TFGAN 中提供了,基于one_hot_labels 变量和输入tensor 生成condition tensor 的API,如下所示:

tfgan.features.condition_tensor_from_onehot (tensor, one_hot_labels, embedding_size)其中,tensor 为输入数据,one_hot_labels 为 onehot 标签,shape 为 [batch_size, num_classes],embedding_size 为每个 label 对应的 embedding 大小,返回值为 condition tensor。

ImageToImage

Phillip Isola 等提出了基于 CGAN 的图片生成图片的对抗神经网络《Image-to-Image Translation with Conditional Adversarial Networks》。网络设计的基本思想如下所示:

其中,x 为输入的线条图,G(x) 为生成图片,y 为线条图x 对应渲染后的真图片,生成模型G 用于生成图片,判断模型D 用于判定生成图片的真假。判别网络能够最大化判断(x,y) 的数据为真,判断(x,G(x)) 数据为假。而生成网络使得判别网络判断(x,G(x)) 数据为真,从而进行生成模型和判别模型的相互博弈。为了使生成模型不仅能够欺骗判别模型,还要使得生成图像要像真实图片,这样在目标函数中加入了真实图像和生成图像的L1 距离,如下所示:

TFGAN 库,提供了 ImageToImage 生成对抗网络的相关损失方程 API 使用示例,如下所示:

复制代码
# 定义真实数据与生成数据的 L1 损失
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) / FLAGS.patch_size ** 2
# gan_loss 为目标函数损失
gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)

InfoGAN

在 GAN 中,生成器用噪声 z 生成数据时,没有加任何的条件限制,很难用 z 的任何一个维度信息表示相关的语义特征。所以在数据生成过程中,无法控制什么样的噪声 z 可以生成什么样的数据,在很大程度上限制了 GAN 的使用。InfoGAN 可以认为是无监督的 CGAN,在噪声 z 上增加潜变量 c,使得生成模型生成的数据与浅变量 c 具有较高的互信息,其中 Info 就是代表互信息的含义。互信息定义为两个熵的差值,H(x) 是先验分布的熵,H(x|y) 代表后验分布的熵。如果 x,y 是相互独立的变量,那么互信息的值为 0,表示 x,y 没有关系;如果 x,y 有相关性,那么互信息大于 0。这样在已知 y 的情况下,可以推断出那些 x 的值出现高。这样 InfoGAN 的目标方程为:

InfoGAN 的网络结构如下所示:

上图中InfoGAN 与GAN 的区别在于,对应判别网络的输出D(x),生成变分分布Q(c|x),从而能用Q(c|x) 来逼近P(c|x),从而增大生成数据与潜变量c 的互信息。

TFGAN 中提供了 InfoGan 相关 API,如下所示:

复制代码
#通过 tfgan.infogan_model,定义 infogan 模型
infogan_model = tfgan.infogan_model(
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
real_data=real_images,
unstructured_generator_inputs=unstructured_inputs,
structured_generator_inputs=structured_inputs)
#通过 tfgan.gan_loss,生成 infogan 模型的 loss 值:
infogan_loss = tfgan.gan_loss(
infogan_model,
gradient_penalty_weight=1.0,
mutual_information_penalty_weight=1.0)

#InfoGan 的 Loss 值为在 GAN 的 loss 值上,加上互信息 I(c;G(z,c)),TFGAN 中提供了互信息计算的 API,如下所示。其中 structured_generator_inputs 为潜变量的噪音信息,predicted_distributions 为变分分布 Q(c|x)。

def mutual_information_penalty(structured_generator_inputs, predicted_distributions)## WGAN

Martin Arjovsky 等提出了 WGAN(Wasserstein GAN),解决了传统 GAN 训练困难、生成器和判别器的 loss 很难指示训练进程、生成样本缺乏多样性等问题,主要有以下优点:

  1. 能够平衡生成器和判别器的训练程度,使得 GAN 的模型训练稳定。
  2. 能够保证生产样本的多样性。
  3. 提出使用 Wasserstein 距离来衡量模型训练的程度,数值越小表示训练得越好,成器生成的图像质量越高。

WGAN 的算法与原始 GAN 算法的差异主要体现在:

  1. 去掉判别模型最后一层的 sigmoid 操作。
  2. 生成模型和判别模型的 loss 值不取 log 操作。
  3. 每次更新判别模型的参数之后把模型参数的绝对值截断到不超过固定常数 c。
  4. 使用 RMSProp 算法,不用基于动量的优化算法,例如 momentum 和 Adam。

WGAN 的算法结构如下所示:

TFGAN 中提供了 WGan 相关 API,如下所示:

复制代码
#生成网络损失方程
generator_loss_fn=tfgan_losses.wasserstein_generator_loss
#判别网络损失方程
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss

总结

本文首先介绍了生成对抗网络和 TFGAN,生成对抗网络模型用于图像生成、超分辨率图片生成、图像压缩、图像风格转换、数据增强、文本生成等场景;TFGAN 是 TensorFlow 库,用于快速实践各种 GAN 模型。然后讲解了 CGAN、ImageToImage、InfoGAN、WGAN 模型的主要思想,并对关键技术进行了分析,主要包括目标函数、网络架构、损失方程及相应的 TFGAN API。用户可基于 TFGAN 快速实践生成对抗网络模型,并应用到工业领域中的相关场景。

参考文献

[1] Generative Adversarial Networks.
[2] Conditional Generative Adversarial Nets.
[3] InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets.
[4] Wasserstein GAN.
[5] Image-to-Image Translation with Conditional Adversarial Networks.
[6] https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan .
[7] https://github.com/tensorflow/models/tree/master/research/gan .

作者简介

武维(微信:allawnweiwu):博士,现为 IBM 架构师。主要从事深度学习平台及应用研究,大数据领域的研发工作。

2018 年 6 月 03 日 17:581124

评论

发布
暂无评论
发现更多内容

架构师 3 期 3 班 -week10- 作业

zbest

作业 week10

满满的干货!阿里开源Java程序员2021年金三银四面试指南

Java架构之路

Java 程序员 架构 面试 编程语言

驱动力读后感之一

张老蔫

28天写作

滴滴夜莺社区文章有奖征集

Obsuite

最佳实践 奖品 案例分享 滴滴夜莺

Gradle Docker插件将SpringBoot应用程序打包为Docker镜像

wjchenge

Docker SpringBoot 2 Gradle

呕心沥血!耗费我一个多月时间整理出这“全程高能得Java面试题合集”面试首选,跳槽必备!诚意之作,收藏不亏!

Java成神之路

Java 程序员 架构 面试 编程语言

云讲堂 | 5期视频带你全面了解滴滴Logi-KafkaManager

Obsuite

kafka 运维 监控 滴滴Logi

「抖音同款播放器」上市:有效解决卡顿、黑屏和模糊

字节跳动技术团队

区块链--另一场改变社会组织方式的工业革命

CECBC区块链专委会

区块链

区块链数字版权保护解决方案--开创了新的维权思路

135深圳3055源中瑞8032

爱奇艺率先上线CUVA HDR标准内容,将多端支持该标准2021央视春晚直播、点播

爱奇艺技术产品团队

十里选一终拿offer,准阿里java程序员分享面试经验!

Java架构之路

Java 程序员 架构 面试 编程语言

三顾茅庐,七面阿里,25k*16offer,还原我的大厂面经

周老师

Java 编程 程序员 架构 面试

2021最新「阿里」Java高级工程师面试高频题:JVM+Redis+并发+算法+框架

比伯

Java 编程 架构 面试 计算机

用最少人力玩转万亿级数据,我用的就是MongoDB

dbaplus社群

教你10分钟解决短信验证码接口被盗刷、轰炸、恶意点击等问题。

香芋味的猫丶

短信防刷 短信验证码 短信防轰炸 短信防火墙

区块链电子签章解决方案--区块链助力电子签章

135深圳3055源中瑞8032

webpack | 进阶用法2:代码分割和动态引入的实现方式

梁龙先森

前端工程 webpack 28天写作 2月春节不断更

如何基于Spring Aware和InitializingBean接口实现策略模式

技术进阶之路

Java spring 5 Java设计模式

欢迎来到,2021摄像机竞技场

脑极体

如何基于Spring 事件驱动模型实现业务解耦

技术进阶之路

Java spring 架构

智慧社区解决方案,平安小区平台APP搭建

135深圳3055源中瑞8032

科技,亲吻这个特别的春节

脑极体

产品经理训练营课后作业-第三周-产品思维和产品意识

.nil?

产品经理训练营

最简单的map,filter,forEach,every,some的使用教学

coolFish(呔呆)

方法 Vue 前端 数组 js

Docker开启Remote API 访问 2375端口

wjchenge

Docker 2375端口

万字长文详细总结!关于继承、重写与重载、封装、接口的硬核干货

codevald

Java 接口 封装、继承、多态 类对象

技术方案设计的方法论及案例分享

阿里巴巴云原生

数据库 流计算 云原生 监控 存储

年终总结:华为|字节|腾讯|京东|网易|滴滴面经分享(斩获6个offer)

Java架构之路

Java 程序员 架构 面试 编程语言

Varchar竟然会自动存储成lob类型?

dbaplus社群

作业3

瑾瑾呀

4月17日 HarmonyOS 开发者日·上海站

4月17日 HarmonyOS 开发者日·上海站

如何应用TFGAN快速实践生成对抗网络?-InfoQ