梯度累积与激活检查点:用小显存模拟大批次

FreeGuideOnline 最新 2026-06-14

梯度累积与激活检查点:用小显存模拟大批次

为什么显存总是不够用?

在训练大型深度学习模型时,显存不足是最常见的瓶颈。模型权重、优化器状态、激活值、梯度——每一项都消耗显存。其中,激活值(前向传播时中间层的输出)往往占据最大份额。

许多人试图通过减小批次大小来缓解显存压力,但这会带来两个问题:

  • 小批次梯度估计不稳定,收敛更慢,甚至难以收敛。
  • 批归一化(Batch Normalization) 在小批次下统计量不准确,效果大打折扣。

有没有一种方法,既能让批次大小保持较小以适配显存,又能模拟出大批次训练的稳定性和效果?答案就是梯度累积。而当激活值本身成为瓶颈时,激活检查点(也称为梯度检查点、重计算技术)则是进一步节省显存的利器。

本文将深入讲解这两种技术的原理、实现方法和最佳实践,帮助你在有限硬件上训练更大的模型、使用更大的等效批次。


梯度累积:用时间换空间

核心思想

梯度累积的原理简单而优雅:不是一次性处理整个大批次,而是将大批次拆分成若干小批次,依次计算每小批次的梯度并累加起来,直到累积次数达到设定值后再执行一次参数更新。

数学上,优化器的标准更新步骤为:

  1. 对批次 B 计算损失 L(B)
  2. 反向传播得到梯度 g
  3. 参数更新:θ ← θ - η·g

在梯度累积中,我们将 B 拆分为 k 个小批次 b₁, b₂, …, bₖ

  1. 对每个 bᵢ 计算损失 L(bᵢ) 并反向传播,得到梯度 gᵢ
  2. 累积梯度g_accum += gᵢ(此时不更新参数);
  3. 重复步骤1-2共 k 次;
  4. 使用累积梯度 g_accum 执行一次参数更新,然后将 g_accum 清零。

等效批次大小 = 每小批次的样本数 × 累积步数。例如,每小批32个样本,累积4步,等效批次大小为128。

为什么有效?

在大多数优化器中,梯度只是损失的线性组合。将大批次拆开计算再求和,与一次性计算整个大批次的梯度在数学上是等价的(忽略批归一化等内部统计量的影响)。因此,梯度累积精确模拟了大批次下的参数更新,只是增加了计算时间(串行执行了 k 次前向与反向传播)。

实现中的细节与陷阱

1. 损失的正确缩放

许多训练框架中,损失函数默认会对批次进行平均(即 reduction='mean')。如果我们只是简单地累加每个小批次的梯度,实际上错误地平均了多次。

正确做法:对每个小批次的损失除以累积步数,或保持损失不做平均但需相应调整。更稳健的方式是:不对损失本身缩放,而是在累加完所有梯度后,将累积梯度除以累积步数,等效为直接计算大批次的平均梯度。

在 PyTorch 中,典型写法如下:

accumulation_steps = 4
optimizer.zero_grad()

for i, (inputs, labels) in enumerate(dataloader):
    outputs = model(inputs)
    loss = criterion(outputs, labels) / accumulation_steps  # 缩放损失
    loss.backward()
    
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

如果不希望缩放损失,也可以在 optimizer.step() 之前手动将每个参数的梯度除以累积步数(但注意梯度累积后需在清零前操作)。

2. 批归一化的影响

批归一化(BatchNorm)依赖批次内的均值和方差来标准化激活值。在梯度累积模式下,每一小批次独立计算自身的统计量,与用整个大批次计算的统计量不同,这会导致训练动态的细微差异。

  • 缓解方案:如果累积步数不太大(例如≤4),这种差异通常可以忽略。对于需要精确大批次统计的场景,可考虑使用同步批归一化(SyncBatchNorm),或记录多个小批次的统计量自行计算,但实现复杂,实际中使用较少。
  • 更彻底的替代:改用对批次大小不敏感的归一化层,如层归一化(LayerNorm)、组归一化(GroupNorm)或实例归一化(InstanceNorm)。在 Transformer 类模型中,层归一化已是标配,因此梯度累积与这些模型天然契合。

3. 优化器状态更新时机

必须确保所有小批次都用同一组模型参数进行前向和反向传播,直到累积结束。如果在累积中途更新了参数,后续小批次的计算将基于过时的参数,破坏了等效大批次训练的假设。

因此,optimizer.step()optimizer.zero_grad() 必须严格在累积步数达到后执行。

实际效果与权衡

  • 优点

    • 在低显存设备上使用大批次,稳定训练。
    • 实现简单,几乎不改变模型结构。
    • 便于分布式训练中扩展有效批次大小。
  • 代价

    • 训练时间线性增加(需要计算更多次前向/反向)。
    • 对于使用 BatchNorm 的模型,可能引入小批次噪声。

激活检查点:让激活值“失而复得”

显存去哪了?

在典型的训练循环中,前向传播产生的中间激活值需要被保存下来,供反向传播计算梯度使用。模型越深、隐藏维度越大,这些激活值消耗的显存就越多。一个简单的 50 层 ResNet 或 BERT 模型,激活值可能占用数 GB 显存。

激活检查点(Activation Checkpointing) 的核心思想是:在前向传播时不保存全部激活值,而只保存部分关键节点(检查点);反向传播需要其它激活值时,从最近的检查点开始重新执行一次前向计算,动态恢复所需的中间张量。

换句话说,我们用额外的计算量交换显存空间,可将显存消耗降低到原来的 1/√n 或更低(取决于检查点放置策略)。

工作流程

以 PyTorch 的 torch.utils.checkpoint 为例,它通过定义一个检查点段(checkpoint segment)来实现:

  1. 将一个模块或计算段标记为检查点区域。
  2. 在第一次前向传播时,该段不保存任何中间激活值,仅保留输入张量和输出张量(以及一些必要的参数状态)。
  3. 反向传播时,当梯度流经此区域,框架会从保留的输入开始,重新执行该段的前向传播,重新计算出所需的中间激活值,然后立即进行对应的反向传播。
  4. 重计算完成后,这些临时激活值被丢弃,显存被释放。

这个过程对模型性能和数值精度没有影响,仅增加一次额外的前向计算。

使用示例

在 PyTorch 中,使用 checkpoint 非常简单:

import torch
from torch.utils.checkpoint import checkpoint

class MyBlock(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(...)
    
    def forward(self, x):
        # 原本直接通过 self.layers(x)
        # 现在用检查点包装,不保存中间激活
        return checkpoint(self.layers, x, use_reentrant=False)

说明:

  • self.layers:需要作为检查点区域的模块或函数。
  • x:输入张量(可包含多个输入)。
  • use_reentrant=False:推荐设置,启用更安全的非重入模式(PyTorch 1.11+),避免一些梯度计算的不兼容问题。

也可对整个 nn.Sequential 进行包装,或对复杂的前向函数使用 checkpoint

def custom_forward(a, b):
    # 复杂计算
    return c

out = checkpoint(custom_forward, a, b, use_reentrant=False)

如何选择检查点位置?

并非所有层都需要激活检查点。一般原则是:

  • 选择计算量相对较小,但激活值占用大的层:例如 Transformer 中的前馈网络(FFN)层、卷积块。这些层的重计算开销不高,但节省显存显著。
  • 避免在计算量极大的部分设置检查点:例如大型矩阵乘法密集的自注意力计算,其重计算代价较高,性价比下降。
  • 在模型最深的几层中应用:深层网络的激活内存通常占比如高,部署检查点可带来最大收益。

许多预训练模型(如 HuggingFace Transformers)提供了内置的梯度检查点开关,只需设置 model.gradient_checkpointing_enable() 即可自动应用于 Transformer 层。

显存节省与时间开销

  • 显存节省:理想情况下,将激活存储从 O(n) 降为 O(√n)O(log n),实际可节省 30%~70% 的激活显存,使得原本无法运行的 batch size 变得可行。
  • 时间开销:每次重计算会导致约 20%~33% 的额外计算时间,具体取决于检查点区域的复杂度。对于大多数模型,这是可以接受的代价。

双剑合璧:梯度累积 + 激活检查点

这两项技术天然互补,可以无缝结合:

  1. 用激活检查点压缩单次训练的激活显存,从而在不 OOM 的前提下尽可能增大每小批次的大小。
  2. 用梯度累积将多个小批次的梯度累加,模拟更大的等效批次,达到大批次训练的稳定性。

假设原本只能设置 batch size = 8,使用激活检查点后可能允许 batch size = 16,再配合梯度累积 4 步,即可达到等效 batch size = 64。这让你在 8GB 显存的 GPU 上也能训练原本需要 32GB 显存的配置。

组合实践示例

以下整合代码展示了在 PyTorch 训练循环中同时使用两者:

from torch.utils.checkpoint import checkpoint

model = LargeModel()
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

accumulation_steps = 4
optimizer.zero_grad()

for step, batch in enumerate(train_loader):
    inputs, labels = batch
    
    # 前向传播中使用激活检查点
    outputs = checkpoint(model, inputs, use_reentrant=False)
    
    loss = criterion(outputs, labels) / accumulation_steps
    loss.backward()
    
    if (step + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

注意:

  • 若模型较复杂,可以对部分块应用检查点,而不是整个模型。
  • 确保在 loss.backward() 调用前,所有涉及的参数都处于相同的 training 模式,且梯度归零仅发生在累积步数完成时。

结合时的注意事项

  1. 学习率和批次大小的缩放
    当有效批次大小大幅增加时,可能需要线性调整学习率(如大批次训练常用的平方根缩放或线性缩放规则)。例如等效批次从 16 增至 128,学习率可尝试增大 8 倍,但需监控训练动态。

  2. BatchNorm 的叠加影响
    激活检查点不会改变 BatchNorm 的行为,但梯度累积的小批次统计量偏差依然存在。若使用 SyncBN,配合检查点亦可,但代码需要额外适配。

  3. 监控显存使用
    使用 torch.cuda.memory_summary()nvidia-smi 观察实际显存下降,调整检查点位置和累积步数以找到最优平衡。


何时该用,何时不必?

适合使用梯度累积的场景

  • 显存不足以容纳所需的大批次,但模型本身可正常运行。
  • 训练不稳定,需要增大批次以获得更好的梯度估计。
  • 分布式训练中,即使单卡批次已最大,仍需进一步增大全局批次。
  • 微调预训练模型,原本冻结部分参数但仍需较大批次。

适合使用激活检查点的场景

  • 激活值显存占比极高(例如使用长序列的 Transformer、深层 CNN)。
  • 模型权重和优化器状态已占满显存,几乎没有空间留给激活值。
  • 不允许牺牲数值精度,愿意以时间换取空间。
  • 使用内存受限的硬件(边缘设备、消费级 GPU)进行研究或原型开发。

可能无需使用的场景

  • 批次大小已经足够满足训练稳定性,且显存有余量。
  • 模型极浅或很小,激活值本身消耗不大。
  • 对训练速度极度敏感,且显存不是瓶颈。

结语

梯度累积和激活检查点是深度学习工程中低显存训练的两个必备技巧。它们实现简单,效果显著,几乎已成为现代训练框架的标准特性。理解其原理和边界条件,可以让你在资源受限的环境中依然保持实验灵活性,不必因为显存而牺牲模型规模或批次设计。

当你下次遇到 CUDA Out of Memory 错误时,不妨试试先减小 batch size,再加上梯度累积;如果激活值太大,开启激活检查点。用小显存模拟大批次,既不是魔法,也不复杂,只是计算与空间的巧妙兑换。