内存优化技巧:在有限显存上训练更大模型
为什么内存会成为训练大模型的瓶颈
训练深度学习模型时,显存(VRAM)需要同时存储模型参数、梯度、优化器状态以及前向传播产生的中间激活值。模型越大,这些数据占用的显存就越多,一旦超出物理显存上限,训练就会直接报错中断。对于个人开发者或使用中低端显卡的场景,学会在有限显存上训练更大模型,比盲目升级硬件更具实际意义。
本教程将从机制出发,逐一拆解最有效的内存优化技巧,并提供可直接应用的代码示例。所有方法均基于 PyTorch 生态,但原理同样适用于其他框架。
1. 先量化你的内存占用
在优化之前,必须清楚内存都消耗在哪里。以混合精度训练为例,单张 GPU 上的内存分布大致如下(以字节计):
| 组件 | 精度 | 参数量公式 | 说明 |
|---|---|---|---|
| 模型参数 | FP32 / FP16 | 参数量 × 字节数 |
FP32 占 4 字节,FP16 占 2 字节 |
| 梯度 | FP32 / FP16 | 同参数 | 反向传播后存储 |
| 优化器状态 | FP32 | 2 × 参数量 × 4(Adam) |
Adam 需要存储一阶矩和二阶矩 |
| 激活值 | 取决于批次大小 | 批次大小 × 序列长度 × 隐藏维度 × 层数 × 字节数 |
前向传播的中间结果 |
举例来说,一个 70 亿参数模型(7B),如果全部使用 FP32 并配合 Adam 优化器,仅参数、梯度和优化器状态就需要约 7e9 × 4 × (1 + 1 + 2) = 112 GB 显存,远高于消费级显卡的 24 GB 或 48 GB。优化就是从这些分量入手。
2. 梯度累积:用小批次模拟大批次
原理:优化器通常在 loss.backward() 时立即更新参数。梯度累积将多个小批次(micro-batch)的梯度累加起来,待累加到等效大批次后再执行一次参数更新。
为什么能省内存:它不直接减少显存峰值占用,但可以让你在不增加显存的情况下,使用与大显存方案相同的有效批次大小,从而间接“训练更大模型”——因为你原本因为批次太大跑不动,现在可以跑起来了。
PyTorch 实现:
accumulation_steps = 4
optimizer.zero_grad()
for i, data in enumerate(dataloader):
loss = model(data) / accumulation_steps # 关键:按累积步数缩放损失
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
注意:Batch Normalization 这类依赖批次统计的层在梯度累积时行为会改变,可改用 Group Norm 或同步 Batch Norm。
3. 混合精度训练:半是 FP16,半是 FP32
原理:前向传播和反向传播使用 FP16 进行计算,这直接让激活值和梯度占用的显存减半。但参数更新仍然在 FP32 精度下完成,以保持数值稳定性。同时,通过**损失缩放(Loss Scaling)**防止 FP16 下的小梯度下溢。
工具:PyTorch 的 torch.cuda.amp 封装了所有细节。
scaler = torch.cuda.amp.GradScaler()
optimizer.zero_grad()
for data in dataloader:
with torch.cuda.amp.autocast():
loss = model(data)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
内存收益:激活值和梯度内存减半,对于大批次训练效果尤其明显。结合梯度累积,可以进一步突破显存限制。
4. 梯度检查点(激活重计算)
原理:默认情况下,前向传播每层产生的激活值都会被保留以供反向传播计算梯度。梯度检查点技术选择性地不保存中间激活值,在反向传播需要时,从最近的检查点开始重新向前计算这部分激活值。这是一种典型的以时间换空间策略。
PyTorch 实现:使用 torch.utils.checkpoint.checkpoint 包裹需要节省显存的模块。
from torch.utils.checkpoint import checkpoint
class CheckpointBlock(torch.nn.Module):
def forward(self, x):
return checkpoint(self._forward, x, use_reentrant=False)
def _forward(self, x):
# 多层的计算逻辑
return x
适用场景:对 Transformer 的每个 encoder/decoder 层应用检查点,可以在几乎不影响吞吐量的前提下,将激活值内存从 O(n) 降到 O(√n) 甚至 O(1)。在大模型训练中是标配技巧。
5. 优化器状态分片:ZeRO 优化
原理:Adam 等优化器需要在训练过程中维护与参数同等规模的动量方差状态,这部分内存占用极大。微软的 DeepSpeed 库实现了 ZeRO(零冗余优化器),将优化器状态、梯度、甚至参数分片到多张 GPU 上,每张卡仅持有自己负责的那部分,极大降低单卡显存压力。
ZeRO 分三个阶段:
- ZeRO-1:分片优化器状态。
- ZeRO-2:额外分片梯度。
- ZeRO-3:额外分片模型参数。
单卡也能受益于 ZeRO 思想。DeepSpeed 可配置阶段 1 或 2 运行于单 GPU,从而减少优化器状态内存。
配置示例(DeepSpeed JSON):
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
}
}
}
收益:ZeRO-2 在 8 张 GPU 上可节省近 8 倍优化器状态内存,单卡理论上也能大幅压缩。
6. CPU 卸载:把用不上的数据挪走
原理:训练过程中,优化器状态、部分模型参数或激活值并不是每时每刻都需要高带宽显存访问。把这些数据卸载到 CPU 内存中,仅在计算必要时挪回 GPU,可以进一步降低显存峰值。
工具:
- DeepSpeed 的
offload_optimizer和offload_param配置,可将优化器状态和参数卸载到 CPU 或 NVMe 硬盘。 - PyTorch 原生:使用
tensor.to('cpu')和tensor.to('cuda')手动控制,但 DeepSpeed 的自动化卸载更高效。
示例配置开启优化器 CPU 卸载:
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
}
}
代价:CPU-GPU 之间的数据传输会拖慢训练速度,但可以让原来跑不起来的模型成功运行。
7. 模型并行:把模型切开
当单卡连模型参数都放不下时,就必须采用模型并行,将模型的不同层或同一层的不同部分分布到多张 GPU 上。
- 朴素流水线并行:将模型按层切分,比如 GPU0 放前 6 层,GPU1 放后 6 层。数据像流水线一样传递,但存在 GPU 空闲等待时间。
- 张量并行:将单个层内的矩阵乘法按列或行切分到多卡并行计算。对 Transformer 的注意力头或前馈网络切分效果很好。
- 推荐的库:使用 Megatron-LM 或 Hugging Face 的
transformers结合 DeepSpeed 集成张量并行。
对于多数个人使用者,处理十亿参数级模型可用 DeepSpeed ZeRO-3 + CPU 卸载;百亿参数则需结合张量并行。
8. 使用低比特量化与 QLoRA 训练
如果只是进行微调,而非从零训练,可以利用量化技术将预训练模型转为 4 位或 8 位精度存储,大幅降低显存占用,再结合低秩适配(LoRA)仅训练少量额外参数。
QLoRA 的核心思路:
- 将冻结的基座模型量化到 4 位(NF4)。
- 在前向反向传播时反量化为 FP16 进行计算。
- 只训练附加的低秩矩阵,优化器状态仅针对这几个参数,内存消耗极低。
代码示例(bitsandbytes + PEFT):
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained("模型名", quantization_config=bnb_config)
lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"])
model = get_peft_model(model, lora_config)
这样可以在单张 24 GB 显卡上微调 65B 甚至更大的模型。
9. 精简数据类型与 FlashAttention
FlashAttention:一种在 GPU 上实现高效注意力计算的算法,通过分块计算减少对高带宽显存的读写,从而将中间注意力矩阵的内存从 O(n²) 降低到 O(n)。PyTorch 2.0 以上已通过 torch.backends.cuda.sdp_kernel 或 flash_attn 包支持。
开启后,长序列训练的显存降低非常显著,同时计算更快。
# 使用 PyTorch 内置的缩放点积注意力,自动调用 FlashAttention
with torch.backends.cuda.sdp_kernel(enable_flash=True):
output = torch.nn.functional.scaled_dot_product_attention(query, key, value)
数据类型:除了混合精度,还可以使用 TF32(Ampere 架构以上),通过设置 torch.backends.cuda.matmul.allow_tf32 = True 可在不牺牲太多精度的情况下加速计算,但不直接节省显存;显存节省主要来自 FP16 或 BF16。
10. 内存优化的组合拳策略
实际训练时,不应该孤立使用某一个技巧,而是按需组合:
-
基础组合:梯度累积 + 混合精度(FP16/BF16)+ FlashAttention。
- 适合可放入显存的模型但批次受限的情况。
-
进阶组合:在上述基础上增加梯度检查点 + ZeRO-2(DeepSpeed)。
- 解决激活值和优化器状态爆显存的问题,适用于十亿参数级模型。
-
极限方案:ZeRO-3 + CPU 卸载 + 模型并行(张量并行/流水线并行),或者直接使用 QLoRA 进行微调。
- 让消费级显卡也能触达百亿参数模型的训练和微调。
调试显存的工具:
torch.cuda.memory_summary()查看详细内存分配。nvidia-smi持续监控。- DeepSpeed 的
ds_report和内存监控日志。
总结:从“跑不动”到“跑得稳”
有限显存训练更大模型的本质是在时间、计算、通信之间寻找最优平衡。掌握以上技巧后,你可以:
- 用一张 24 GB 显卡微调 70B 语言模型(QLoRA)
- 用两张 24 GB 显卡从零训练 13B 模型(ZeRO-2 + 梯度检查点 + 混合精度)
- 大幅降低云 GPU 实例成本,让实验迭代更快
将这些技巧内化为工程习惯,你会发现显存不再是限制你探索大模型的枷锁,而是一道可以通过工程智慧解开的数学题。