FSDP:PyTorch 原生的全分片数据并行
为什么需要 FSDP?
大语言模型的参数量呈指数级增长,单张 GPU 的显存已经远远不够。传统的分布式数据并行(DDP)虽然能把训练扩展到多卡,但它在每张卡上都维护了一份完整的模型副本,显存消耗巨大。FSDP(Fully Sharded Data Parallel,全分片数据并行)正是在这样的背景下诞生的——它通过模型分片,把参数、梯度和优化器状态都切分到所有 GPU 上,从而让单卡显存可以容纳更大的模型。换句话说,有了 FSDP,你就能用同样的 GPU 数量训练更大、更强的模型,而无需求助于复杂的模型并行或流水线并行。
从 DDP 到 ZeRO:分片理念的由来
在深入了解 FSDP 之前,我们先建立两个核心概念:
- 数据并行:每张 GPU 都有一个完整的模型副本,处理不同的数据子集(mini‑batch),然后通过 All‑Reduce 同步梯度。DDP 就是这种方案的代表。
- 分片数据并行:不再让每张 GPU 持有完整模型,而是将模型参数、梯度和优化器状态等均匀切分到所有 worker 上。计算时通过通信临时重组需要的完整层,用完即释放,从而大幅降低单卡内存占用。
微软的 ZeRO(Zero Redundancy Optimizer)把这种分片思想分成了三个阶段:
| 阶段 | 分片内容 | 节省显存比例(近似) |
|---|---|---|
| ZeRO‑1 | 优化器状态 | 4× |
| ZeRO‑2 | 优化器状态 + 梯度 | 8× |
| ZeRO‑3 | 优化器状态 + 梯度 + 参数 | 与 GPU 数量 N 成线性关系,每卡显存 = 总参数 / N + 剩余开销 |
FSDP 正是 PyTorch 对 ZeRO‑3 的原生实现,它完全集成在 torch.distributed.fsdp 模块中,让你无需第三方库即可享受全分片数据并行带来的显存红利。
FSDP 的核心工作原理
FSDP 的每次前向/反向传播,背后都伴随着精巧的分片与通信机制:
-
参数分片
模型的所有参数在初始化和加载 checkpoint 之后,就被平坦化并等分到所有 rank 上。每个 rank 只持有整个模型参数的一个分片(shard)。 -
前向传播
当计算到某个模块时,FSDP 会通过 All‑Gather 将该模块的全部参数收集到当前 rank 上(其他 rank 也会同步获得)。模块计算完成后,如果策略要求立即释放,这些参数会被丢弃(或写回分片),只保留本地分片。这种“需要时才重建,用后即焚”的方式,确保了任意时刻每张卡上只有极少数层的完整参数。 -
反向传播
梯度在反向传播时同样遵循分片逻辑:计算某模块的梯度时,先 All‑Gather 参数,计算本地梯度,然后通过 Reduce‑Scatter 把梯度聚合并按分片分布到各个 rank 上,从而每个 rank 只持有自己负责的那部分参数的梯度分片。 -
优化器步骤
由于优化器状态(如 Adam 的 momentum 和 variance)也和梯度一一对应,FSDP 使得每个 rank 的优化器只需要更新本地的参数分片,完全无需额外通信。更新完成后,这些参数分片就是最新的权重,等待下一次前向传播的 All‑Gather。
如何快速启用 FSDP?
PyTorch 提供了高层 API,让你可以用极少的代码量从 DDP 切换为 FSDP。最常用的入口是 FullyShardedDataParallel 类或以 fsdp 开头的包装函数。
基础 API 示例
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType,
FullStateDictConfig,
)
# 1. 标准分布式初始化
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# 2. 构建模型并将其用 FSDP 包裹
model = MyModel().to(local_rank)
model = FSDP(model)
# 3. 定义优化器(注意:优化器接收的是 FSDP 包裹后的模型参数)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# 4. 训练循环不变,但需注意梯度同步已在 FSDP 内部处理
for data, target in dataloader:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
只需一行 FSDP(model),原本耗显存的完整模型就会自动分片。但真实世界的大模型训练往往需要更精细的配置,我们接下来会详细讲解。
更精细的 FSDP 配置:包裹策略与参数设置
直接对整个模型使用 FSDP 有一个潜在问题:如果模型包含很多不可训练的层(如 LayerNorm、Dropout),包裹策略会将其也分片,可能影响性能。FSDP 提供了灵活的包裹策略(wrapping policy),让你可以指定哪些子模块应该被独立包装为 FSDP 单元。
基于 module 类型的包裹策略
PyTorch 内置了针对 Transformer 的常用策略:
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
)
from functools import partial
# 假设使用 HuggingFace 风格的 GPT-2,我们希望把每个 TransformerBlock 作为独立的 FSDP 单元
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={GPT2Block},
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
device_id=local_rank,
)
transformer_auto_wrap_policy 会识别指定类型的子模块并将它们分别包裹,这样前向传播时可以按块执行 All‑Gather 和释放,进一步降低峰值显存。
FSDP 的核心超参数
创建 FSDP 时可以传入多个参数控制行为,下面是几个关键项:
| 参数 | 含义 | 默认值 | 建议 |
|---|---|---|---|
sharding_strategy |
分片策略 | FULL_SHARD |
全分片即 ZeRO‑3;设为 SHARD_GRAD_OP 对应 ZeRO‑2;NO_SHARD 即 DDP |
cpu_offload |
是否将参数和梯度卸载到 CPU | None |
当 GPU 显存极度紧张时设置 CPUOffload(offload_params=True),但会牺牲训练速度 |
mixed_precision |
混合精度配置 | None |
推荐为 MixedPrecision(param_dtype=torch.float16, reduce_dtype=torch.float16),可节省显存并加速 |
backward_prefetch |
反向传播时是否预取下一层参数 | BACKWARD_PRE |
设为 BACKWARD_PRE 可以重叠通信与计算,提升吞吐 |
activation_checkpointing |
是否使用激活检查点(重计算) | 无 |
需要搭配 torch.utils.checkpoint 使用,进一步节省显存 |
ignored_modules |
忽略不分片的模块列表 | [] |
适合放 LayerNorm 等小参数层,避免这些层被分片带来不必要的通信开销 |
实战配置:用 FSDP 训练 7B 模型
以下是一个接近生产环境的 FSDP 配置片段,整合了包裹策略、混合精度和 CPU 卸载:
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
CPUOffload,
ShardingStrategy,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
mixed_precision_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
# 对于 7B 模型,若显存不够,启用 CPU offload
cpu_offload = CPUOffload(offload_params=True)
model = FSDP(
model,
auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls={MyDecoderLayer}),
mixed_precision=mixed_precision_policy,
cpu_offload=cpu_offload,
sharding_strategy=ShardingStrategy.FULL_SHARD, # 等效于默认
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=torch.cuda.current_device(),
ignored_modules=[model.lm_head], # 输出层通常不分片,避免每次同步
)
FSDP 与 DeepSpeed ZeRO 的对比
| 特性 | FSDP | DeepSpeed ZeRO‑3 |
|---|---|---|
| 集成方式 | PyTorch 原生,无需额外安装 | 第三方库,需额外安装 deepspeed |
| API 复杂度 | 简单,包裹即可 | 需编写配置文件或使用 ds_config.json |
| 性能 | 相近,部分场景 FSDP 通信效率更高 | 优化历史悠久,社区生态成熟 |
| 灵活性 | 易与 PyTorch 其他特性混合(如 TorchDynamo) | 对某些 PyTorch 新特性支持可能滞后 |
| 卸载选项 | CPU offload(参数/梯度) | CPU/NVMe offload,更丰富的卸载层次 |
| 大模型支持 | 可训练达到数百 B 参数 | 同样支持超大规模 |
如果你的团队重度使用 PyTorch 生态,并追求更少的依赖和原生的兼容性,FSDP 几乎是首选。而如果你需要更细粒度的卸载到 NVMe 或者一些特定优化(如 1‑bit Adam),则 DeepSpeed 仍然是不错的补充。
模型 checkpoint 的保存与加载
FSDP 重新定义了 state dict 的管理方式。因为每个 rank 只持有分片,直接调用 model.state_dict() 会得到分片后的结果。为了获得完整的模型权重,需要使用特定的上下文管理器:
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
)
# 保存完整模型(只在 rank 0 上执行,并合并分片)
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, "full_model.pt")
加载时,可以先将完整权重加载到 rank 0,然后广播给所有 worker,或者使用 FSDP 的分片 state dict 加载方式以直接恢复分片状态。具体细节可参考官方文档。
常见问题与调试
-
报错
AssertionError: Expected to run on a single device
通常是因为在 FSDP 内部对环境变量LOCAL_RANK或设备放置处理不当,确保已正确调用torch.cuda.set_device(local_rank)。 -
通信效率低,GPU 利用率不高
尝试调整backward_prefetch为BACKWARD_PRE,或者增加单个 FSDP 单元的大小(减少包裹模块数量)。同时确保使用 NCCL 高速通信后端,并考虑用torch.distributed.barrier()检查同步点。 -
启用 CPU offload 后训练速度极慢
CPU offload 本质上是以带宽换显存,确认是否真的需要。可优先尝试混合精度、激活检查点,最后再考虑 offload。 -
梯度出现 NaN 或 Loss 发散
混合精度训练下,务必设置合理的reduce_dtype和梯度缩放。FSDP 与 PyTorch 的torch.cuda.amp完全兼容,建议使用GradScaler。
FSDP 最佳实践速查
- 先从小模型开始验证:用 DDP 能跑通的脚本,直接替换为 FSDP 包装,确保分布式通信无误,再逐步放大模型尺寸。
- 组合使用激活检查点:在
FSDP包裹的子模块内使用torch.utils.checkpoint.checkpoint,可以节省 30‑50% 的激活内存。 - 善用
ignored_modules:将 LayerNorm、bias 层等参数极少的模块排除在分片外,避免不必要的 All‑Gather 通信。 - 监控显存:使用
torch.cuda.memory_summary()或 PyTorch Profiler 观察分片前后的显存变化,验证 FSDP 的节省效果。 - 合理搭配
sync_module_states:从磁盘加载预训练权重时,设置sync_module_states=True可以让 FSDP 自动广播权重到各个 rank,避免手动操作。
总结
FSDP 让 PyTorch 用户无需离开舒适区,就能享受到完整的 ZeRO‑3 体验。它通过原生 API 将模型分片、混合精度、CPU 卸载等高级能力无缝融入训练框架,显著降低了大模型训练的硬件门槛。从单机多卡到千卡集群,FSDP 都表现出良好的可扩展性。掌握了本文介绍的包裹策略、关键参数和调试技巧,你就可以信心满满地将 FSDP 应用到自己的大模型训练任务中,用更少的显存,撬动更大的模型。