内存优化技巧:在有限显存上训练更大模型

FreeGuideOnline 最新 2026-06-28

为什么内存会成为训练大模型的瓶颈

训练深度学习模型时,显存(VRAM)需要同时存储模型参数、梯度、优化器状态以及前向传播产生的中间激活值。模型越大,这些数据占用的显存就越多,一旦超出物理显存上限,训练就会直接报错中断。对于个人开发者或使用中低端显卡的场景,学会在有限显存上训练更大模型,比盲目升级硬件更具实际意义。

本教程将从机制出发,逐一拆解最有效的内存优化技巧,并提供可直接应用的代码示例。所有方法均基于 PyTorch 生态,但原理同样适用于其他框架。


1. 先量化你的内存占用

在优化之前,必须清楚内存都消耗在哪里。以混合精度训练为例,单张 GPU 上的内存分布大致如下(以字节计):

组件 精度 参数量公式 说明
模型参数 FP32 / FP16 参数量 × 字节数 FP32 占 4 字节,FP16 占 2 字节
梯度 FP32 / FP16 同参数 反向传播后存储
优化器状态 FP32 2 × 参数量 × 4(Adam) Adam 需要存储一阶矩和二阶矩
激活值 取决于批次大小 批次大小 × 序列长度 × 隐藏维度 × 层数 × 字节数 前向传播的中间结果

举例来说,一个 70 亿参数模型(7B),如果全部使用 FP32 并配合 Adam 优化器,仅参数、梯度和优化器状态就需要约 7e9 × 4 × (1 + 1 + 2) = 112 GB 显存,远高于消费级显卡的 24 GB 或 48 GB。优化就是从这些分量入手。


2. 梯度累积:用小批次模拟大批次

原理:优化器通常在 loss.backward() 时立即更新参数。梯度累积将多个小批次(micro-batch)的梯度累加起来,待累加到等效大批次后再执行一次参数更新。

为什么能省内存:它不直接减少显存峰值占用,但可以让你在不增加显存的情况下,使用与大显存方案相同的有效批次大小,从而间接“训练更大模型”——因为你原本因为批次太大跑不动,现在可以跑起来了。

PyTorch 实现

accumulation_steps = 4
optimizer.zero_grad()
for i, data in enumerate(dataloader):
    loss = model(data) / accumulation_steps  # 关键:按累积步数缩放损失
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

注意:Batch Normalization 这类依赖批次统计的层在梯度累积时行为会改变,可改用 Group Norm 或同步 Batch Norm。


3. 混合精度训练:半是 FP16,半是 FP32

原理:前向传播和反向传播使用 FP16 进行计算,这直接让激活值和梯度占用的显存减半。但参数更新仍然在 FP32 精度下完成,以保持数值稳定性。同时,通过**损失缩放(Loss Scaling)**防止 FP16 下的小梯度下溢。

工具:PyTorch 的 torch.cuda.amp 封装了所有细节。

scaler = torch.cuda.amp.GradScaler()
optimizer.zero_grad()
for data in dataloader:
    with torch.cuda.amp.autocast():
        loss = model(data)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

内存收益:激活值和梯度内存减半,对于大批次训练效果尤其明显。结合梯度累积,可以进一步突破显存限制。


4. 梯度检查点(激活重计算)

原理:默认情况下,前向传播每层产生的激活值都会被保留以供反向传播计算梯度。梯度检查点技术选择性地不保存中间激活值,在反向传播需要时,从最近的检查点开始重新向前计算这部分激活值。这是一种典型的以时间换空间策略。

PyTorch 实现:使用 torch.utils.checkpoint.checkpoint 包裹需要节省显存的模块。

from torch.utils.checkpoint import checkpoint

class CheckpointBlock(torch.nn.Module):
    def forward(self, x):
        return checkpoint(self._forward, x, use_reentrant=False)
    
    def _forward(self, x):
        # 多层的计算逻辑
        return x

适用场景:对 Transformer 的每个 encoder/decoder 层应用检查点,可以在几乎不影响吞吐量的前提下,将激活值内存从 O(n) 降到 O(√n) 甚至 O(1)。在大模型训练中是标配技巧。


5. 优化器状态分片:ZeRO 优化

原理:Adam 等优化器需要在训练过程中维护与参数同等规模的动量方差状态,这部分内存占用极大。微软的 DeepSpeed 库实现了 ZeRO(零冗余优化器),将优化器状态、梯度、甚至参数分片到多张 GPU 上,每张卡仅持有自己负责的那部分,极大降低单卡显存压力。

ZeRO 分三个阶段:

  • ZeRO-1:分片优化器状态。
  • ZeRO-2:额外分片梯度。
  • ZeRO-3:额外分片模型参数。

单卡也能受益于 ZeRO 思想。DeepSpeed 可配置阶段 1 或 2 运行于单 GPU,从而减少优化器状态内存。

配置示例(DeepSpeed JSON)

{
  "zero_optimization": {
    "stage": 2,
    "offload_optimizer": {
      "device": "cpu"
    }
  }
}

收益:ZeRO-2 在 8 张 GPU 上可节省近 8 倍优化器状态内存,单卡理论上也能大幅压缩。


6. CPU 卸载:把用不上的数据挪走

原理:训练过程中,优化器状态、部分模型参数或激活值并不是每时每刻都需要高带宽显存访问。把这些数据卸载到 CPU 内存中,仅在计算必要时挪回 GPU,可以进一步降低显存峰值。

工具

  • DeepSpeedoffload_optimizeroffload_param 配置,可将优化器状态和参数卸载到 CPU 或 NVMe 硬盘。
  • PyTorch 原生:使用 tensor.to('cpu')tensor.to('cuda') 手动控制,但 DeepSpeed 的自动化卸载更高效。

示例配置开启优化器 CPU 卸载:

{
  "zero_optimization": {
    "stage": 2,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    }
  }
}

代价:CPU-GPU 之间的数据传输会拖慢训练速度,但可以让原来跑不起来的模型成功运行。


7. 模型并行:把模型切开

当单卡连模型参数都放不下时,就必须采用模型并行,将模型的不同层或同一层的不同部分分布到多张 GPU 上。

  • 朴素流水线并行:将模型按层切分,比如 GPU0 放前 6 层,GPU1 放后 6 层。数据像流水线一样传递,但存在 GPU 空闲等待时间。
  • 张量并行:将单个层内的矩阵乘法按列或行切分到多卡并行计算。对 Transformer 的注意力头或前馈网络切分效果很好。
  • 推荐的库:使用 Megatron-LM 或 Hugging Face 的 transformers 结合 DeepSpeed 集成张量并行。

对于多数个人使用者,处理十亿参数级模型可用 DeepSpeed ZeRO-3 + CPU 卸载;百亿参数则需结合张量并行。


8. 使用低比特量化与 QLoRA 训练

如果只是进行微调,而非从零训练,可以利用量化技术将预训练模型转为 4 位或 8 位精度存储,大幅降低显存占用,再结合低秩适配(LoRA)仅训练少量额外参数。

QLoRA 的核心思路

  • 将冻结的基座模型量化到 4 位(NF4)。
  • 在前向反向传播时反量化为 FP16 进行计算。
  • 只训练附加的低秩矩阵,优化器状态仅针对这几个参数,内存消耗极低。

代码示例(bitsandbytes + PEFT)

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained("模型名", quantization_config=bnb_config)
lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"])
model = get_peft_model(model, lora_config)

这样可以在单张 24 GB 显卡上微调 65B 甚至更大的模型。


9. 精简数据类型与 FlashAttention

FlashAttention:一种在 GPU 上实现高效注意力计算的算法,通过分块计算减少对高带宽显存的读写,从而将中间注意力矩阵的内存从 O(n²) 降低到 O(n)。PyTorch 2.0 以上已通过 torch.backends.cuda.sdp_kernelflash_attn 包支持。

开启后,长序列训练的显存降低非常显著,同时计算更快。

# 使用 PyTorch 内置的缩放点积注意力,自动调用 FlashAttention
with torch.backends.cuda.sdp_kernel(enable_flash=True):
    output = torch.nn.functional.scaled_dot_product_attention(query, key, value)

数据类型:除了混合精度,还可以使用 TF32(Ampere 架构以上),通过设置 torch.backends.cuda.matmul.allow_tf32 = True 可在不牺牲太多精度的情况下加速计算,但不直接节省显存;显存节省主要来自 FP16 或 BF16。


10. 内存优化的组合拳策略

实际训练时,不应该孤立使用某一个技巧,而是按需组合:

  1. 基础组合:梯度累积 + 混合精度(FP16/BF16)+ FlashAttention。

    • 适合可放入显存的模型但批次受限的情况。
  2. 进阶组合:在上述基础上增加梯度检查点 + ZeRO-2(DeepSpeed)。

    • 解决激活值和优化器状态爆显存的问题,适用于十亿参数级模型。
  3. 极限方案:ZeRO-3 + CPU 卸载 + 模型并行(张量并行/流水线并行),或者直接使用 QLoRA 进行微调。

    • 让消费级显卡也能触达百亿参数模型的训练和微调。

调试显存的工具

  • torch.cuda.memory_summary() 查看详细内存分配。
  • nvidia-smi 持续监控。
  • DeepSpeed 的 ds_report 和内存监控日志。

总结:从“跑不动”到“跑得稳”

有限显存训练更大模型的本质是在时间、计算、通信之间寻找最优平衡。掌握以上技巧后,你可以:

  • 用一张 24 GB 显卡微调 70B 语言模型(QLoRA)
  • 用两张 24 GB 显卡从零训练 13B 模型(ZeRO-2 + 梯度检查点 + 混合精度)
  • 大幅降低云 GPU 实例成本,让实验迭代更快

将这些技巧内化为工程习惯,你会发现显存不再是限制你探索大模型的枷锁,而是一道可以通过工程智慧解开的数学题。