50万奖金+官方证书,深圳国际金融科技大赛正式启动,点击报名 了解详情
写点什么

如何应用 TFGAN 快速实践生成对抗网络?

  • 2018-06-03
  • 本文字数:4252 字

    阅读完需:约 14 分钟

前言

生成对抗网络(Generative Adversarial Nets ,GAN)目前已广泛应用于图像生成、超分辨率图片生成、图像压缩、图像风格转换、数据增强、文本生成等场景。越来越多的研发人员从事 GAN 网络的研究,提出了各种 GAN 模型的变种,包括 CGAN、InfoGAN、WGAN、CycleGAN 等。为了更容易地应用及实践 GAN 模型,谷歌开源了名为 TFGAN 的 TensorFlow 库,可快速实践各种 GAN 模型。本文主要讲解 TFGAN 如何应用于原生 GAN、CGAN、InfoGAN、WGAN 等场景,如下所示:

其中,原生GAN 生成的Mnist 图像不可控:CGAN 可按照数字标签生成相应标签的数字图像;InfoGAN 可认为是无监督的CGAN,前两行表示用分类潜变量控制数字的生成类别,中间两行表示用连续型潜变量控制数字的粗细,最后两行表示用连续型潜变量控制数字的倾斜方向;ImageToImage 是CGAN 的一种,实现图像的风格转换。

生成对抗网络与TFGAN

GAN 由 Goodfellow 首先提出,主要由两部分构成:Generator(生成器),简称 G;Discriminator(判别器), 简称 D。生成器主要用噪声 z 生成一个类似真实数据的样本,样本越逼真越好;判别器用于估计一个样本来自于真实数据还是生成数据,判定越准确越好。如下图所示:

上图中,对于真实的采样数据,通过判别网络后,生成D(x)。D(x) 的输出是0-1 范围内的一个实数,用来判断这个图片是一个真实图片的概率是多大。这样对于真实数据,D(x) 越接近1 越好。对于随机噪声z,通过生成网络G 后,G 将这个随机噪声转化为生成数据x。如果是图片生成问题,G 网络的输出就是一张生成的假图片,用G(z) 表示。判别模型D 要使得D(G(z)) 接近与0,即能够判断生成的图片是假的;生成模型G 要使得D(G(z)) 接近于1,即要能够要欺骗判别模型,使得D 认为G(z) 生成的假数据是真的。这样通过判别模型D 和生成模型G 的博弈,使得D 无法判断一张图片是生成出来的还是真实的而结束。

假设P_r 和P_g 分别代表真实数据的分布与生成数据的分布,这样判别模型的目标函数可以表示为:

而生成模型的是让判别模型D 无法区别真实数据与生成数据,这样优化目标函数为:

TFGAN 库的地址为 https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan ,主要包含以下几个组件:

  1. 核心架构,主要包括创建 TFGAN 模型,添加 Loss 值,创建训练 operation,运行训练 operation。
  2. 常用操作,主要提供了梯度修剪操作,归一化操作及条件化操作等。
  3. 损失函数,主要提供了 GAN 中常用的损失和惩罚函数,如 Wasserstein 损失、梯度惩罚、互信息惩罚等。
  4. 模型评估,提供了 Inception Score 和 Frechet Distance 指标,用于评估无条件生成模型。
  5. 示例,谷歌同时开源了常用的 GAN 网络示例代码,包括 unconditional GAN,conditional GAN, InfoGAN,WGAN 等。相关用例可从 https://github.com/tensorflow/models/tree/master/research/gan/ 地址下载。

使用 TFGAN 库训练 GAN 网络主要包含如下几个步骤:

1. 确定 GAN 网络的输入,如下所示:

复制代码
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

2. 设定 GANModel 中的生成模型和判别模型,如下所示:

复制代码
gan_model = tfgan.gan_model(
generator_fn=mnist.unconditional_generator, # you define
discriminator_fn=mnist.unconditional_discriminator, # you define
real_data=images,
generator_inputs=noise)

3. 设定 GANLoss 中的损失方程,如下所示:

复制代码
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss)

4. 设定 GANTrainOps 中的训练操作,如下所示:

复制代码
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))

5. 运行模型训练,如下所示:

复制代码
tfgan.gan_train(
train_ops,
hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
logdir=FLAGS.train_log_dir)

CGAN

CGAN(Conditional Generative Adversarial Nets),针对 GAN 本身不可控的缺点,加入监督信息,训练从无监督变成有监督,指导 GAN 网络进行生成。例如输入分类的标签,可生成相应标签的图像。这样 CGAN 的目标方程可以转换为:

其中,y 是加入的监督信息,D(x|y) 表示在y 的条件下判定真实数据x,D(G(z|y)) 表示在y 的条件下判定生成数据G(z|y)。例如,MNIST 数据集可根据数字label 信息,生成相应标签的图片;人脸生成数据集,可根据性别、是否微笑、年龄等信息,生成相应的人脸图片。CGAN 的架构如下图所示:

在TFGAN 中提供了,基于one_hot_labels 变量和输入tensor 生成condition tensor 的API,如下所示:

tfgan.features.condition_tensor_from_onehot (tensor, one_hot_labels, embedding_size)其中,tensor 为输入数据,one_hot_labels 为 onehot 标签,shape 为 [batch_size, num_classes],embedding_size 为每个 label 对应的 embedding 大小,返回值为 condition tensor。

ImageToImage

Phillip Isola 等提出了基于 CGAN 的图片生成图片的对抗神经网络《Image-to-Image Translation with Conditional Adversarial Networks》。网络设计的基本思想如下所示:

其中,x 为输入的线条图,G(x) 为生成图片,y 为线条图x 对应渲染后的真图片,生成模型G 用于生成图片,判断模型D 用于判定生成图片的真假。判别网络能够最大化判断(x,y) 的数据为真,判断(x,G(x)) 数据为假。而生成网络使得判别网络判断(x,G(x)) 数据为真,从而进行生成模型和判别模型的相互博弈。为了使生成模型不仅能够欺骗判别模型,还要使得生成图像要像真实图片,这样在目标函数中加入了真实图像和生成图像的L1 距离,如下所示:

TFGAN 库,提供了 ImageToImage 生成对抗网络的相关损失方程 API 使用示例,如下所示:

复制代码
# 定义真实数据与生成数据的 L1 损失
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) / FLAGS.patch_size ** 2
# gan_loss 为目标函数损失
gan_loss = tfgan.losses.combine_adversarial_loss(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)

InfoGAN

在 GAN 中,生成器用噪声 z 生成数据时,没有加任何的条件限制,很难用 z 的任何一个维度信息表示相关的语义特征。所以在数据生成过程中,无法控制什么样的噪声 z 可以生成什么样的数据,在很大程度上限制了 GAN 的使用。InfoGAN 可以认为是无监督的 CGAN,在噪声 z 上增加潜变量 c,使得生成模型生成的数据与浅变量 c 具有较高的互信息,其中 Info 就是代表互信息的含义。互信息定义为两个熵的差值,H(x) 是先验分布的熵,H(x|y) 代表后验分布的熵。如果 x,y 是相互独立的变量,那么互信息的值为 0,表示 x,y 没有关系;如果 x,y 有相关性,那么互信息大于 0。这样在已知 y 的情况下,可以推断出那些 x 的值出现高。这样 InfoGAN 的目标方程为:

InfoGAN 的网络结构如下所示:

上图中InfoGAN 与GAN 的区别在于,对应判别网络的输出D(x),生成变分分布Q(c|x),从而能用Q(c|x) 来逼近P(c|x),从而增大生成数据与潜变量c 的互信息。

TFGAN 中提供了 InfoGan 相关 API,如下所示:

复制代码
#通过 tfgan.infogan_model,定义 infogan 模型
infogan_model = tfgan.infogan_model(
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
real_data=real_images,
unstructured_generator_inputs=unstructured_inputs,
structured_generator_inputs=structured_inputs)
#通过 tfgan.gan_loss,生成 infogan 模型的 loss 值:
infogan_loss = tfgan.gan_loss(
infogan_model,
gradient_penalty_weight=1.0,
mutual_information_penalty_weight=1.0)

#InfoGan 的 Loss 值为在 GAN 的 loss 值上,加上互信息 I(c;G(z,c)),TFGAN 中提供了互信息计算的 API,如下所示。其中 structured_generator_inputs 为潜变量的噪音信息,predicted_distributions 为变分分布 Q(c|x)。

def mutual_information_penalty(structured_generator_inputs, predicted_distributions)## WGAN

Martin Arjovsky 等提出了 WGAN(Wasserstein GAN),解决了传统 GAN 训练困难、生成器和判别器的 loss 很难指示训练进程、生成样本缺乏多样性等问题,主要有以下优点:

  1. 能够平衡生成器和判别器的训练程度,使得 GAN 的模型训练稳定。
  2. 能够保证生产样本的多样性。
  3. 提出使用 Wasserstein 距离来衡量模型训练的程度,数值越小表示训练得越好,成器生成的图像质量越高。

WGAN 的算法与原始 GAN 算法的差异主要体现在:

  1. 去掉判别模型最后一层的 sigmoid 操作。
  2. 生成模型和判别模型的 loss 值不取 log 操作。
  3. 每次更新判别模型的参数之后把模型参数的绝对值截断到不超过固定常数 c。
  4. 使用 RMSProp 算法,不用基于动量的优化算法,例如 momentum 和 Adam。

WGAN 的算法结构如下所示:

TFGAN 中提供了 WGan 相关 API,如下所示:

复制代码
#生成网络损失方程
generator_loss_fn=tfgan_losses.wasserstein_generator_loss
#判别网络损失方程
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss

总结

本文首先介绍了生成对抗网络和 TFGAN,生成对抗网络模型用于图像生成、超分辨率图片生成、图像压缩、图像风格转换、数据增强、文本生成等场景;TFGAN 是 TensorFlow 库,用于快速实践各种 GAN 模型。然后讲解了 CGAN、ImageToImage、InfoGAN、WGAN 模型的主要思想,并对关键技术进行了分析,主要包括目标函数、网络架构、损失方程及相应的 TFGAN API。用户可基于 TFGAN 快速实践生成对抗网络模型,并应用到工业领域中的相关场景。

参考文献

[1] Generative Adversarial Networks.
[2] Conditional Generative Adversarial Nets.
[3] InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets.
[4] Wasserstein GAN.
[5] Image-to-Image Translation with Conditional Adversarial Networks.
[6] https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan .
[7] https://github.com/tensorflow/models/tree/master/research/gan .

作者简介

武维(微信:allawnweiwu):博士,现为 IBM 架构师。主要从事深度学习平台及应用研究,大数据领域的研发工作。

2018-06-03 17:582451

评论

发布
暂无评论
发现更多内容

袋鼠云产品功能更新一探究竟|实时开发,效率再升级!

袋鼠云数栈

数据治理 数据产品 数据智能 数字孪生 空间智能

AI Compass前沿速览:Nano Banana玩法教学、AgentScope、Hunyuan-MT-7B、HunyuanWorld-Voyager、AudioStory

汀丶人工智能

当Playwright遇见MCP,AI智能体实现自主化UI回归测试

测试人

软件测试

mybatis中<if>条件判断带数字的字符串失效问题

刘大猫

人工智能 算法 智慧城市 光电科学 材料科学

HyperMesh几何修复、清理和简化

智造软件

CAE软件 Hypermesh hyperworks

厌倦了日复一日?不如从升级鸿蒙5.1,换个手机主题开始!

最新动态

设备点检 设备维护经验总结(5)

万里无云万里天

工业 设备维护 工厂运维 设备点检

Coze教程:核心功能 - 智能体创建与角色设计

测试人

什么是IPv6?和IPv4相比,IPv6具有哪些优势和特点?

国科云

AI Agent重构SOC:下一代智能安全运营平台的能力跃迁

日志易

AI SOC 日志易 安全运营中心

京东商品详情API数据解析(附代码)

tbapi

京东API 京东商品详情API 京东数据采集 京东数据分析

火山引擎数智平台发布 Data Agent"一客一策"与 AI 数据湖"算子广场"

极客天地

企业怎么挑合适的数据集成工具?

谷云科技RestCloud

Apache 数据同步 ETL 数据集成工具 informatica

代码可读性与命名艺术:空间布局与命名的核心原则

qife122

代码可读性 命名约定

设备点检 设备维护经验总结(4)

万里无云万里天

工业 设备维护 工厂运维 设备点检

Data Agent 再升级:一客一策,营销服务的理想型来了!

北京中暄互动广告传媒有限公司

AI算子广场,大幅降低多模态数据处理门槛

北京中暄互动广告传媒有限公司

刷新记录:TapData Oracle 日志同步性能达 80K TPS,重塑实时同步新标准

tapdata

Tapdata 实时数据同步 Oracle日志解析 Oracle实时同步 Oracle数据同步工具

Jenkins 可观测最佳实践

观测云

CI/CD

智能推荐新纪元:快手生成式技术对系统边界的消融与重建

老周聊架构

AICon

开源能源管理系统 MyEMS:技术深耕与实践赋能的深度解析

开源能源管理系统

开源 能源管理系统

告别 Hadoop,拥抱 StarRocks!政采云数据平台升级之路

镜舟科技

hadoop 数据仓库 数字化转型 存算分离 StarRocks

开源能源管理系统 MyEMS:智能化升级与跨场景适配的全新探索

开源能源管理系统

开源 开源能源管理系统

API管理进入新阶段:iPaaS如何统一接口治理与运维?

谷云科技RestCloud

数据治理 数据传输 API治理 API管理 ipaas

融云:当我们谈论 AI 重构业务时,我们到底在谈论什么

融云 RongCloud

MyEMS 开源能源管理系统:赋能高效能源管控与可持续发展

开源能源管理系统

开源 能源管理系统

Java小程序调用物流接口服务:快递鸟API集成指南

快递鸟

百度智能云「智能集锦」自动生成短剧解说,三步实现专业级素材生产

Baidu AICLOUD

视频云 智能剪辑

如何通过Python SDK获取Collection列表

DashVector

人工智能 数据库 向量检索 大模型

百亿数据,秒级响应:YMatrix 如何助力孚能科技实现工厂“智造”升级?

YMatrix 超融合数据库

超融合数据库 数智化转型 YMatrix 孚能科技

在AI技术唾手可得的时代,挖掘新需求成为核心竞争力——某知名教育游戏辅助工具需求洞察

qife122

需求分析 功能优化

如何应用TFGAN快速实践生成对抗网络?_Google_武维_InfoQ精选文章