Skip to content

生成对抗网络

Generative adversarial network [2014]

Goodfellow I , Pouget-Abadie J , Mirza M ,et al.Generative Adversarial Nets[J].MIT Press, 2014.DOI:10.3156/JSOFT.29.5_177_2.

核心思想

GAN 基于零和博弈思想:

  • 生成器(Generator):学习生成与真实数据相似的假数据。
  • 判别器(Discriminator):区分输入数据是来自真实分布还是生成器。
  • 两者对抗优化,最终达到纳什均衡(生成器生成的样本与真实数据无法区分)。

结构

生成器

  • 作用:将随机噪声映射到数据空间,生成假样本
  • 输入:随机噪声向量 zpz(z)(通常为高斯分布或均匀分布)
  • 输出:生成样本 G(z)

判别器

  • 作用:区分真实样本与生成样本
  • 输入:真实数据样本 x 或生成样本 G(z)
  • 输出:样本属于真实数据的概率 D(x)[0,1]

目标函数

minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
符号含义
G生成器网络
D判别器网络
x真实数据样本
z随机噪声向量
pdata(x)真实数据分布
pz(z)噪声先验分布
G(z)生成器生成的样本
D(x)判别器对真实样本的判断概率
D(G(z))判别器对生成样本的判断概率

判别器的最大化目标

对于第一项:Expdata(x)[logD(x)],判别器需要对真实数据输出接近于 1 的值(D(x)1

对于第二项:Ezpz(z)[log(1D(G(z)))],判别器需要对生成数据输出接近于 0 的值(D(G(z))0

生成器的最小化目标

生成器试图通过调整参数使判别器的性能最差,即让生成数据 G(z) 被判别器误判为真实数据。这使得:判别器对生成数据的输出 D(G(z))1,从而导致第二项 log(1D(G(z)))

训练

固定生成器 G,更新判别器 D

训练目标:

D=argmaxDV(D,G)

有:

V(G,D)=xpdata(x)logD(x)dx+xpz(z)log(1D(G(z)))dz=x[pdata(x)logD(x)+pg(x)log(1D(x))]dx

求导求出 D

D(x)(pdata (x)log(D(x))+pg(x)log(1D(x)))=0pdata (x)D(x)pg(x)1D(x)=0D(x)=pdata (x)pdata (x)+pg(x)C(G)=maxDV(G,D)=Expdata [logpdata (x)Pdata (x)+pg(x)]+Expg[logpg(x)pdata (x)+pg(x)]=Expdata [logpdata (x)Pdata (x)+pg(x)2]+Expg[logpg(x)Pdata (x)+pg(x)2]=log(4)+KL(pdata pdata +pg2)+KL(pgpdata +pg2)=log(4)+2JSD(pdata pg)

GAN 的目标是使:pdata=pg,则有:D=1/2,此时,minGC(G)=log4