FSDP:PyTorch 原生的全分片数据并行

FreeGuideOnline 最新 2026-06-14

为什么需要 FSDP?

大语言模型的参数量呈指数级增长,单张 GPU 的显存已经远远不够。传统的分布式数据并行(DDP)虽然能把训练扩展到多卡,但它在每张卡上都维护了一份完整的模型副本,显存消耗巨大。FSDP(Fully Sharded Data Parallel,全分片数据并行)正是在这样的背景下诞生的——它通过模型分片,把参数、梯度和优化器状态都切分到所有 GPU 上,从而让单卡显存可以容纳更大的模型。换句话说,有了 FSDP,你就能用同样的 GPU 数量训练更大、更强的模型,而无需求助于复杂的模型并行或流水线并行。

从 DDP 到 ZeRO:分片理念的由来

在深入了解 FSDP 之前,我们先建立两个核心概念:

  • 数据并行:每张 GPU 都有一个完整的模型副本,处理不同的数据子集(mini‑batch),然后通过 All‑Reduce 同步梯度。DDP 就是这种方案的代表。
  • 分片数据并行:不再让每张 GPU 持有完整模型,而是将模型参数、梯度和优化器状态等均匀切分到所有 worker 上。计算时通过通信临时重组需要的完整层,用完即释放,从而大幅降低单卡内存占用。

微软的 ZeRO(Zero Redundancy Optimizer)把这种分片思想分成了三个阶段:

阶段 分片内容 节省显存比例(近似)
ZeRO‑1 优化器状态
ZeRO‑2 优化器状态 + 梯度
ZeRO‑3 优化器状态 + 梯度 + 参数 与 GPU 数量 N 成线性关系,每卡显存 = 总参数 / N + 剩余开销

FSDP 正是 PyTorch 对 ZeRO‑3 的原生实现,它完全集成在 torch.distributed.fsdp 模块中,让你无需第三方库即可享受全分片数据并行带来的显存红利。

FSDP 的核心工作原理

FSDP 的每次前向/反向传播,背后都伴随着精巧的分片与通信机制:

  1. 参数分片
    模型的所有参数在初始化和加载 checkpoint 之后,就被平坦化并等分到所有 rank 上。每个 rank 只持有整个模型参数的一个分片(shard)。

  2. 前向传播
    当计算到某个模块时,FSDP 会通过 All‑Gather 将该模块的全部参数收集到当前 rank 上(其他 rank 也会同步获得)。模块计算完成后,如果策略要求立即释放,这些参数会被丢弃(或写回分片),只保留本地分片。这种“需要时才重建,用后即焚”的方式,确保了任意时刻每张卡上只有极少数层的完整参数。

  3. 反向传播
    梯度在反向传播时同样遵循分片逻辑:计算某模块的梯度时,先 All‑Gather 参数,计算本地梯度,然后通过 Reduce‑Scatter 把梯度聚合并按分片分布到各个 rank 上,从而每个 rank 只持有自己负责的那部分参数的梯度分片。

  4. 优化器步骤
    由于优化器状态(如 Adam 的 momentum 和 variance)也和梯度一一对应,FSDP 使得每个 rank 的优化器只需要更新本地的参数分片,完全无需额外通信。更新完成后,这些参数分片就是最新的权重,等待下一次前向传播的 All‑Gather。

如何快速启用 FSDP?

PyTorch 提供了高层 API,让你可以用极少的代码量从 DDP 切换为 FSDP。最常用的入口是 FullyShardedDataParallel 类或以 fsdp 开头的包装函数。

基础 API 示例

import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    StateDictType,
    FullStateDictConfig,
)

# 1. 标准分布式初始化
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

# 2. 构建模型并将其用 FSDP 包裹
model = MyModel().to(local_rank)
model = FSDP(model)

# 3. 定义优化器(注意:优化器接收的是 FSDP 包裹后的模型参数)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 4. 训练循环不变,但需注意梯度同步已在 FSDP 内部处理
for data, target in dataloader:
    data, target = data.cuda(), target.cuda()
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

只需一行 FSDP(model),原本耗显存的完整模型就会自动分片。但真实世界的大模型训练往往需要更精细的配置,我们接下来会详细讲解。

更精细的 FSDP 配置:包裹策略与参数设置

直接对整个模型使用 FSDP 有一个潜在问题:如果模型包含很多不可训练的层(如 LayerNorm、Dropout),包裹策略会将其也分片,可能影响性能。FSDP 提供了灵活的包裹策略(wrapping policy),让你可以指定哪些子模块应该被独立包装为 FSDP 单元。

基于 module 类型的包裹策略

PyTorch 内置了针对 Transformer 的常用策略:

from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    size_based_auto_wrap_policy,
)
from functools import partial

# 假设使用 HuggingFace 风格的 GPT-2,我们希望把每个 TransformerBlock 作为独立的 FSDP 单元
auto_wrap_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={GPT2Block},
)

model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    device_id=local_rank,
)

transformer_auto_wrap_policy 会识别指定类型的子模块并将它们分别包裹,这样前向传播时可以按块执行 All‑Gather 和释放,进一步降低峰值显存。

FSDP 的核心超参数

创建 FSDP 时可以传入多个参数控制行为,下面是几个关键项:

参数 含义 默认值 建议
sharding_strategy 分片策略 FULL_SHARD 全分片即 ZeRO‑3;设为 SHARD_GRAD_OP 对应 ZeRO‑2;NO_SHARD 即 DDP
cpu_offload 是否将参数和梯度卸载到 CPU None 当 GPU 显存极度紧张时设置 CPUOffload(offload_params=True),但会牺牲训练速度
mixed_precision 混合精度配置 None 推荐为 MixedPrecision(param_dtype=torch.float16, reduce_dtype=torch.float16),可节省显存并加速
backward_prefetch 反向传播时是否预取下一层参数 BACKWARD_PRE 设为 BACKWARD_PRE 可以重叠通信与计算,提升吞吐
activation_checkpointing 是否使用激活检查点(重计算) 需要搭配 torch.utils.checkpoint 使用,进一步节省显存
ignored_modules 忽略不分片的模块列表 [] 适合放 LayerNorm 等小参数层,避免这些层被分片带来不必要的通信开销

实战配置:用 FSDP 训练 7B 模型

以下是一个接近生产环境的 FSDP 配置片段,整合了包裹策略、混合精度和 CPU 卸载:

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    CPUOffload,
    ShardingStrategy,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

mixed_precision_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

# 对于 7B 模型,若显存不够,启用 CPU offload
cpu_offload = CPUOffload(offload_params=True)

model = FSDP(
    model,
    auto_wrap_policy=partial(transformer_auto_wrap_policy, transformer_layer_cls={MyDecoderLayer}),
    mixed_precision=mixed_precision_policy,
    cpu_offload=cpu_offload,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # 等效于默认
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
    device_id=torch.cuda.current_device(),
    ignored_modules=[model.lm_head],  # 输出层通常不分片,避免每次同步
)

FSDP 与 DeepSpeed ZeRO 的对比

特性 FSDP DeepSpeed ZeRO‑3
集成方式 PyTorch 原生,无需额外安装 第三方库,需额外安装 deepspeed
API 复杂度 简单,包裹即可 需编写配置文件或使用 ds_config.json
性能 相近,部分场景 FSDP 通信效率更高 优化历史悠久,社区生态成熟
灵活性 易与 PyTorch 其他特性混合(如 TorchDynamo) 对某些 PyTorch 新特性支持可能滞后
卸载选项 CPU offload(参数/梯度) CPU/NVMe offload,更丰富的卸载层次
大模型支持 可训练达到数百 B 参数 同样支持超大规模

如果你的团队重度使用 PyTorch 生态,并追求更少的依赖和原生的兼容性,FSDP 几乎是首选。而如果你需要更细粒度的卸载到 NVMe 或者一些特定优化(如 1‑bit Adam),则 DeepSpeed 仍然是不错的补充。

模型 checkpoint 的保存与加载

FSDP 重新定义了 state dict 的管理方式。因为每个 rank 只持有分片,直接调用 model.state_dict() 会得到分片后的结果。为了获得完整的模型权重,需要使用特定的上下文管理器:

from torch.distributed.fsdp import (
    FullStateDictConfig,
    StateDictType,
)

# 保存完整模型(只在 rank 0 上执行,并合并分片)
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
    state_dict = model.state_dict()
    if dist.get_rank() == 0:
        torch.save(state_dict, "full_model.pt")

加载时,可以先将完整权重加载到 rank 0,然后广播给所有 worker,或者使用 FSDP 的分片 state dict 加载方式以直接恢复分片状态。具体细节可参考官方文档。

常见问题与调试

  1. 报错 AssertionError: Expected to run on a single device
    通常是因为在 FSDP 内部对环境变量 LOCAL_RANK 或设备放置处理不当,确保已正确调用 torch.cuda.set_device(local_rank)

  2. 通信效率低,GPU 利用率不高
    尝试调整 backward_prefetchBACKWARD_PRE,或者增加单个 FSDP 单元的大小(减少包裹模块数量)。同时确保使用 NCCL 高速通信后端,并考虑用 torch.distributed.barrier() 检查同步点。

  3. 启用 CPU offload 后训练速度极慢
    CPU offload 本质上是以带宽换显存,确认是否真的需要。可优先尝试混合精度、激活检查点,最后再考虑 offload。

  4. 梯度出现 NaN 或 Loss 发散
    混合精度训练下,务必设置合理的 reduce_dtype 和梯度缩放。FSDP 与 PyTorch 的 torch.cuda.amp 完全兼容,建议使用 GradScaler

FSDP 最佳实践速查

  • 先从小模型开始验证:用 DDP 能跑通的脚本,直接替换为 FSDP 包装,确保分布式通信无误,再逐步放大模型尺寸。
  • 组合使用激活检查点:在 FSDP 包裹的子模块内使用 torch.utils.checkpoint.checkpoint,可以节省 30‑50% 的激活内存。
  • 善用 ignored_modules:将 LayerNorm、bias 层等参数极少的模块排除在分片外,避免不必要的 All‑Gather 通信。
  • 监控显存:使用 torch.cuda.memory_summary() 或 PyTorch Profiler 观察分片前后的显存变化,验证 FSDP 的节省效果。
  • 合理搭配 sync_module_states:从磁盘加载预训练权重时,设置 sync_module_states=True 可以让 FSDP 自动广播权重到各个 rank,避免手动操作。

总结

FSDP 让 PyTorch 用户无需离开舒适区,就能享受到完整的 ZeRO‑3 体验。它通过原生 API 将模型分片、混合精度、CPU 卸载等高级能力无缝融入训练框架,显著降低了大模型训练的硬件门槛。从单机多卡到千卡集群,FSDP 都表现出良好的可扩展性。掌握了本文介绍的包裹策略、关键参数和调试技巧,你就可以信心满满地将 FSDP 应用到自己的大模型训练任务中,用更少的显存,撬动更大的模型。