负载均衡损失:防止 MoE 中的专家坍塌
负载均衡损失:入门指南
为什么混合专家模型(MoE)需要负载均衡?
混合专家模型(MoE)通过多个“专家”子网络协同工作,在扩大模型容量的同时保持计算成本可控。每个输入 token 经过门控网络(Router)选择最相关的 1 个或 top-k 个专家进行处理。然而,未经约束的门控网络极易陷入专家坍塌(Expert Collapse):少数几个专家被反复选择,大部分输入被路由到同一批专家,其余专家则被“闲置”,导致模型容量浪费、训练不稳定且泛化能力下降。
专家坍塌的表现:
- 门控网络输出的选择概率高度集中,某几个专家被分配了绝大部分 token。
- 被忽略的专家梯度近乎为零,无法有效学习。
- 实际有效参数量远低于设计值,模型退化为几个大专家,丧失 MoE 的优势。
什么是负载均衡损失?
负载均衡损失(Load Balancing Loss)是一种辅助损失函数,在原有主任务损失(如语言建模交叉熵)上附加一个正则化项,目的是迫使门控网络将 token 均匀地分配给所有专家,从而避免专家坍塌。它在训练时与主损失共同优化,不参与推理。
数学上,假设有 $N$ 个专家,一个批次中有 $T$ 个 token。每个 token 的门控输出向量为 $p(x) \in \mathbb{R}^N$,表示它被路由到各个专家的概率。定义:
- $f_i$:批次中被分配给专家 $i$ 的 token 比例(重要性分数之和,见下文)。
- $P_i$:门控网络对专家 $i$ 的平均分配概率。
负载均衡损失通常设计为鼓励 $f_i$ 与 $1/N$ 的均匀分布接近。最常见的两种形式是 Switch Transformer 损失 和 Auxiliary Loss of GShard。
1. Switch Transformer 负载均衡损失
Switch Transformer(Fedus et al., 2021)提出的辅助损失简单有效,定义如下:
$$\mathcal{L}{\text{aux}} = N \cdot \sum{i=1}^{N} f_i \cdot P_i$$
其中:
- 对于每个 token $x$,门控输出概率为 $p_i(x)$。每个 token 的调度向量 $g(x)$ 为 one-hot 的 top-1 选择(即概率最高的专家)。
- $f_i = \frac{1}{T} \sum_{x \in \text{batch}} \mathbb{1}_{{\text{argmax } p(x) = i}}$,即分配给专家 $i$ 的 token 比例。
- $P_i = \frac{1}{T} \sum_{x \in \text{batch}} p_i(x)$,即被分配到专家 $i$ 的平均概率。
该损失在均匀分配时取得最小值 $1$,因此实际使用时系数 $\alpha$ 作用于 $\mathcal{L}_{\text{aux}}$ 并与主损失相加。PyTorch 风格的伪代码:
def load_balancing_loss(router_probs, expert_indices):
"""
router_probs: (batch_size * seq_len, num_experts) 每个token的专家概率
expert_indices: (batch_size * seq_len,) top-1选择的专家索引
"""
num_experts = router_probs.shape[-1]
# 计算每个专家的平均概率 P_i
P_i = router_probs.mean(dim=0)
# 计算每个专家被选中的比例 f_i
f_i = torch.zeros(num_experts, device=router_probs.device)
expert_counts = torch.bincount(expert_indices, minlength=num_experts)
f_i = expert_counts.float() / expert_indices.shape[0]
# Switch Transformer 负载均衡损失
loss = num_experts * torch.sum(f_i * P_i)
return loss
2. GShard 的辅助损失
GShard(Lepikhin et al., 2020)采用的损失形式稍有不同,它使用门控输出的 top-k 路由后的“重要性分数”之和:
$$\mathcal{L}{\text{aux}} = \sum{i=1}^{N} f_i \cdot g_i$$
其中 $f_i$ 是分配到专家 $i$ 的 token 比例(基于调度决策),$g_i$ 是门控网络对专家 $i$ 的重要性分数之和的均值。本质上与 Switch 损失思想一致,但可以扩展到 top-k 路由场景。
3. Z-loss / 负载均衡的变体
为了避免门控网络产生过大的 logits(导致训练不稳定),有时会引入 Z-loss:
$$\mathcal{L}z = \frac{1}{T} \sum{x} \left(\log \sum_{i} e^{z_i(x)}\right)^2$$
其中 $z_i(x)$ 是门控 logits。这倾向于让门控输出接近均匀,同时也是一种软负载均衡,常与上述负载均衡损失联用。
如何在训练中集成负载均衡损失?
在标准训练循环中,总损失为:
$$\mathcal{L}{\text{total}} = \mathcal{L}{\text{task}} + \alpha \cdot \mathcal{L}_{\text{aux}}$$
- $\alpha$ 是负载均衡系数,通常取值
0.01(Switch Transformer 推荐)。过大的 $\alpha$ 会过度强制均匀分配,损害模型根据内容选择最优专家的能力;过小则无法防止坍塌。 - 仅在训练时计算辅助损失,推理时移除。
代码集成示例(简化):
outputs = model(input_ids, labels=labels)
task_loss = outputs.loss
# 从模型获取门控概率和索引
router_probs = outputs.router_logits # (batch, seq, num_experts)
expert_indices = outputs.expert_indices # (batch, seq)
aux_loss = load_balancing_loss(router_probs.view(-1, num_experts),
expert_indices.view(-1))
total_loss = task_loss + 0.01 * aux_loss
total_loss.backward()
进阶调参与常见陷阱
-
负载均衡系数的选择
从0.001到0.1扫描,监控在验证集上的困惑度(PPL)及各专家被分配的 token 比例直方图。理想状态是各专家占比大致在 $1/N$ 附近,同时任务损失没有明显恶化。 -
与容量因子(Capacity Factor)联合使用
负载均衡损失只解决“选择偏向”,无法解决某专家在单步内接收过多 token 导致显存溢出。常结合容量因子强制每个专家处理的上限 token 数,溢出的 token 被丢弃或走残差连接。负载均衡损失 + 容量约束是 MoE 训练的标配。 -
灵活负载与异构专家
在某些设计中,专家能力可能不对称,此时可以调整目标分布为非均匀,例如让某专家承担更大负载。可以通过修改损失的目标函数实现,但实践中较少使用。 -
动态调整 $\alpha$
训练初期模型可能更需要学习合理路由,可先将 $\alpha$ 设为 0 进行预热,之后再逐渐引入或增大负载均衡损失,防止过早固化路由。 -
监控指标
建议记录并可视化:- 每个专家的 token 分配比例随时间变化曲线。
- 门控概率的熵(越高表示分配越均匀)。
- 辅助损失数值与主损失的比值。
总结
负载均衡损失是混合专家模型训练中守卫性的关键技术。它通过一个简单的辅助损失约束门控网络,确保所有专家都能得到充分训练,杜绝专家坍塌。实现仅需数行代码,却对模型收敛、资源利用率和最终性能影响巨大。掌握其原理和调参方法,是训练大规模 MoE 模型的基本功。
参考实现提示:在实际工程中,您可以在模型前向过程的末尾计算负载均衡损失并返回,或直接在训练脚本中从模型输出提取必要信息即时计算。