GAN 毕业手册:从零到一构建自己的 GAN 模型

阅读数:4753 2019 年 5 月 16 日

本文首先对生成对抗网络(GAN)的发展进行了简单的介绍。随后深入探讨了 GAN 的网络原理,并介绍了网络的评价指标。最后带领读者创建并运行自己的 GAN,利用 MNIST 数据集训练网络,用 Comet.ml 对实验数据和参数进行分析,并生成手写数字。

image

生成对抗网络(GAN)的能力已经超乎公众的想象——由AI 生成的名人照片影响了流行文化,创造的艺术品在高级艺术品拍卖会上售出了数千美元的价格

在本文中,我们将探讨:

  • GAN 的简要介绍

  • 了解和评价 GAN

  • 运行你自己的 GAN

现在网络上有大量的 GAN 学习资源,因此本文的重点是了解如何评价 GAN。我们还将引导读者通过运行自己的 GAN,来生成像 MNIST 数据集那样的手写数字。

image

GAN 的训练过程图,从图中可以看出生成的手写数字随着训练过程变得越来越真实

GAN 入门

自 2014 年 Ian Goodfellow 的《生成对抗网络(Generative Adversarial Networks)》论文发表以来,GAN 的进展突飞猛进,生成结果也越来越具有照片真实感。

就在三年前,Ian Goodfellow 在reddit 上回答 GAN 是否可以应用在文本领域的问题时,还认为 GAN 不能扩展到文本领域。

image

“由于 GAN 定义在实值数据上,因此 GAN 不能应用于 NLP。

GAN 的工作原理是训练一个生成网络,输出合成数据,然后利用判别网络判别合成数据。判别网络根据合成数据输出的梯度告诉你该如何对合成数据进行微调,使其更真实。

因此只有当合成数据是基于连续数字时,才能对其进行微调。如果是基于离散的数字,就没有办法做微小的改变。

例如,如果输出像素值为 1.0 的图像,则下一步可以将该像素值更改为 1.0001。

但如果输出单词‘penguin’,不能在下一步直接将其更改为‘penguin+.001’,因为没有‘penguin+.001’这样的单词。你必须从‘penguin’直接转变到‘ostrich’。

由于所有的 NLP 都是基于离散的值,如单词、字符或字节,所以目前还没有人知道该如何将 GAN 应用于 NLP。”

但是现在,GAN 已经可用于生成各种内容,包括图像、视频、音频和文本。这些输出的合成数据既可以用于训练其他的模型,也可以用于创建一些有趣的项目,如thispersonnesnotexist.com(这个人不存在)thisairbnbdoensotexist.com(这家酒店不存在),甚至This Machine Learning Medium post does not exist(这个机器学习帖子不存在)

GAN 原理

GAN 由两个神经网络组成,一个是合成新样本的生成器,另一个是对比训练样本与生成样本的判别器。判别器的目标是区分“真实”和“虚假”的输入(对样本来自模型分布还是真实分布进行分类)。这些样本可以是图像、视频、音频片段和文本。

image

为了合成这些新的样本,生成器的输入为随机噪声,然后尝试从训练数据中学习到的分布中生成真实的图像。

判别器网络(卷积神经网络)输出相对于合成数据的梯度,其中包含着如何改变合成数据以使其更具真实感的信息。最终生成器收敛,它可以生成符合真实数据分布的样本,而判别器无法区分生成数据和真实数据。

GAN Lab是一个可以在浏览器上训练 GAN 的项目,可以利用这个项目查看网络训练过程中数据分布逐渐收敛的动画过程。

这里给大家推荐一些 GAN 的精选介绍指南:

了解和评价 GAN

量化 GAN 的表现让人感觉是个非常主观的问题——“这张生成的脸看起来够真实吗?”,“这些生成的图像足够多样化吗?”——GAN 就像是黑匣子,我们不清楚模型的哪些部分会影响学习过程或结果的质量。

为此,麻省理工学院计算机科学与人工智能实验室(CSAIL)的一个小组最近发表了一篇论文《GAN 解剖:可视化和理解生成对抗网络(GAN Dissection: Visualizing and Understanding Generative Adversarial Networks)》,其中介绍了一种可视化 GAN 的方法,并且分析了 GAN 单元如何与图像中的对象相关联,以及对象之间的关系。

该论文采用基于分割的网络解析方法,对生成神经网络的内部原理进行了分析和可视化。通过寻找 GAN 单元(神经元)和输出图像中的概念(如树、天空、云等)之间的一致性,我们能够鉴别出负责特定对象(如建筑物或云)的神经元。

image

通过强制激活和冻结(消除)这些对象的对应神经元,可以从这种粒度级别控制神经元,对生成图像进行编辑(例如,添加或删除图像中的树)。

然而,我们现在还不清楚网络是否能够解释场景中的对象,或者网络只是记忆这些对象。回答这个问题一个方法是用不实际的方式对图像进行变形。MIT CSAIL 给出了一个GAN 绘画交互网页演示,其中最令人印象深刻的是,模型似乎能够将这些人为编辑限制在“真实照片”的变化范围内。如果你试图把草地强行画在天空上,会发生以下情况:

image

尽管我们激活了相应的神经元,但看起来好像 GAN 后面几层的神经元抑制了相应的信号。

image

对象的局部上下文对合成对象可能性的影响(在图中,在建筑物上生成门的概率比在树上或天空中生成门的概率要高).

另一个可视化 GAN 的方法是进行隐空间插值(GAN 通过从所学的隐空间采样生成新的实例)。这是查看生成样本之间的转换是否平滑的一种有效方法。

image

虽然这些可视化方法能够帮助我们理解 GAN 的内部表示,但是如何找到量化的方法来理解 GAN 的训练过程和输出质量仍然是一个活跃的研究方向。

衡量生成图像质量和多样性的两个常用评价指标是:起始分数(Inception Score,IS)和Fréchet 初始距离(Fréchet Inception Distance, FID)。在论文《关于起始分数的笔记(A Note on the Inception Score)》指出了前者的重要缺点之后,大多数研究人员已经从 IS 转向了 FID。

Inception Score

IS 由 Salimans 等人在 2016 年的论文《训练 GAN 的技术改进(Improved Techniques for Training GANs)》中提出。Inception Score 的提出受到一个观点的启发,即真实样本应该能被预训练的分类网络分类,例如预训练的 ImageNet 网络。从技术层面来讲,样本的 softmax 预测矢量的熵应该较低。

除了高可预测性(低熵),Inception Score 也从生成样本的多样性对 GAN 进行评价(例如生产样本分布的方差或熵较高),这意味着分类结果中不应该出现任何占支配性的类别。

如果这两个特点都满足,Inception Score 的值就会很高。将这两个标准结合起来的方法是计算样本的条件标签分布与所有样本的边缘分布之间的 KL 散度。

Fréchet Inception Distance

FID 由Heusel 等人在 2017 年提出,FID 通过测量图像生成分布和真实分布之间的距离来衡量生成样本的真实性。FID 通过 Inception Net 的一个特殊层将一组生成的样本嵌入特征空间。该嵌入层可以被视为一个连续的多元高斯分布,然后计算生成数据和真实数据的平均值和协方差。这两个高斯分布之间的 Fréchet 距离(即 Wasserstein-2 距离)可以对生成样本的质量进行量化。FID 较低,则代表生成样本与真实样本更相似。

重要的一点是,FID 需要一定的样本数量才能得到良好的结果(建议 50K 个样本)。如果使用的样本数量太少,就会出现过度估计,并且估计值会有很大的方差。

Neal Jean 的博客中对比了不同论文中的 Inception Score 和 FID 分数:https://nealjean.com/ml/frechet-inception-distance/

了解更多

Aji Borji 的论文《GAN 评价策略的好与坏(Pros and Cons of GAN Evaluation Measures)》中给出了一个更全面的表格,涵盖了更多的 GAN 评价指标:

image

有趣的是,其他研究人员并没有采用这两种方法,而是采用了其研究领域特定的评价指标。对于文本 GAN,Guy Tevet 和他的团队在论文《将文本 GAN 作为语言模型进行评价(Evaluating Text GANs as Language Models)》中提出,使用传统的基于概率的语言模型指标来评价 GAN 生成文本的分布。

在论文《我的 GAN 有多好(How good is my GAN)》中,Konstantin Shmelkov 和他的团队使用了两种基于图像分类的评价指标:GAN 训练和 GAN 测试,分别估算 GAN 的召回率(多样性)和准确率(图像质量)。你可以在 Google Brain 的研究论文《GAN 是生而平等的吗?(Are GANs created equal)》中看到这些评价指标的实际应用,他们使用三角形数据集来测量不同 GAN 模型的精度和召回率。

image

运行你自己的 GAN

为了进一步介绍 GAN,我们采用了来自 Wouter Bulten 的教程,该教程使用 Keras 和 MNIST 数据集训练 GAN,生成手写数字。

(完整笔记链接:https://gist.github.com/ceceshao1/935ea6000c8509a28130d4c55b32fcd6

image

我们对损失和准确率曲线进行可视化,并且通过Comet.ml检查输出,来记录 GAN 的训练过程。

该 GAN 模型将 MNIST 训练数据和随机噪声作为输入(噪声的随机矢量)来生成以下内容:

  • 图像(在本例中,为手写数字的图像)。最终,这些生成的图像将模拟 MNIST 数据集的数据分布。

  • 判别器对生成图像的预测。

生成器判别器一起构成对抗模型,在本例中,如果对抗模型将生成的图像分类为真实图像,说明生成器训练良好。

完整代码链接:
https://gist.github.com/ceceshao1/935ea6000c8509a28130d4c55b32fcd6

完整实验结果链接:
https://www.comet.ml/ceceshao1/mnist-gan

记录模型的进度

我们可以使用Comet.ml记录生成器判别器模型的训练进度。

我们为判别器和对抗模型绘制准确率和损失曲线图,两个需要记录的最重要的指标如下:

  • 判别器损失(见右图蓝线)-dis_loss

  • 对抗模型的准确率(见左图蓝线)-acc_adv

实验训练过程链接:
https://www.comet.ml/ceceshao1/mnist-gan/cf310adacd724bf280323e2eef92d1cd/chart

image

还需要确认训练过程使用的是 GPU,可以在Comet 系统指标选项卡中检查这一项。

image

请注意,我们的训练代码中包括从测试向量中显示图像的代码:

复制代码
if i % 500 == 0:
# Visualize the performance of the generator by producing images from the test vector
images = net_generator.predict(vis_noise)
# Map back to original range
#images = (images + 1 ) * 0.5
plt.figure(figsize=(10,10))
for im in range(images.shape[0]):
plt.subplot(4, 4, im+1)
image = images[im, :, :, :]
image = np.reshape(image, [28, 28])
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.tight_layout()
# plt.savefig('/home/ubuntu/cecelia/deeplearning-resources/output/mnist-normal/{}.png'.format(i))
plt.savefig(r'output/mnist-normal/{}.png'.format(i))
experiment.log_image(r'output/mnist-normal/{}.png'.format(i))
plt.close('all')

我们每隔几步显示生成的图像,部分原因是为了让我们能够更直观地分析生成器在生成手写数字方面的表现,以及判别器将生成的数字进行正确分类的表现。

让我们看看这些生成的输出!

在这个Comet 实验中,你可以看到生成的输出图像。

可以看到生成器模型是从这个模糊的、灰色的输出(下面的 0.png)开始的,它看起来完全不像我们所期望的手写数字。

image

随着训练的进展,模型的损失下降,生成的数字变得越来越清晰。在不同步数生成的输出:

步数 500:

image

步数 1000:

image

步数 1500:

image

最后,在步数 10000时,下图红框标出了一些 GAN 生成的手写数字示例。

image

一旦 GAN 模型完成了训练,我们可以在 Comet 的图形选项卡中查看由输出组成的动画,只需要按下播放按钮即可。

image

要完成实验,请确保运行 experiment.end(),可以查看有关模型和 GPU 使用情况的一些统计信息。

image

模型迭代

我们可以增加模型的训练时间,观察它对性能的影响,但是首先让我们尝试使用一些不同的参数进行模型迭代。

我们调整的参数包括:

  • 判别器的优化器

  • 学习率

  • dropout 随机失活率

  • batchsize 批尺寸

Wouter 在自己的博客文章中,提到了他在测试参数方面做的工作:

“我测试了 SGD、RMSprop 和 Adam 作为判别器的优化器的效果,其中 RMSprop 的性能最好。RMSprop 的学习率很低,我将学习率的值压缩到 -1 和 1 之间。学习速率的轻微衰减有助于网络稳定。”

我们尝试将判别器的 dropout 随机失活率从 0.4 提高到 0.5,并同时提高判别器的学习率(从 0.008 提高到 0.0009)和生成器的学习率(从 0.0004 提高到 0.0006)。很容易看到这些变化对训练造成的影响。

要创建一个不同的实验,只需再次运行实验定义单元,Comet将为你的新实验分配一个新的 url。记录下每次的实验是一个好习惯,可以方便比较不同之处:

image

不幸的是,我们的调整没有提高模型的性能,反而得到了一些奇怪的输出图像:

image

这就是输出图像呈现出来的效果。面对这样的结果,应该说点什么好呢?

收藏

评论

微博

发表评论

注册/登录 InfoQ 发表评论