阿里、蚂蚁、晟腾、中科加禾精彩分享 AI 基础设施洞见,现购票可享受 9 折优惠 |AICon 了解详情
写点什么

GAN 毕业手册:从零到一构建自己的 GAN 模型

  • 2019-05-16
  • 本文字数:5339 字

    阅读完需:约 18 分钟

GAN毕业手册:从零到一构建自己的GAN模型

本文首先对生成对抗网络(GAN)的发展进行了简单的介绍。随后深入探讨了 GAN 的网络原理,并介绍了网络的评价指标。最后带领读者创建并运行自己的 GAN,利用 MNIST 数据集训练网络,用 Comet.ml 对实验数据和参数进行分析,并生成手写数字。


MNIST数据集的隐空间可视化


生成对抗网络(GAN)的能力已经超乎公众的想象——由AI生成的名人照片影响了流行文化,创造的艺术品在高级艺术品拍卖会上售出了数千美元的价格


在本文中,我们将探讨:


  • GAN 的简要介绍

  • 了解和评价 GAN

  • 运行你自己的 GAN


现在网络上有大量的 GAN 学习资源,因此本文的重点是了解如何评价 GAN。我们还将引导读者通过运行自己的 GAN,来生成像 MNIST 数据集那样的手写数字。



GAN 的训练过程图,从图中可以看出生成的手写数字随着训练过程变得越来越真实

GAN 入门

自 2014 年 Ian Goodfellow 的《生成对抗网络(Generative Adversarial Networks)》论文发表以来,GAN 的进展突飞猛进,生成结果也越来越具有照片真实感。


就在三年前,Ian Goodfellow 在reddit上回答GAN是否可以应用在文本领域的问题时,还认为 GAN 不能扩展到文本领域。


Ian Goodfellow在Reddit上针对GAN问题的回答


“由于 GAN 定义在实值数据上,因此 GAN 不能应用于 NLP。

GAN 的工作原理是训练一个生成网络,输出合成数据,然后利用判别网络判别合成数据。判别网络根据合成数据输出的梯度告诉你该如何对合成数据进行微调,使其更真实。

因此只有当合成数据是基于连续数字时,才能对其进行微调。如果是基于离散的数字,就没有办法做微小的改变。

例如,如果输出像素值为 1.0 的图像,则下一步可以将该像素值更改为 1.0001。

但如果输出单词‘penguin’,不能在下一步直接将其更改为‘penguin+.001’,因为没有‘penguin+.001’这样的单词。你必须从‘penguin’直接转变到‘ostrich’。

由于所有的 NLP 都是基于离散的值,如单词、字符或字节,所以目前还没有人知道该如何将 GAN 应用于 NLP。”


但是现在,GAN 已经可用于生成各种内容,包括图像、视频、音频和文本。这些输出的合成数据既可以用于训练其他的模型,也可以用于创建一些有趣的项目,如thispersonnesnotexist.com(这个人不存在)thisairbnbdoensotexist.com(这家酒店不存在),甚至This Machine Learning Medium post does not exist(这个机器学习帖子不存在)

GAN 原理

GAN 由两个神经网络组成,一个是合成新样本的生成器,另一个是对比训练样本与生成样本的判别器。判别器的目标是区分“真实”和“虚假”的输入(对样本来自模型分布还是真实分布进行分类)。这些样本可以是图像、视频、音频片段和文本。


GAN概览


为了合成这些新的样本,生成器的输入为随机噪声,然后尝试从训练数据中学习到的分布中生成真实的图像。


判别器网络(卷积神经网络)输出相对于合成数据的梯度,其中包含着如何改变合成数据以使其更具真实感的信息。最终生成器收敛,它可以生成符合真实数据分布的样本,而判别器无法区分生成数据和真实数据。


GAN Lab是一个可以在浏览器上训练 GAN 的项目,可以利用这个项目查看网络训练过程中数据分布逐渐收敛的动画过程。


这里给大家推荐一些 GAN 的精选介绍指南:


了解和评价 GAN

量化 GAN 的表现让人感觉是个非常主观的问题——“这张生成的脸看起来够真实吗?”,“这些生成的图像足够多样化吗?”——GAN 就像是黑匣子,我们不清楚模型的哪些部分会影响学习过程或结果的质量。


为此,麻省理工学院计算机科学与人工智能实验室(CSAIL)的一个小组最近发表了一篇论文《GAN解剖:可视化和理解生成对抗网络(GAN Dissection: Visualizing and Understanding Generative Adversarial Networks)》,其中介绍了一种可视化 GAN 的方法,并且分析了 GAN 单元如何与图像中的对象相关联,以及对象之间的关系。


该论文采用基于分割的网络解析方法,对生成神经网络的内部原理进行了分析和可视化。通过寻找 GAN 单元(神经元)和输出图像中的概念(如树、天空、云等)之间的一致性,我们能够鉴别出负责特定对象(如建筑物或云)的神经元。


通过干扰特定的GAN单元进行图像修改


通过强制激活和冻结(消除)这些对象的对应神经元,可以从这种粒度级别控制神经元,对生成图像进行编辑(例如,添加或删除图像中的树)。


然而,我们现在还不清楚网络是否能够解释场景中的对象,或者网络只是记忆这些对象。回答这个问题一个方法是用不实际的方式对图像进行变形。MIT CSAIL 给出了一个GAN绘画交互网页演示,其中最令人印象深刻的是,模型似乎能够将这些人为编辑限制在“真实照片”的变化范围内。如果你试图把草地强行画在天空上,会发生以下情况:


GAN Paint交互式页面


尽管我们激活了相应的神经元,但看起来好像 GAN 后面几层的神经元抑制了相应的信号。



对象的局部上下文对合成对象可能性的影响(在图中,在建筑物上生成门的概率比在树上或天空中生成门的概率要高).


另一个可视化 GAN 的方法是进行隐空间插值(GAN 通过从所学的隐空间采样生成新的实例)。这是查看生成样本之间的转换是否平滑的一种有效方法。


隐空间插值


虽然这些可视化方法能够帮助我们理解 GAN 的内部表示,但是如何找到量化的方法来理解 GAN 的训练过程和输出质量仍然是一个活跃的研究方向。


衡量生成图像质量和多样性的两个常用评价指标是:起始分数(Inception Score,IS)和 Fréchet 初始距离(Fréchet Inception Distance, FID)。在论文《关于起始分数的笔记(A Note on the Inception Score)》指出了前者的重要缺点之后,大多数研究人员已经从 IS 转向了 FID。

Inception Score

IS 由 Salimans 等人在 2016 年的论文《训练GAN的技术改进(Improved Techniques for Training GANs)》中提出。Inception Score 的提出受到一个观点的启发,即真实样本应该能被预训练的分类网络分类,例如预训练的 ImageNet 网络。从技术层面来讲,样本的 softmax 预测矢量的熵应该较低。


除了高可预测性(低熵),Inception Score 也从生成样本的多样性对 GAN 进行评价(例如生产样本分布的方差或熵较高),这意味着分类结果中不应该出现任何占支配性的类别。


如果这两个特点都满足,Inception Score 的值就会很高。将这两个标准结合起来的方法是计算样本的条件标签分布与所有样本的边缘分布之间的 KL 散度。

Fréchet Inception Distance

FID 由Heusel等人在 2017 年提出,FID 通过测量图像生成分布和真实分布之间的距离来衡量生成样本的真实性。FID 通过 Inception Net 的一个特殊层将一组生成的样本嵌入特征空间。该嵌入层可以被视为一个连续的多元高斯分布,然后计算生成数据和真实数据的平均值和协方差。这两个高斯分布之间的 Fréchet 距离(即 Wasserstein-2 距离)可以对生成样本的质量进行量化。FID 较低,则代表生成样本与真实样本更相似。


重要的一点是,FID 需要一定的样本数量才能得到良好的结果(建议 50K 个样本)。如果使用的样本数量太少,就会出现过度估计,并且估计值会有很大的方差。


Neal Jean 的博客中对比了不同论文中的 Inception Score 和 FID 分数:https://nealjean.com/ml/frechet-inception-distance/

了解更多

Aji Borji 的论文《GAN评价策略的好与坏(Pros and Cons of GAN Evaluation Measures)》中给出了一个更全面的表格,涵盖了更多的 GAN 评价指标:


常用GAN评价指标总结表格


有趣的是,其他研究人员并没有采用这两种方法,而是采用了其研究领域特定的评价指标。对于文本 GAN,Guy Tevet 和他的团队在论文《将文本GAN作为语言模型进行评价(Evaluating Text GANs as Language Models)》中提出,使用传统的基于概率的语言模型指标来评价 GAN 生成文本的分布。


在论文《我的GAN有多好(How good is my GAN)》中,Konstantin Shmelkov 和他的团队使用了两种基于图像分类的评价指标:GAN 训练和 GAN 测试,分别估算 GAN 的召回率(多样性)和准确率(图像质量)。你可以在 Google Brain 的研究论文《GAN是生而平等的吗?(Are GANs created equal)》中看到这些评价指标的实际应用,他们使用三角形数据集来测量不同 GAN 模型的精度和召回率。


三角形数据集衡量不同GAN的精度和召回率

运行你自己的 GAN

为了进一步介绍 GAN,我们采用了来自 Wouter Bulten 的教程,该教程使用 Keras 和 MNIST 数据集训练 GAN,生成手写数字。


(完整笔记链接:https://gist.github.com/ceceshao1/935ea6000c8509a28130d4c55b32fcd6



我们对损失和准确率曲线进行可视化,并且通过Comet.ml检查输出,来记录 GAN 的训练过程。


该 GAN 模型将 MNIST 训练数据和随机噪声作为输入(噪声的随机矢量)来生成以下内容:


  • 图像(在本例中,为手写数字的图像)。最终,这些生成的图像将模拟 MNIST 数据集的数据分布。

  • 判别器对生成图像的预测。


生成器判别器一起构成对抗模型,在本例中,如果对抗模型将生成的图像分类为真实图像,说明生成器训练良好。


完整代码链接:


https://gist.github.com/ceceshao1/935ea6000c8509a28130d4c55b32fcd6


完整实验结果链接:


https://www.comet.ml/ceceshao1/mnist-gan

记录模型的进度

我们可以使用Comet.ml记录生成器判别器模型的训练进度。


我们为判别器和对抗模型绘制准确率和损失曲线图,两个需要记录的最重要的指标如下:


  • 判别器损失(见右图蓝线)-dis_loss

  • 对抗模型的准确率(见左图蓝线)-acc_adv


实验训练过程链接:


https://www.comet.ml/ceceshao1/mnist-gan/cf310adacd724bf280323e2eef92d1cd/chart


判别器和对抗模型的损失和精度曲线图


还需要确认训练过程使用的是 GPU,可以在Comet系统指标选项卡中检查这一项。


GPU内存和使用情况


请注意,我们的训练代码中包括从测试向量中显示图像的代码:


 if i % 500 == 0:        # Visualize the performance of the generator by producing images from the test vector        images = net_generator.predict(vis_noise)        # Map back to original range        #images = (images + 1 ) * 0.5        plt.figure(figsize=(10,10))                for im in range(images.shape[0]):            plt.subplot(4, 4, im+1)            image = images[im, :, :, :]            image = np.reshape(image, [28, 28])                        plt.imshow(image, cmap='gray')            plt.axis('off')                plt.tight_layout()        # plt.savefig('/home/ubuntu/cecelia/deeplearning-resources/output/mnist-normal/{}.png'.format(i))        plt.savefig(r'output/mnist-normal/{}.png'.format(i))        experiment.log_image(r'output/mnist-normal/{}.png'.format(i))        plt.close('all')
复制代码


我们每隔几步显示生成的图像,部分原因是为了让我们能够更直观地分析生成器在生成手写数字方面的表现,以及判别器将生成的数字进行正确分类的表现。


让我们看看这些生成的输出!


在这个Comet实验中,你可以看到生成的输出图像。


可以看到生成器模型是从这个模糊的、灰色的输出(下面的 0.png)开始的,它看起来完全不像我们所期望的手写数字。



随着训练的进展,模型的损失下降,生成的数字变得越来越清晰。在不同步数生成的输出:


步数 500:



步数 1000:



步数 1500:



最后,在步数 10000 时,下图红框标出了一些 GAN 生成的手写数字示例。



一旦 GAN 模型完成了训练,我们可以在 Comet 的图形选项卡中查看由输出组成的动画,只需要按下播放按钮即可。



要完成实验,请确保运行 experiment.end(),可以查看有关模型和 GPU 使用情况的一些统计信息。


模型迭代

我们可以增加模型的训练时间,观察它对性能的影响,但是首先让我们尝试使用一些不同的参数进行模型迭代。


我们调整的参数包括:


  • 判别器的优化器

  • 学习率

  • dropout 随机失活率

  • batchsize 批尺寸


Wouter 在自己的博客文章中,提到了他在测试参数方面做的工作:


“我测试了 SGD、RMSprop 和 Adam 作为判别器的优化器的效果,其中 RMSprop 的性能最好。RMSprop 的学习率很低,我将学习率的值压缩到-1 和 1 之间。学习速率的轻微衰减有助于网络稳定。”


我们尝试将判别器的 dropout 随机失活率从 0.4 提高到 0.5,并同时提高判别器的学习率(从 0.008 提高到 0.0009)和生成器的学习率(从 0.0004 提高到 0.0006)。很容易看到这些变化对训练造成的影响。


要创建一个不同的实验,只需再次运行实验定义单元,Comet将为你的新实验分配一个新的 url。记录下每次的实验是一个好习惯,可以方便比较不同之处:



不幸的是,我们的调整没有提高模型的性能,反而得到了一些奇怪的输出图像:



这就是输出图像呈现出来的效果。面对这样的结果,应该说点什么好呢?


公众号推荐:

2024 年 1 月,InfoQ 研究中心重磅发布《大语言模型综合能力测评报告 2024》,揭示了 10 个大模型在语义理解、文学创作、知识问答等领域的卓越表现。ChatGPT-4、文心一言等领先模型在编程、逻辑推理等方面展现出惊人的进步,预示着大模型将在 2024 年迎来更广泛的应用和创新。关注公众号「AI 前线」,回复「大模型报告」免费获取电子版研究报告。

AI 前线公众号
2019-05-16 11:1517232
用户头像

发布了 52 篇内容, 共 28.1 次阅读, 收获喜欢 72 次。

关注

评论

发布
暂无评论
发现更多内容

数据库审计和日志审计的三大区别分析

行云管家

数据库 日志 日志审计 数据库审计

为安全而生!云安全漫谈开讲啦

浪潮云

云安全 云计算运维

拔掉网线几秒,再插回去,原本的 TCP 连接还存在吗?

程序员小毕

程序员 程序人生 计算机网络 java面试 TCP协议

APISIX 如何与 Hydra 集成,搭建集中认证网关助力企业安全

API7.ai 技术团队

云原生 网关 身份验证 APISIX 网关

Python|分析QQ群聊信息,记录词频并制作词云

AXYZdong

Python 7月月更

阿里云第四届全球数据库大赛火热开赛,40万奖金广纳英才

科技热闻

数据库审计部署方式有哪些?哪种比较好?

行云管家

数据库 数据库审计 数据库审计部署

数据也能进超市

天翼云开发者社区

云计算 大数据 云平台

『51单片机』十分钟学会定时器

謓泽

7月月更

Python 入门指南之标准库概览

海拥(haiyong.site)

7月月更

天翼云携手华为,强强联合,共创数据存储新生态

天翼云开发者社区

存储 数字化

2022数十位Java架构师汇总产出,最新25个技术栈“Java面经”

程序知音

Java 程序员 面试 后端 八股文

易观分析加入智能投研技术联盟,共促行业数智化发展

易观分析

易观新闻

大数据培训Spark数据倾斜问题的解决方法

@零度

spark 大数据开发

内行,阿里大牛离职带出内部“高并发系统设计”学习手册

程序知音

Java 阿里巴巴 程序员 后端 高并发

项目进度管理和风险管理记录

老猎人

java零基础入门-多态

喵手

Java 7月月更

一招,让停车管理不再难

天翼云开发者社区

数字化 云平台

火眼金睛,天翼云助力打造城市视觉中枢

天翼云开发者社区

大数据 云平台

五个核心能力打造普惠金融商业化发展模式

易观分析

普惠金融

小程序表单组件-1

小恺

7月月更

LeetCode-119. 杨辉三角II(java)

bug菌

Leet Code 7月月更

告别缺电焦虑!充电桩装上“智慧大脑”

天翼云开发者社区

云主机 云平台

关于微软 Edge 浏览器的 Tracking Prevention 特性在 Angular 应用中的影响

Jerry Wang

JavaScript typescript Web web开发 7月月更

百家号基于AE的视频渲染技术探索

百度Geek说

视频 视频渲染

大数据ZooKeeper(一):基本知识和集群搭建

Lansonli

大数据 zookeeper 7月月更

【愚公系列】2022年7月 Go教学课程 010-数据类型之布尔型和字符类型

愚公搬代码

7月月更

揭露数据不一致的利器 —— 实时核对系统

Shopee技术团队

数据分析 后端

leetcode 455. Assign Cookies 分发饼干(简)

okokabcd

LeetCode 数据结构与算法 贪心算法

大数据环境搭建:​​​​​​​​​​​​​​​​​​​​​Hadoop编译和分布式环境搭建

Lansonli

大数据 hadoop 环境搭建 7月月更

2022年中国互联网医疗年度盘点

易观分析

互联网医疗

GAN毕业手册:从零到一构建自己的GAN模型_AI&大模型_Cecelia Shao_InfoQ精选文章