LAMB 与 LARS:大规模批次训练的优化器

FreeGuideOnline 最新 2026-06-21

LAMB 与 LARS:大规模批次训练优化器完全指南

在深度学习领域,为了加速模型训练,研究人员和工程师通常会采用数据并行并增大批次大小(Batch Size)的策略。然而,简单地增大 Batch Size 往往会导致模型精度下降。LARS 和 LAMB 作为专为大规模批次训练设计的优化器,有效解决了这一困境。本教程将从基础概念出发,带你深入理解这两种强大的优化算法。

为什么需要大 Batch 训练?

使用大 Batch 训练主要有以下优势:

  • 硬件利用率高:更大的矩阵乘法更充分地利用 GPU 的计算能力,提升硬件吞吐量。
  • 分布式扩展性强:通过数据并行,将大 Batch 分散到多个节点,可线性缩短训练时间。
  • 降低通信开销:在分布式系统中,更大的计算量与固定的通信量之比更高,使得扩展效率更优。

然而,当 Batch Size 超过某个阈值(如 8K 或 32K)后,训练往往变得不稳定,且最终模型的泛化性能会显著下降。

大 Batch 训练面临的优化困境

经典优化器(如 SGD with Momentum、Adam)在处理超大 Batch 时,会遇到两个核心问题:

  1. 尖锐极小值:小 Batch 训练倾向于收敛到较 "平缓" 的极小值,其泛化能力更强;而大 Batch 训练更易陷入 "尖锐" 极小值,导致测试误差增大。直观理解是,平缓的极小值对输入扰动不敏感,因此更鲁棒。

  2. 层间更新不平衡:神经网络不同层的权重和梯度的尺度往往差异极大。例如,卷积层的梯度范数可能远小于全连接层。在全局使用统一的学习率时,梯度较小的层几乎无法更新,而梯度巨大的层则可能发散。这种不平衡在大 Batch 下被进一步放大,因为梯度估计的方差更小,这种差异更加固化。

为了同时享受大 Batch 的训练速度与小 Batch 的泛化能力,需要设计专门的优化器。

LARS:层级自适应率缩放

LARS(Layer-wise Adaptive Rate Scaling)于 2017 年提出,成功实现了在不损失精度的情况下,使用高达 32K 的巨大 Batch 训练 ResNet。它的核心思想是:信任比例

核心原理

对于每一层 ( l ),LARS 不再使用全局学习率,而是为该层计算一个局部学习率 ( \eta^l ),其计算公式为:

[ \eta^l = \eta \times \frac{||w^l||}{||\nabla L(w^l)|| + \lambda ||w^l||} ]

其中:

  • ( \eta ) 是全局基础学习率。
  • ( w^l ) 是第 ( l ) 层的权重。
  • ( \nabla L(w^l) ) 是第 ( l ) 层的梯度。
  • ( || \cdot || ) 通常取 L2 范数。
  • ( \lambda ) 是权重衰减系数。

分母中加入 ( \lambda ||w^l|| ) 是为了在权重衰减存在时保持缩放比的稳定性。

直观理解

“信任比例” 的意思是:我们对该层权重更新的相对大小设定一个上限。这是一个比率,表示该层参数更新的范数相对于权重范数的变化幅度。如果梯度相对权重过大(即 ( \frac{||\nabla L||}{||w||} ) 很大),层学习率会被缩小,防止更新破坏已有知识;反之,如果梯度相对权重很小,学习率会被放大,让该层也能有效学习。

LARS 更新步骤

LARS 通常与带动的 SGD 结合,流程如下:

  1. 计算常规的动量更新 ( v_{t+1} = \mu v_t + g_t )。
  2. 对每一层,计算缩放因子 ( \gamma = \frac{||w^l||}{||v_{t+1}|| + \lambda ||w^l||} )。
  3. 该层的实际学习率为 ( \eta^l = \eta \times \gamma )。
  4. 参数更新:( w_{t+1} = w_t - \eta^l v_{t+1} )。

这样,各层都能以适合自身尺度的步长进行更新,极大缓解了大 Batch 下的层间不平衡问题。

LAMB:带有动量和归一化的层自适应优化器

LAMB(Layer-wise Adaptive Moments optimizer for Batch training)是在 LARS 基础上为自适应矩估计(Adam)家族量身定做的优化器,专为 Transformer 等架构的大规模训练而设计(如训练 BERT)。

核心改进

Adam 这类优化器通过计算梯度的第一和第二阶矩(( m_t, v_t ))来自适应调整每个参数的学习率。然而在大 Batch 下,仍然存在层间不平衡导致的训练困难。LAMB 巧妙地将 LARS 的信任比例思想与 Adam 的更新规则相结合。

LAMB 更新公式

对于每一层 ( l ),LAMB 的更新规则如下:

  1. 计算带动的梯度一阶矩 ( m_t ) 和二阶矩 ( v_t )(与 Adam 完全相同)。
  2. 计算 Adam 风格的更新方向: [ r_t = \frac{m_t}{\sqrt{v_t} + \epsilon} ]
  3. 应用 LARS 式的层级缩放: [ \eta^l = \eta \times \frac{||w^l||}{||r_t + \lambda w^l||} ]
  4. 最终更新: [ w_{t+1} = w_t - \eta^l (r_t + \lambda w^l) ]

这里,缩放因子作用于归一化后的更新方向 ( r_t ) 与权重衰减项 ( \lambda w^l ) 的合向量。这种做法既保留了 Adam 自适应每个参数步长的优点,又通过范数比实现了层间的平衡,还能天然地将权重衰减与自适应学习率解耦。

LAMB 的优势

  • 对学习率不敏感:由于每层都有自己的缩放因子,LAMB 可以在极宽的全局学习率范围内正常工作。
  • 训练极深网络:在训练 BERT 等深度 Transformer 模型时,使用 32K 甚至 64K 的超大 Batch 也能达到与基线相同的精度。
  • 快速收敛:配合适当的学习率 warmup,LAMB 能极大缩短训练时间。

LAMB 与 LARS 的核心对比

特性 LARS LAMB
基础优化器 SGD with Momentum Adam (带有第一/二阶矩估计)
自适应方式 仅层级自适应 层级自适应 + 元素级自适应 (来自Adam)
更新公式 缩放因子作用于动量后的梯度 缩放因子作用于归一化后的梯度 ( r_t )
典型应用场景 卷积神经网络 (ResNet等) Transformer 模型 (BERT, ViT等)
大 Batch 稳定性 优秀,可达 32K+ 极佳,可达 64K+
学习率 warmup 推荐使用 几乎必备,通常配合线性 warmup

简单来说,LAMB 是 LARS 在自适应学习率优化器上的自然延伸,将层级自适应的优势带入了 Adam 体系。

实践:如何配置和使用

关键超参数设置

  • 基础学习率 ( \eta ):通常需要选用比单机小 Batch 训练更大的值。例如,训练 ResNet-50 时,LARS 的基础学习率可设为 0.1 甚至更高;训练 BERT 时,LAMB 的学习率常设置在 1e-3 附近。核心原则是采用线性缩放法则:当 Batch Size 增大 ( k ) 倍时,学习率也相应增大 ( k ) 倍。
  • 学习率 Warmup:大 Batch 训练初期梯度估计方差较大,突然使用大学习率容易导致发散。Warmup 策略(如线性增加学习率从 0 到目标值)至关重要。对于 LAMB,通常设置 warmup 步数为总步数的 2%~10%。
  • 权重衰减:应与学习率解耦。LAMB 论文中明确推荐将权重衰减与自适应学习率分离开,公式中的 ( \lambda w^l ) 直接加入更新项,而非先缩放后衰减。

代码示例(PyTorch 风格)

目前 PyTorch 官方并未内置 LARS/LAMB,但可通过第三方库或自定义实现。以下以简化的 LAMB 为例展示核心逻辑:

import torch
from torch.optim import Optimizer

class LAMB(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(LAMB, self).__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('LAMB does not support sparse gradients')

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
                state['step'] += 1
                
                # Adam update
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                # Bias correction
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                # Adam direction
                adam_step = exp_avg / bias_correction1 / (exp_avg_sq.sqrt() / bias_correction2.sqrt() + group['eps'])
                
                # Weight decay term (decoupled)
                wd_term = p.data
                update = adam_step + group['weight_decay'] * wd_term
                
                # Layer-wise scaling
                w_norm = p.data.norm(2)
                update_norm = update.norm(2)
                trust_ratio = 1.0
                if w_norm > 0 and update_norm > 0:
                    trust_ratio = w_norm / update_norm
                
                # Final update
                p.data.add_(update, alpha=-group['lr'] * trust_ratio)

注意事项

  • BatchNorm 层处理:实践中 LARS/LAMB 通常不应用层级自适应缩放于 BatchNorm 的参数(因为其尺度含义不同),直接将它们的 trust_ratio 设为 1.0。
  • 混合精度训练:LAMB 配合半精度 (FP16) 训练通常会使用动态损失缩放,需保持缩放因子的一致性。

总结与建议

  • 何时选择 LARS:当你的模型以卷积网络为主(如 ResNet、EfficientNet),Batch Size 超过 4K,且你希望继续使用带动量的 SGD 框架时,LARS 是一个成熟的选择。
  • 何时选择 LAMB:训练 Transformer 系列模型(BERT、GPT、ViT),或者任何使用 Adam 系列优化器且需要扩展到大 Batch 的场景。LAMB 几乎是大 Batch 训练 BERT 的事实标准。
  • 通用最佳实践
    1. 采用线性学习率缩放,并配合足够的 warmup。
    2. 将权重衰减与自适应学习率解耦。
    3. 排除 BatchNorm 层,不进行层次自适应缩放。
    4. 监控各层的信任比例 ( \frac{||w||}{||update||} ),若发现某些层该比率剧烈波动,需调整基础学习率或 warmup 策略。

借助 LARS 和 LAMB,大 Batch 训练不再是泛化杀手,而是实现极速分布式训练的关键技术。希望你通过本教程能熟练驾驭这两种优化利器。