GAN 生成对抗网络入门
1. 简介
神经网络分很多种,有普通的前向传播网络,有分析图片的CNN卷积神经网络,有分析序列化数据的RNN循环神经网络,这些都是根据数据去预测结果。而对于生成网络而言,则是”捏造结果”了。GAN(Generative Adversarial Networks,生成对抗网络)就是其中一种生成网络,是根据随机数来生成结果,提供了一种不需要大量的标注训练数据就能学习深度表征的方式。
相比于其他的生成模型,GAN有两大特点:
- 不依赖任何的先验假设:传统的许多方法会假设数据服从某一分布,然后使用极大似然去估计数据分布。
- 生成 real-like 样本的方式非常简单:通过生成器(Generator)的前向传播。
1.1 生成模型
生成模型是指能够随机生成观测数据的模型,尤其是在给定某些隐含参数的条件下。
1.2 形象描述
生成对抗网络的过程可以用一个新手画家和一个新手鉴赏家的例子来说明。新手画家要画一幅达芬奇的画,但是他刚开始学,不知道达芬奇的画风格要怎么呈现,于是他用新手的灵感画画,画得一团糟。这时候有一个新手鉴赏家,他接受到了一些画作,但是他不知道哪些是新手画家画的,哪些是达芬奇画的。新手鉴赏家说出了自己的判断,你来纠正他的判断。于是新手鉴赏家就一边判断,一边告诉新手画家要怎么作画才能更加接近达芬奇的画作。新手画家也就一步步学会怎么画的更像达芬奇的风格了。如下,包含一个生成网络和一个判别网络。
1.3 原理
有两个网络,G(Generator)和D(Descriminator)。G是生成图片的网络,接收一个隐变量z(通常为服从高斯分布的随机噪声),并生成图片G(z);D是判别网络,判定一张图片是不是真实的,输入x为一张图片,输出的D(x)代表x为真实图片的概率,输出为1表示百分百是真实图片,输出为0表示不可能为真实的图片。
训练的过程,就是G生成尽量真实的图片去欺骗D,而D则是尽量把G生成的图片和真实的图片区分开,二者就形成了一个博弈的过程。而在最理想的情况下,博弈的最终结果,就是D完全无法区分开G的作品和真实的图片,也就是D(G(z)) = 0.5。
- 寻找生成模型和判别模型之间的一个纳什均衡!
数学语言描述:(优化的目标函数)
对判别器D来说,这是一个二分类问题。V(D,G) 为二分类问题中常见的交叉熵损失。
对于生成器 G 来说,为了尽可能欺骗 D,所以需要最大化生成样本的判别概率 D(G(z))(最小化 log(1-D(G(z))) )。生成器优化的目标是:最小化 V(D,G) 的最大值。
D和G采取交替训练的方法,先训练D,目标是希望V(G, D)越大越好,所以采用梯度上升;之后训练G,目标是希望V(G, D)越小越好,所以采用梯度下降
- 生成器G固定时,可以对V(D, G)求导,求出最佳判别器(已经被作者证明该D*(x)存在且唯一)
- 判别函数代入上面的目标函数,可以求出在最佳判别器下,生成器的目标函数等价于优化 Pdata(x) , Pg(x) 的 JS 散度 (JS-Divergence)。
- 全局的优化目标为Pg = Pdata。经过若干次训练之后,如果G和D有一定的复杂度,那么二者会达到Pg = Pdata这个均衡点,即生成器的密度概率函数等于真实数据的密度概率函数,即生成数据和密度数据一样,D(x) = 1/2。
训练总结起来有以下步骤:
- 参数优化过程
要寻找最优的生成器,那给定一个判别器,可以将 max V(G,D) 看作训练生成器的损失函数 L(G)。设定了损失函数之后,就可以利用Adam等优化算法通过梯度下降来更新生成器G的参数。
- 给定 G_0,最大化 V(G_0,D),求得 D_0*,即 max[JSD(P_data(x)||P_G0(x)];
- 固定 D_0*,计算θ_G1 ← θ_G0 −η(dV(G,D_0*) /dθ_G) 以求得更新后的 G_1;
- 固定 G_1,最大化 V(G_1,D_0*) 以求得 D_1*,即 max[JSD(P_data(x)||P_G1(x)];
- 固定 D_1*,计算θ_G2 ← θ_G1 −η(dV(G,D_0*) /dθ_G) 以求得更新后的 G_2;
- 。。。
- 实际训练过程
根据价值函数 V(G,D) 的定义,要求两个数学期望, E[log(D(x))] 和 E[log(1-D(G(z)))], x 服从真实数据分布,z 服从初始化分布。实践中没法利用积分求解,所以一般是从无穷的真实数据和无穷的生成器中做采样以逼近真实的数学期望。
最大化价值函数:P_data(x) 采样 m 个样本, P_G(x) 采样 m 个样本
最小化损失函数:P_data(x) 抽取样本作为正样本,从 P_G(x) 抽取样本作为负样本,同时将逼近 -V(G,D) 的函数作为损失函数。
——用迭代和数值计算的方式实现极小极大化博弈过程。
从真实数据分布 P_data 抽取 m 个样本
从先验分布 P_prior(z) 抽取 m 个噪声样本
将噪声样本投入 G 而生成数据,通过最大化 V 的近似而更新判别器参数θ_d,即极大化,判别器参数的更新迭代式:
—— 以上是学习判别器D的过程。学习D的过程是计算JS散度的过程,我们希望最大化价值函数,所以会进行K次迭代。实践中好像一般K取1也足够。
从先验分布 P_prior(z) 中抽取另外 m 个噪声样本 {z^1,…,z^m}
生成器参数的更新迭代式:
——以上是学习生成器G的过程。为避免更新太多使得JS散度上升,在一次迭代中只进行一次。
1.4 优缺点
优点:无需马尔可夫链,仅仅使用反向传播来获得梯度,学习间无需推理,且模型中可以融合入多种函数。
缺点:需要同步D和G;p(x)的隐式表示(???不太懂)
- 生成对抗网络模型和其他生成模型之间的对比:
2. 其他常见生成式模型
2.1 PixelCNN 和 PixelRNN
对图像数据的概率分布Pdata(x)进行显式建模,并利用极大似然估计优化模型。给定 x1,x2,…,xi-1 条件下,所有 p(xi) 的概率乘起来就是图像数据的分布。如果使用 RNN 对上述依然关系建模,就是 pixelRNN。如果使用 CNN,则是 pixelCNN。
优点:定义了一个易于处理的密度函数,可以直接优化训练数据的似然。
缺点:像素值是从图像的一个角落开始,一个个生成的,所以速度会很慢。
2.2 VAE: Variational Auto-Encode 变分自编码器
真实样本X通过神经网络计算出均值方差,假设隐变量服从正态分布。然后通过采样得到采样变量Z进行重构。VAE和GAN均学习了隐变量 z 到真实数据分布的映射,但VAE的不同之处在于:
- GAN思路直接,使用一个判别器去度量分布转换模块(即生成器)生成分布与真实数据分布的距离。
- VAE委婉,通过约束隐变量 z 服从标准正态分布以及重构数据实现了分布转换映射 X=G(z)。
3. GAN常见模型结构
3.1 DCGAN
提出使用CNN结构来稳定GAN的训练,这允许了生成器和判别器学习优秀的上采样和下采样操作,这些操作可能提升图像合成的质量。
3.2 层级结构
GAN 对于高分辨率图像生成一直存在许多问题,层级结构的 GAN 通过逐层次,分阶段生成,一步步提生图像的分辨率。
- 使用多对GAN:StackGAN、GoGAN
- 单一GAN,多阶段生成:ProgressiveGAN
3.3 自编码结构
BEGAN,EBGAN,MAGAN ……
4. GAN存在的问题
不收敛问题:由于GAN是采用极小极大博弈,D 在进行梯度下降时,使得在损失流形上下降,而G使其上升,可能造成两者的梯度相互抵消,最终在最优点附近徘徊。因此,不收敛问题也是GAN所面临的最大的问题。目前采用的优化方法都是采用的启发式的方法。WGAN也在一定程度上解决了收敛不稳定的问题。
模式崩溃:生成器”崩溃”,即用不同的输入生成相似的样本。可以理解成多样性问题,当G生成了一张比较真实的图片之后,就不再学习其他的分布,而仅靠这一张图片来欺骗,仅仅收敛到一种模式。这样即便训练时间再长也不会有好的结果。
梯度消失:判别器的损失很快收敛为零,从而没有足够强的梯度路径可以继续更新生成器
离散输出问题:GAN对生成器的唯一要求就是——生成器表示的函数必须可导,因此,GAN似乎无法用于离散输出,如文本。目前,仍未有将GAN应用于NLP领域。目前可能解决该问题三个可能的方向:采用强化学习、采用具体的分布、训练生成器产生连续的输出值,并将其编码为离散值。
5. GAN的发展
5.1 GAN: Generative Adversarial Networks, 2014
论文地址: arxiv.org/abs/1406.26…
“GAN之父” Ian Goodfellow 发表的第一篇提出 GAN 的论文,提出了 GAN 这个模型框架,讨论了非饱和的损失函数,然后对于最佳判别器(optimal discriminator)给出其导数,然后进行证明;最后是在 Mnist、TFD、CIFAR-10 数据集上进行了实验。
结论和未来的研究方向:
- 条件生成模型p(x∣c)可以通过将c作为G和D的输入来获得。
- 半监督学习:提供适量的带标签数据,以提高判别网络或推理网络的特征分类效果。
- 效率改善:设计更好的方法来协调D和G,或确定更好的分布来对 z 进行采样,以提高训练速率。
- ……
5.2 Conditional GAN, 2014
之前的GAN是无监督模型,但给生成器提供随机噪声的话效果往往没有那么好。cGAN的提出将其拉回监督学习领域,缓和了GAN的训练不稳定的问题。
5.3 DCGAN
Deep Convolutional GAN
第一次采用CNN结构实现GAN。将G和D换成两个CNN,但对CNN的结构做出了一些改变来提高样本质量和收敛速度:
- 取消池化层,G网络中使用转置卷积层(transposed convolutional layer)进行上采样,D网络中用加入stride的卷积代替pooling。
- D和G中均使用batch normalization(???)
- 去掉全连接层,使网络变为全卷积网络
- G网络中使用ReLU作为激活函数,最后一层使用tanh
- D网络中使用LeakyReLU作为激活函数
- 采用Adam优化算法,学习率是0.0002,beta1=0.5
DCGAN中的G网络:
这篇论文介绍了如何使用卷积层,并给出一些额外的结构上的指导建议来实现。另外,它还讨论如何可视化 GAN 的特征、隐空间的插值、利用判别器特征训练分类器以及评估结果。
5.4 Improved Techniques for Training GANs
作者之一是 Ian Goodfellow。论文介绍了很多如何构建一个 GAN 结构的建议,可以帮助理解 GAN 不稳定的原因,给出稳定训练 DCGANs 的建议,比如特征匹配(feature matching)、最小批次判别(minibatch discrimination)、单边标签平滑(one-sided label smoothing)、虚拟批归一化(virtual batch normalization)等等。
5.5 Pix2Pix
Image-to-Image Translation with Conditional Adversarial Networks
目标是实现图像转换:语义图转街景,黑白图片上色,素描图变真实照片等。
在训练时候需要采用成对的训练数据,并对 GAN 模型采用了不同的配置。其中它应用到了 PatchGAN 这个模型,PatchGAN 对图片的一块 70*70 大小的区域进行观察来判断该图片是真是假,而不需要观察整张图片。生成器部分使用 U-Net 结构,即结合了 ResNet 网络中的 skip connections 技术,编码器和解码器对应层之间有相互连接。
5.6 CycleGAN
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
和上面的Pix2Pix不同,不需要原图和转换后的图来训练,仅仅需要准备两个领域的数据集即可,比如说普通马的图片和斑马的图片,但不需要一一对应。这篇论文提出了一个非常好的方法–循环一致性(Cycle-Consistency)损失函数。
5.7 Progressive Growing of GANs
Progressive Growing of GANs for Improved Quality, Stability, and Variation
利用一个多尺度结构,从 4*4
到 8*8
一直提升到 1024*1024
的分辨率,如下图所示的结构,这篇论文提出了一些如何解决由于目标图片尺寸导致的不稳定问题。
5.8 StackGAN
StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks
StackGAN 和 cGAN、Progressively GANs 两篇论文比较相似,它同样采用了先验知识,以及多尺度方法。整个网络结构如下图所示,第一阶段根据给定文本描述和随机噪声,然后输出 64*64
的图片,接着将其作为先验知识,再次生成 256*256
大小的图片。相比前面 7 篇论文,StackGAN 通过一个文本向量来引入文本信息,并提取一些视觉特征。
5.9 BigGAN
Large Scale GAN Training for High Fidelity Natural Image Synthesis
当前 ImageNet 上图片生成最好的模型,但这篇论文比较难在本地电脑上进行复现。它同时结合了很多结构和技术,包括自注意机制(Self-Attention)、谱归一化(Spectral Normalization)等。
6. GAN的实现
6.1 MNIST数据集上的应用
- 进行GAN实验时候,只是将二维的数据拉伸成一维,没有用到卷积,只是多层神经网络的叠加。
- 生成器和判别器使用不同的激活函数
1 |
|
完整代码:
1 |
|
1 |
|
运行结果:
Epoch 1:
Epoch 10:
Epoch 50:
Epoch 100:
Epoch 300:
6.2 在图网络中的应用
GraphGAN采用GAN网络中常见的对抗机制:生成器G尽可能的逼近Ptrue(V|Vc)以找到与Vc的相邻节点极其相似的节点来欺骗判别器D,而判别器D则会反过来检测给定的节点V是Vc的真实邻居还是由生成器生成的。
Reference
- 什么是 GAN 生成对抗网络 (深度学习)? - 莫烦Python
- GAN学习指南:从原理入门到制作生成Demo
- 万字综述之生成对抗网络 (GAN)
- 生成式对抗网络GAN的研究进展与展望
- 必读的10篇关于GAN的论文
- GAN笔记——理论与实现