模型检查点策略:何时保存以及保存什么
if current_val_score > best_score + min_delta: save_checkpoint() best_score = current_val_score
关键概念:
- **主指标方向**:正确设定 `mode`(如 loss 为 `min`,accuracy 为 `max`)。
- **min_delta**:必须超过的改进绝对量,防止因噪声导致的无意义保存。
- **patience**:允许连续多少轮无提升后才触发其他动作(如学习率衰减或早停),有时也与检查点保存结合。
### 3.3 周期余弦退火与快照集成 (Snapshot Ensemble)
在一个余弦退火学习率周期内保存多个局部最低点的模型,用于后续集成。
- **区别**:不仅仅保存一个最佳模型,而是保存每个周期收敛时的模型。
- **优势**:不增加额外训练成本即可获得多个多样性较高的模型进行集成,提升最终精度。
- **时机**:在每次学习率下降到周期最低位置时保存。
### 3.4 异常中断时保护 (Emergency Checkpoint)
训练过程可能因异常(如显存溢出、超时)而终止,此时需要自动保存一份紧急检查点。
- **实现方式**:捕获系统信号(如 `SIGTERM`)或使用框架回调的 `on_train_end` 钩子。
- **保存内容**:必须包含完整的优化器和调度器状态。
### 3.5 时间限制下的分段保存 (Time-based Checkpoint)
在集群任务有时间限制(如 Slurm 队列最长运行 48 小时)时,需要在运行结束前自动保存最后状态,以便重新提交继续训练。
- **策略**:计算剩余时间,在结束前 5~10 分钟保存检查点。
- **额外需求**:保存当前数据加载器的状态(shuffle种子、已遍历样本数),但这在大多数框架中较难实现,因此通常采用每个 epoch 开始时重置数据迭代的方式。
---
## 4. 保存什么:构建完整的检查点工件
一个健壮的检查点应该是**自包含**且**可直接恢复训练**的。推荐使用字典结构统一打包。
```python
checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
'best_score': best_score,
'arch': model_arch_name,
'config': training_config,
'scaler_state_dict': scaler.state_dict() if use_amp else None,
}
torch.save(checkpoint, 'checkpoint.pt')
4.1 模型权重部分
- state_dict 是最佳选择,相对于保存整个模型对象(
torch.save(model, ...)),它更安全、文件更小、不绑定代码目录结构。 - 如果仅用于推断或部署,可以额外导出一份 jit 脚本模型 或 ONNX 格式,但这不属于训练检查点范畴。
4.2 优化器状态
优化器的动量、方差等是继续训练的关键。丢失优化器状态会导致收敛曲线出现“毛刺”,甚至需要重新 warmup。从不保存优化器状态的检查点恢复训练,等于从一个随机初始化重新开始动量累积,这极大可能损害当前阶段的优化效果。
4.3 学习率调度器状态
学习率变化曲线决定了训练的动态。尤其是复杂调度器(如带有热身的余弦退火),状态中可能包含已经执行的步数、周期计数等。
4.4 混合精度训练的 Gradient Scalers
若使用 PyTorch 的 torch.cuda.amp.GradScaler,必须保存其状态,否则恢复训练后缩放因子无法匹配,会导致梯度下溢或溢出。
4.5 随机种子与数据加载器状态
这属于高级需求。要做到精确复现中断点,还需要记录:
- Python、NumPy、PyTorch 的随机种子状态。
- 数据加载器当前索引(对于可索引数据集)。通常通过记录
global_step并在恢复时跳过相应批次来近似实现。
4.6 配置文件与元数据
在检查点同级目录或检查点内部放置一份 config.yaml,记录所有超参数,如:
model: "resnet50"
optimizer: "adamw"
learning_rate: 0.001
batch_size: 128
epochs: 200
这样即使数月后也能即刻了解该检查点的上下文。
5. 存储规划与命名规则
5.1 文件命名规范
混乱的命名是检查点管理的头号敌人。建议采用具有辨识度的命名格式:
{experiment_name}-{model_arch}-{epoch}ep-{metric_name}-{metric_value:.4f}.pt
例如: exp001-resnet50-epoch25-val_loss-0.1521.pt
定期保存与最佳模型文件可区分存放:
periodic/文件夹存放每 N epoch 的检查点。best/文件夹仅存放基于某指标的最佳模型(可直接覆盖同名文件)。
5.2 保留策略
磁盘空间有限,需要自动清理过期检查点:
- 保留最近 K 个:始终保留最近 K 个周期间隔的检查点,删除更老的。
- 仅保留最优:定期检查点全部删除,只留最佳模型。
- Top N 最佳:保留验证分数最高的 N 个模型。
- 动态加权:结合时间和指标排名,保留近期较好模型和历史上 Top N。
5.3 使用分布式文件系统
在多机训练中,检查点可能由多个 rank 生成。通常只需让 rank 0 保存完整检查点。但若需要从任何节点独立恢复,每个 rank 需要保存各自的优化器状态。此时文件体系设计要避免冲突,并考虑挂载的共享存储性能。
6. 常见陷阱与错误
-
只保存模型参数,忽略优化器
恢复训练后初期 loss 突然飙升,收敛打乱,完全背离原有轨迹。 -
覆盖保存导致最佳模型丢失
在无判断逻辑下每次都覆盖同一个文件名,最终只保留了训练最后的过拟合状态。 -
检查点体积爆炸
保存了完整模型对象(包括分类层、代码依赖等),单个文件可能达到数GB;或使用了不当的数据并行存储方式保存了重复参数。 -
指标误差方向
例如将 loss 的mode错误设为max,导致永远不触发最佳模型保存。 -
丢失 amp scaler 状态
使用混合精度训练但未保存 scaler,恢复后 scale 被重置为初始值,可能造成较长时间无法有效更新。 -
checkpoint 与代码版本不匹配
检查点中模型结构依赖某个仓库分支版本,重新加载时因代码更新导致state_dict加载失败。建议在检查点中记录 git commit hash 或模型类名。
7. 框架工具与实现示例
PyTorch + Ignite / Lightning
- PyTorch Lightning 内置了
ModelCheckpoint回调,支持monitor、mode、every_n_epochs、save_last、save_top_k等参数,同时自动保存优化器和调度器状态,是初学者上手的最佳选择。 - Ignite 中也提供了
Checkpoint和ModelCheckpoint,可将任意对象字典存入 hdf5 或普通文件。
自定义示例(最小化实现):
# 定期保存 + 最佳保存逻辑
def save_ckpt(state, filename):
torch.save(state, filename)
# 在train loop中:
if epoch % save_every == 0:
save_ckpt({...}, f'periodic/ckpt_epoch{epoch}.pt')
if val_score > best_score:
best_score = val_score
save_ckpt({...}, 'best/model_best.pt')