训练恢复:继续被意外中断的超长时间训练
训练恢复:继续被意外中断的超长时间训练
对于初学者来说,花费数小时甚至数天训练一个机器学习模型,却因为断电、系统崩溃或手动误操作而中断,无疑是一种灾难。本教程将系统性地讲解如何实现训练恢复,让你能从断点处无缝继续训练,避免宝贵算力和时间的浪费。
1. 什么是训练恢复?为什么需要它?
训练恢复指的是在模型训练过程中,定期或在某些事件触发时保存当前训练的完整状态,并在中断发生后,从最近一次保存的状态重新加载并继续训练,而不是从头开始。
你需要在以下场景中依赖训练恢复:
- 意外中断:硬件故障、断电、进程崩溃、内存不足(OOM)等不可控因素。
- 抢占式资源:使用云平台的抢占式实例(Spot/Preemptible VM),训练可能随时被回收。
- 超长训练:训练周期跨越数天甚至数周,你无法保证环境全程稳定。
- 渐进式训练:先训练一部分,评估后再决定是否继续追加训练。
2. 你必须保存的完整状态
仅保存模型权重(如 model.pt)是不够的。一个可完全恢复的训练状态至少应包含以下内容:
| 保存对象 | 作用 | 示例(代码表述) |
|---|---|---|
| 模型参数 | 模型本身的可学习权重和偏置。 | model.state_dict() |
| 优化器状态 | 动量、自适应学习率参数(如 Adam 的 exp_avg)等。 |
optimizer.state_dict() |
| 学习率调度器状态 | 当前学习率、调度器内部计数器等,确保学习率曲线连续。 | scheduler.state_dict() |
| 训练轮次与步数 | 当前是第几个 epoch、第几个 global step,用于准确恢复循环控制和日志。 | epoch, global_step |
| 随机种子状态 | 保证数据随机增强、采样顺序可复现(重要但非绝对必需)。 | random.getstate(), np.random.get_state(), torch.random.get_rng_state() |
| 数据加载器状态 | 如果用自定义采样器且需要从断点后的数据继续,则需保存数据迭代器状态。 | data_iterator.state_dict() 或记录已消费样本数 |
关键认知:缺少优化器状态将导致恢复后训练初期不稳定,动量被重置,收敛轨迹发生偏移。
3. 检查点保存与恢复策略
在框架中实现一个强健的检查点机制,通常分为保存和加载两个核心函数。
3.1 定义统一的检查点字典
将所有必要组件打包成一个字典,以 .tar 或 .pth 文件形式存储。
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'epoch': epoch,
'global_step': global_step,
'best_metric': best_val_acc, # 你的最佳评估指标
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.random.get_rng_state(),
# 如有必要,加入数据加载器状态
}
3.2 定期保存与最佳模型保存
不要每轮都覆盖同一个文件,这样若保存过程中断电会导致整个检查点损坏。采用轮换保存或原子写入。
- 原子写入:先写入临时文件,再重命名为正式文件。
- 轮换保留:保留最新 N 个检查点(例如
checkpoint_epoch_10.pth,checkpoint_epoch_20.pth),避免磁盘爆满且可回溯。
import os
import torch
def save_checkpoint(state, filename="checkpoint.pth.tar"):
# 原子写入,防止保存时中断损坏文件
tmp_filename = filename + ".tmp"
torch.save(state, tmp_filename)
os.replace(tmp_filename, filename) # 原子操作(在 POSIX 系统上)
同时,单独保存一个只含模型权重的 best_model.pth,用于最终部署。
3.3 恢复时重建设备映射
保存的设备(GPU/CPU)可能和恢复时的环境不同。加载时应使用 map_location 参数,并确保模型与优化器的参数张量被移动到正确的设备。
def load_checkpoint(filename, model, optimizer, scheduler, device):
checkpoint = torch.load(filename, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint['epoch']
global_step = checkpoint['global_step']
# 恢复随机状态(可选)
random.setstate(checkpoint['random_rng_state'])
np.random.set_state(checkpoint['np_rng_state'])
torch.random.set_rng_state(checkpoint['torch_rng_state'])
return epoch, global_step
4. 从断点精确继续训练
恢复后,你需要从保存的 epoch 和 global_step 继续循环,而不是从 0 开始。
4.1 正确的循环控制
start_epoch = 0
global_step = 0
# 如果检查点存在
if os.path.exists(CHECKPOINT_PATH):
start_epoch, global_step = load_checkpoint(...)
# 训练循环将从 start_epoch + 1 开始
start_epoch += 1
for epoch in range(start_epoch, num_epochs):
for batch_idx, (data, target) in enumerate(train_loader):
# ... 训练逻辑 ...
global_step += 1
# 可在每 N 步保存
4.2 数据流的中断处理
如果你的数据加载顺序对训练有影响(如课程学习),你需要保存 DataLoader 的精确状态。在某些框架中(如 PyTorch),DataLoader 不支持直接序列化状态。替代方案:
- 记录已处理的样本索引,重新创建数据集子集。
- 使用可恢复的 Sampler,保存 Sampler 状态。
- 或简单地在每个 epoch 开始时重新随机,接受某些微小差异,这通常是可接受的。
5. 跨框架实现示例
尽管上述代码基于 PyTorch,其思想完全适用于其他框架:
- TensorFlow/Keras:使用
tf.train.Checkpoint管理全部对象,调用checkpoint.save和checkpoint.restore。 - PyTorch Lightning:内置
ModelCheckpoint回调,自动保存epoch,global_step,state_dict等。 - Hugging Face Transformers:
Trainer类自动处理检查点保存,恢复时直接传入resume_from_checkpoint参数即可。
无论哪个框架,核心原则不变:完整保存可恢复的运行时状态,并在恢复时正确衔接训练循环。
6. 常见故障排查与注意事项
- 恢复后损失跳动/精度下降:通常是优化器状态未正确恢复。检查
optimizer.load_state_dict是否执行,并确认优化器创建的设备与模型一致。 - GPU 内存状态不一致:不要试图保存 CUDA 的随机状态并在不同环境中恢复。保存 CPU 随机状态,并在恢复后通过
torch.cuda.manual_seed重新设置。 - 检查点损坏:始终启用原子写入,必要时保留 CRC 校验。定时将远程备份检查点文件。
- 学习率曲线错乱:恢复后若 scheduler 的
last_epoch未正确设置,学习率会大幅跳变。建议在加载调度器状态后额外设置scheduler.last_epoch = start_epoch。
7. 总结:你的训练恢复清单
从今天起,将以下流程纳入你的训练脚本,让“中断”不再可怕。
- 定义检查点内容:模型、优化器、调度器、epoch/step、度量值。
- 原子保存:先写临时文件再替换,避免写入时损坏。
- 条件恢复:训练启动时检查是否存在检查点,存在则加载全部状态。
- 循环同步:从保存的
epoch+1开始,恢复global_step供日志使用。 - 定期清理:仅保留最近几个检查点,节省空间。
- 始终确保一个独立的最佳模型文件,与可恢复检查点分开存放。
实现一次稳健的恢复机制,便能在所有后续项目中使用。现在,你可以放心地运行那些需要千卡 GPU · 周级的训练任务了。