生成对抗网络 GAN:生成器与判别器博弈

FreeGuideOnline 最新 2026-06-17

什么是生成对抗网络 (GAN)?

生成对抗网络(Generative Adversarial Network,简称 GAN)是一种深度学习模型,由 Ian Goodfellow 等人在 2014 年提出。它通过两个神经网络的相互博弈来生成逼真的数据,是当前人工智能领域最具创造力的技术之一。

GAN 的核心思想非常直观:假设你想伪造一幅名画,你会不断根据专家的反馈改进赝品,直到专家也无法分辨真伪。在 GAN 中,生成器(Generator) 就像那位伪造者,而判别器(Discriminator) 则扮演鉴定专家的角色。

为什么需要 GAN?

传统生成模型(如变分自编码器 VAE)在生成高维数据时容易产生模糊输出。GAN 通过对抗训练,能够生成极其逼真的图像、音频乃至文本,极大推动了无监督学习的发展。


核心架构:两个玩家的博弈

GAN 的整体结构可以看作一场零和博弈,以下为详细拆解。

生成器(Generator, G)

生成器的任务是从随机噪声中生成数据。它接收一个服从简单分布(通常是高斯分布或均匀分布)的潜变量 z,输出与真实数据形状相同的假样本 G(z)

  • 输入:随机噪声向量(例如长度为 100 的向量)
  • 输出:伪造的数据(例如 28×28 的灰度图像)
  • 目标:生成足以欺骗判别器的样本,最终使判别器输出“接近真”的概率

判别器(Discriminator, D)

判别器是一个二分类器,用于区分输入样本是来自真实数据分布还是生成器。

  • 输入:真实样本 x 或生成样本 G(z)
  • 输出:一个标量,表示输入属于“真”的概率(通常使用 Sigmoid 将值压缩到 [0,1] 区间)
  • 目标:对真实数据输出接近 1,对生成数据输出接近 0

博弈关系

两者相互对抗、交替优化:

  • 生成器试图最大化判别器的判断错误率
  • 判别器试图最小化自己的分类错误率

这个动态过程可以被形式化为如下 Min-Max 游戏:

min_G max_D V(D, G) = E_{x~p_data}[log D(x)] + E_{z~p_z}[log(1 - D(G(z)))]

训练过程详解

GAN 的训练不是一次性更新所有参数,而是交替进行。以下为标准的训练步骤(以随机梯度下降为例)。

步骤一:训练判别器(最大化 D 的目标)

  1. 从真实数据集中取出一批样本 {x},标注为“真”(标签 1)
  2. 从噪声分布中采样一批潜向量 {z},通过生成器得到伪样本 {G(z)},标注为“假”(标签 0)
  3. 计算判别器损失:- [log D(x) + log(1 - D(G(z)))]
  4. 更新判别器参数(在梯度上升方向,或在使用优化器时令损失最小化)

步骤二:训练生成器(最小化 G 的目标)

  1. 重新采样一批噪声 {z}
  2. 生成伪样本 {G(z)},并将这些样本的目标标签设为 1(因为我们想欺骗判别器,让判别器认为它们是真的)
  3. 计算生成器损失:- log D(G(z))(非饱和损失,早期推荐使用,避免梯度消失)
  4. 冻结判别器参数,仅更新生成器参数

超参数与技巧

  • 交替频率:通常每更新一次生成器,更新 k 次判别器(k=1 或 k=5)。当判别器太强时,可设 k>1。
  • 学习率:建议使用 Adam 优化器,学习率设为 2e-4 左右,beta1=0.5 在实践中效果较好。
  • 标签平滑:将真实数据的标签从 1 改为 0.9,可缓解判别器过度自信,稳定训练。

损失函数与改进

原始 GAN 使用交叉熵损失,但存在梯度消失问题。以下是几种重要的损失函数变体。

最小二乘 GAN(LSGAN)

将交叉熵损失替换为最小二乘损失,使判别器输出接近连续值,缓解梯度消失,生成质量更高。

min_D 0.5 * E[(D(x) - 1)^2] + 0.5 * E[(D(G(z)))^2]
min_G 0.5 * E[(D(G(z)) - 1)^2]

Wasserstein GAN(WGAN)

引入 Earth-Mover 距离(Wasserstein-1 距离),从根本上改善训练稳定性。关键改动:

  • 判别器最后一层去掉 Sigmoid,输出无界实数(此时称为“评论器” Critic)
  • 对 critic 执行权重裁剪或使用梯度惩罚(WGAN-GP)
  • 损失函数变为:max_D E[D(x)] - E[D(G(z))],生成器最小化 - E[D(G(z))]

主流 GAN 变体

深度卷积生成对抗网络(DCGAN)

首次将卷积神经网络成功应用于 GAN 架构,奠定了图像生成的基础框架。核心设计原则:

  • 在判别器中使用跨步卷积(strided convolution)替代池化层
  • 在生成器中使用转置卷积(fractionally-strided convolution)进行上采样
  • 生成器和判别器均使用批归一化(batch normalization),但生成器输出层与判别器输入层通常不加
  • 生成器除输出层用 Tanh 外,其余用 ReLU;判别器用 LeakyReLU

条件生成对抗网络(CGAN)

通过向生成器和判别器同时输入额外条件信息(如类别标签),实现可控生成。例如生成指定数字“7”的手写体图片。


实践:用 PyTorch 搭建一个简单 DCGAN

以下代码展示了在 MNIST 手写数字数据集上训练一个简易 DCGAN 的关键片段(仅展示模型定义与训练循环轮廓)。

# 生成器定义
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256), nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(-1, 100, 1, 1)
        return self.main(z)

# 训练循环片段(抽象)
for epoch in range(num_epochs):
    for real_imgs, _ in dataloader:
        # 训练判别器
        z = torch.randn(batch_size, latent_dim)
        fake_imgs = generator(z)
        d_loss = - torch.mean(torch.log(discriminator(real_imgs)) + torch.log(1 - discriminator(fake_imgs.detach())))
        discriminator.zero_grad(); d_loss.backward(); optim_d.step()

        # 训练生成器
        z = torch.randn(batch_size, latent_dim)
        g_loss = - torch.mean(torch.log(discriminator(generator(z))))
        generator.zero_grad(); g_loss.backward(); optim_g.step()

注意:实际应用中推荐使用 WGAN-GP 等更稳定的损失,上述仅作为入门示例。


常见应用领域

  • 图像生成与编辑:生成高清人脸(如 StyleGAN)、超分辨率重建(SRGAN)、图像修复
  • 视频生成:预测下一帧视频、生成短动画
  • 风格迁移与合成:将照片转换为艺术风格(CycleGAN)
  • 数据增强:为医学影像等小样本领域生成更多训练数据
  • 文本到图像生成:基于描述性文字生成对应图片(如 DALL·E 的前身)

训练难点与应对策略

模式崩塌(Mode Collapse)

生成器只生成少数几种样本,无法覆盖真实数据分布的全部模式。

  • 应对:使用小批量判别(Minibatch discrimination)、历史平均(Unrolled GAN)、或改用 WGAN 架构。

训练不稳定

损失振荡、无法收敛。

  • 应对:降低学习率、增加批归一化、使用 WGAN-GP 或 R1 正则化、保持判别器更新次数稍多。

评估困难

缺少像分类准确率那样直观的指标。

  • 常用指标:Inception Score (IS)、Fréchet Inception Distance (FID),FID 越低越好。

总结

生成对抗网络通过生成器与判别器的动态博弈,让模型具备了惊人的生成能力。尽管训练存在挑战,但一系列改进(DCGAN、WGAN、StyleGAN 等)已将其推向实际应用。掌握 GAN 的基本原理和训练技巧,你就拿到了通往创造性 AI 大门的钥匙。

下一步可以尝试动手实现一个 DCGAN,并逐步探索条件生成或风格迁移任务。