AdamW:解耦权重衰减的自适应优化器

FreeGuideOnline 最新 2026-06-21

初始化:参数 θ₀,学习率 η,权重衰减 λ,矩估计衰减 β₁, β₂,ε = 1e-8 m₀ = 0, v₀ = 0, t = 0

while 未收敛: t = t + 1 gₜ = 计算当前mini-batch的损失梯度 (无正则化) mₜ = β₁ * mₜ₋₁ + (1-β₁) * gₜ vₜ = β₂ * vₜ₋₁ + (1-β₂) * gₜ² m̂ₜ = mₜ / (1 - β₁ᵗ) v̂ₜ = vₜ / (1 - β₂ᵗ) θₜ = θₜ₋₁ - η * m̂ₜ / (√v̂ₜ + ε) # 自适应更新 θₜ = θₜ - η * λ * θₜ₋₁ # 权重衰减(有些实现直接使用 θₜ₋₁) # 注意:在代码实现中常直接写为一行更新


## 7. 实际代码示例(PyTorch)

PyTorch 从 1.0 版本开始直接提供 `torch.optim.AdamW`,使用方法与 Adam 几乎一致,但需正确设置 `weight_decay`。

```python
import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单模型
model = nn.Linear(10, 2)

# AdamW 优化器
# 注意:weight_decay 参数就是解耦的权重衰减系数 λ
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

# 训练循环片段
for epoch in range(100):
    for inputs, targets in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()                 # 这里的梯度不包含权重衰减项
        optimizer.step()                # AdamW 内部完成解耦权重衰减

关键提示

  • weight_decay 的取值通常比 L2 正则化的 λ 略大,常用范围 1e-21e-1,建议从 1e-2 开始调节。
  • 如果之前使用 Adam 并添加 L2 正则化(在损失中加惩罚项),迁移到 AdamW 时,应去掉损失中的正则化,只依赖优化器的 weight_decay
  • 对于学习率调度,AdamW 同样兼容,如配合 torch.optim.lr_scheduler.CosineAnnealingLR 使用。

8. AdamW 的优势

  • 更好的泛化性:解耦后的权重衰减使模型的测试性能(尤其是大型网络)优于带 L2 正则化的 Adam,与经过精细调参的 SGD with momentum 相当甚至更好。
  • 超参数分离:学习率 η 控制自适应更新步长,权重衰减 λ 独立控制正则化强度,调试更容易。
  • 适用于现代网络:在 Transformer、CNN 等架构中,AdamW 已成为事实标准,与学习率预热、余弦退火等策略配合丝滑流畅。
  • 与自适应算法天然兼容:保留了 Adam 对稀疏梯度、非平稳目标的自适应优势,同时修正了正则化偏差。

9. 注意事项与调参建议

  • 权重衰减和 L2 正则化不可同时使用:若在损失函数中添加了 weight_decay 对应的 L2 项,同时又使用 AdamW 的 weight_decay,会导致衰减过度。
  • 权重衰减系数的量级:典型值位于 0.010.1,某些场景(如 fine-tuning)可能低至 1e-4。过大可能引起欠拟合,过小则正则化不足。
  • 与 BatchNorm 的交互:通常 BatchNorm 的 γ 和 β 参数不应施加权重衰减,可在 PyTorch 中通过 no_decay 分组策略排除:
    param_groups = [
        {'params': [p for n, p in model.named_parameters() if 'bias' not in n and 'bn' not in n], 'weight_decay': 1e-2},
        {'params': [p for n, p in model.named_parameters() if 'bias' in n or 'bn' in n], 'weight_decay': 0.0}
    ]
    optimizer = optim.AdamW(param_groups, lr=1e-3)