SRGAN:生成对抗网络驱动的图像超分辨率
SRGAN:生成对抗网络驱动的图像超分辨率
1. 什么是超分辨率重建?
超分辨率重建(Super-Resolution,简称SR)是指从一张或多张低分辨率图像中恢复出高分辨率图像的技术。它广泛用于医学影像、卫星遥感、安防监控以及老照片修复等场景。传统方法依赖插值(如双三次插值)或基于稀疏编码的方式,但这些方法在恢复高频细节时能力有限,生成的图像往往过于平滑,缺乏真实感。
近年来,深度学习极大地推动了超分辨率的发展,其中 SRCNN、VDSR 等模型显著提升了峰值信噪比(PSNR)和结构相似性(SSIM)。然而,这些模型使用的像素级损失(如MSE)会使输出趋于模糊,无法恢复纹理等高频细节。SRGAN 的出现正是为了突破这一瓶颈——它首次将生成对抗网络(GAN)引入超分辨率任务,追求更符合人类视觉感知的逼真效果。
2. 生成对抗网络(GAN)基础
生成对抗网络由生成器 G 和判别器 D 组成。生成器负责从随机噪声或条件输入中生成“假”样本,判别器则试图区分真实样本和生成样本。两者通过极小极大博弈进行对抗训练:
- 生成器目标:最大化判别器将生成样本误判为真实样本的概率。
- 判别器目标:最大化正确分类真实样本和生成样本的概率。
这种对抗机制迫使生成器学习真实数据的分布,从而产生更自然、更清晰的输出。在 SRGAN 中,生成器的输入不再是噪声,而是低分辨率图像,输出则为相应的高分辨率图像。
3. SRGAN 核心架构
SRGAN 架构由生成器网络和判别器网络组成,两者通过对抗损失和内容损失联合优化。
3.1 生成器:深度残差网络
生成器采用类似 ResNet 的结构,主要由三部分构成:
- 浅层特征提取:一个 9×9 卷积 + PReLU 激活,将低分辨率图像(通常先经过双三次上采样至目标尺寸)映射到高维特征空间。
- 残差块堆叠:包含多个(如 16 个)相同的残差块。每个残差块内部使用 3×3 卷积、批归一化(BN)和 PReLU,再通过跳跃连接保留输入信息。残差结构有助于稳定深层网络的训练。
- 上采样与重建:通过两个亚像素卷积层(PixelShuffle)实现 2× 或 4× 上采样。亚像素卷积是一种高效的上采样方式,它将通道维度的特征重新排列成空间维度,避免了转置卷积带来的棋盘效应。最后使用一个 9×9 卷积生成三通道 RGB 高分辨率图像。
整体生成器可以端到端地将低分辨率输入直接映射为高分辨率输出,并通过残差块专注于学习高频残差信息。
3.2 判别器:VGG 风格卷积网络
判别器采用类似 VGG 网络的深度卷积结构,专门用于区分生成的高分辨率图像与真实高分辨率图像。主要设计为:
- 从 64 到 512 通道数的连续卷积块(3×3 卷积 + BN + LeakyReLU),步长为 2 或 1,逐步压缩空间尺寸并提取高层特征。
- 使用 LeakyReLU(斜率 0.2)代替 ReLU 防止梯度稀疏。
- 经过多次卷积后得到 512 通道的特征图,通过全局平均池化或自适应平均池化后接两个全连接层,最终输出一个标量,经 Sigmoid 函数映射为真实图像的概率。
判别器不做像素级分类,而是对整张图像给出一个真实性评分,促使生成器在全局结构上更自然。
3.3 损失函数设计
SRGAN 的损失函数是内容损失和对抗损失的加权和,这是其产生逼真纹理的关键。
(1)内容损失——感知损失
传统超分辨率常用逐像素 MSE 损失(像素空间内容损失),但 SRGAN 转而采用基于 VGG-19 网络的感知损失。它计算生成图像与真实高分辨率图像在预训练 VGG-19 某一高层特征图(如 relu5_4)之间的欧氏距离:
$$ L_{\text{content}} = \frac{1}{W_{i,j}H_{i,j}} \sum_{x=1}^{W_{i,j}} \sum_{y=1}^{H_{i,j}} (\phi_{i,j}(I^{HR}){x,y} - \phi{i,j}(G_{\theta_G}(I^{LR}))_{x,y})^2 $$
这里的 $\phi_{i,j}$ 表示 VGG-19 网络中第 i 个池化层前第 j 个卷积层输出的特征映射。感知损失更关注图像语义和结构相似性,而非逐像素数值匹配,因此允许生成器重建出更清晰的边缘和纹理。
(2)对抗损失
对抗损失鼓励生成器产生难以被判别器区分的图像,使生成结果进入更真实的流形:
$$ L_{\text{adv}} = \sum_{n=1}^{N} -\log D_{\theta_D}(G_{\theta_G}(I^{LR})) $$
其中 $D_{\theta_D}(G(I^{LR}))$ 是判别器认为生成图像为真实图像的概率。
(3)整体生成器损失
生成器的总损失为:
$$ L_G = L_{\text{content}} + 10^{-3} L_{\text{adv}} $$
权重 $10^{-3}$ 用于平衡两项损失的尺度,保证以内容损失为主导,同时对抗损失提供细节驱动。判别器损失则采用传统二分类交叉熵:
$$ L_D = -\log D(I^{HR}) - \log(1 - D(G(I^{LR}))) $$
4. SRGAN 的训练流程
-
数据准备:使用高分辨率图像数据集(如 DIV2K、Flickr2K)。利用双三次下采样生成对应的低分辨率图像,并将低分辨率图像上采样至目标尺寸(如 4× 时使用双三次上采样后送入网络)。图像通常以随机裁剪的 96×96(高分辨率对应块)进行训练,并做随机水平翻转等增强。
-
预训练:为了稳定训练,通常先单独使用内容损失(仅感知损失或 MSE)训练生成器一定轮数,使其获得基本的超分能力。
-
联合对坑训练:固定生成器,训练判别器区分真实 HR 图像和生成的 SR 图像;然后固定判别器,训练生成器以最小化生成器总损失。这一步交替进行,通常判别器更新一次,生成器更新一次或更多次(如 1:1)。
-
超参数设置:使用 Adam 优化器,初始学习率常设为 $10^{-4}$,并逐渐衰减。批次大小根据 GPU 显存选择 16 或 32。
-
收敛判断:观察感知损失和对抗损失在验证集上的表现。当生成图像的视觉效果稳定且无明显伪影时停止训练。
5. 代码实现关键片段(PyTorch)
以下展示生成器的核心结构代码(简化版,4× 上采样):
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, 3, 1, 1),
nn.BatchNorm2d(channels),
nn.PReLU(),
nn.Conv2d(channels, channels, 3, 1, 1),
nn.BatchNorm2d(channels)
)
def forward(self, x):
return x + self.block(x)
class Generator(nn.Module):
def __init__(self, scale_factor=4, num_res_blocks=16):
super().__init__()
self.conv1 = nn.Sequential(nn.Conv2d(3, 64, 9, 1, 4), nn.PReLU())
res_blocks = [ResidualBlock(64) for _ in range(num_res_blocks)]
self.res_blocks = nn.Sequential(*res_blocks)
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, 3, 1, 1), nn.BatchNorm2d(64))
# 亚像素卷积上采样 4×
self.upsample = nn.Sequential(
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2),
nn.PReLU(),
nn.Conv2d(64, 256, 3, 1, 1),
nn.PixelShuffle(2),
nn.PReLU(),
)
self.final = nn.Conv2d(64, 3, 9, 1, 4)
def forward(self, x):
out1 = self.conv1(x)
out = self.res_blocks(out1)
out = self.conv2(out) + out1 # 全局跳跃连接
out = self.upsample(out)
out = self.final(out)
return out
判别器部分示例(基于 VGG 风格):
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
def conv_block(in_c, out_c, stride):
return nn.Sequential(
nn.Conv2d(in_c, out_c, 3, stride, 1),
nn.BatchNorm2d(out_c),
nn.LeakyReLU(0.2)
)
self.net = nn.Sequential(
nn.Conv2d(3, 64, 3, 1, 1), nn.LeakyReLU(0.2),
conv_block(64, 64, 2),
conv_block(64, 128, 1),
conv_block(128, 128, 2),
conv_block(128, 256, 1),
conv_block(256, 256, 2),
conv_block(256, 512, 1),
conv_block(512, 512, 2),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
训练循环伪代码:
for epoch in range(n_epochs):
for lr, hr in dataloader:
sr = generator(lr)
# 训练判别器
loss_d = -torch.log(discriminator(hr)) - torch.log(1 - discriminator(sr.detach()))
optimizer_d.zero_grad(); loss_d.backward(); optimizer_d.step()
# 训练生成器
content_loss = mse_loss(vgg(sr), vgg(hr))
adv_loss = -torch.log(discriminator(sr))
loss_g = content_loss + 1e-3 * adv_loss
optimizer_g.zero_grad(); loss_g.backward(); optimizer_g.step()
6. 结果与评估
SRGAN 在多个基准数据集上展示了惊人的视觉效果。与传统方法(如 SRCNN、VDSR)以及同样使用感知损失的方法相比,其特点如下:
- PSNR/SSIM 反而可能略低:因为感知损失不追求逐像素精确匹配,而更关注纹理真实感,所以 PSNR 和 SSIM 指标上未必最高。这一现象说明像素度量与人类感知之间存在不一致。
- 视觉质量大幅提升:SRGAN 重建的图像边缘锐利、纹理清晰,更接近真实照片。在人眼主观测试中,SRGAN 的 MOS(平均意见分)显著优于基于 MSE 优化的模型。
- 典型效果:对于头发丝、羽毛、文字等高频信息,SRGAN 能生成清晰可辨的细节,而非均匀的模糊色块。
7. 局限性与改进
尽管 SRGAN 开创了感知驱动的超分辨率方法,但其仍存在一些不足:
- 伪影问题:对抗训练可能引入不自然的条纹、棋盘格或错误的高频模式,尤其是在纹理较少的平滑区域。
- 不稳定训练:GAN 的训练本身较难收敛,需要细致的超参数调节与技巧,有时还会出现模式坍塌。
- 图像保真度下降:某些场景下生成图像可能偏离原始语义,例如将砖墙纹路重构成植物纹理。
为了解决这些问题,后续工作提出了一系列改进模型,如 ESRGAN(增强型 SRGAN):
- 移除残差块中的批归一化,使用残差密集块(RRDB)作为基本单元,使网络容量更大。
- 采用相对平均判别器(Relativistic GAN),使判别器估计相对真实性而非绝对真实性,提升生成图像的纹理质感。
- 使用激活前的特征计算感知损失(如 VGG 的激活前特征),进一步改善边缘和轮廓的重建。
这些改进使得生成结果在真实感和保真度之间取得了更好的平衡,目前基于 GAN 的超分辨率研究仍在不断发展。
8. 动手实践建议
如果你是初学者,想快速体验 SRGAN,可以参考以下步骤:
- 克隆官方或社区实现的 SRGAN 仓库(如
srgan-pytorch)。 - 下载 DIV2K 数据集,并按需准备低分辨率图像。
- 调整配置文件中的缩放因子、批量大小和训练轮次。
- 先使用预训练模型在几张测试图上推理,观察效果。
- 尝试训练时监控感知损失和判别器损失曲线,调整学习率和对抗损失权重。
通过亲手实验,你可以更直观地理解感知损失与对抗训练如何合力打造出令人惊叹的超分辨率效果。