数值稳定的 softmax 实现与技巧
为什么需要数值稳定的 Softmax
Softmax 函数是深度学习和机器学习中最常见的激活函数之一,它将一个实数向量转换为概率分布。其标准数学定义为:
给定向量 ( \mathbf{z} = [z_1, z_2, \dots, z_n] ),softmax 函数的第 ( i ) 个分量为:
[ \text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} ]
这个公式在数学上完全正确,但在计算机中直接实现时存在严重的数值稳定性问题。指数运算会将微小的输入差异急剧放大:
- 当 ( z_i ) 较大(如 1000)时,( e^{1000} ) 会超过浮点数表示范围,导致上溢,结果为
inf或NaN。 - 当 ( z_i ) 较小(如 -1000)时,( e^{-1000} ) 会趋近于 0,造成下溢,分母为 0 时甚至会产生除零错误。
更隐蔽的问题是:即使没有完全上溢或下溢,当某个分量远大于其他分量时,分母中的其他 ( e^{z_j} ) 会被近似忽略,但若仅有一个极大值,计算出的概率会变成 1 和多个 0,导致对数 softmax(交叉熵损失的前置计算)取对数时出现 -inf。
解决这些问题的方法就是使用数值稳定的 Softmax。
核心技巧:平移不变性
Softmax 具备一个关键的数学性质:对输入向量的每个分量减去同一个常数,输出结果不变。
设 ( c ) 为任意常数,有: [ \text{softmax}(z_i) = \frac{e^{z_i - c}}{\sum_{j} e^{z_j - c}} ]
因为分子分母同时乘以 ( e^{-c} ),原式的值保持不变。我们可以利用这个性质,选择一个恰当的 ( c ) 来约束指数输入的范围,从而避免溢出。
在工程实践中,通常选择 ( c = \max_k z_k ),即减去输入向量中的最大值。这样:
- 所有输入偏移后的最大值变为 0,其余值 ≤ 0。
- 指数的输入范围变为 ( (-\infty, 0] ),( e^0 = 1 ),其余值在 0 到 1 之间。
- 从根本上杜绝了上溢(因为没有任何大于 0 的输入),而下溢虽仍可能发生(( e^{-k} ) 非常趋近于 0),但分母中至少有一个 1,保证了分母不为零,数值完全可控。
分步实现:从朴素到稳定
1. 朴素实现(不稳定)
import numpy as np
def unstable_softmax(z):
exp_z = np.exp(z) # 可能产生 inf
return exp_z / np.sum(exp_z)
对 [1000, 1000, 1000] 这样的向量,np.exp(1000) 会返回 inf,最终结果变为 [nan, nan, nan]。
2. 稳定实现:减去最大值
def stable_softmax(z):
z_max = np.max(z)
exp_z = np.exp(z - z_max) # 所有输入 ≤ 0,安全
return exp_z / np.sum(exp_z)
现在即使输入是 [1000, 1000, 1000],减去最大值 1000 后所有值变为 0,exp(0)=1,softmax 输出 [0.333..., 0.333..., 0.333...],正确且稳定。
对于更加极端的输入 [1000, 0, -1000],减去 1000 后得到 [0, -1000, -2000],指数分别为 [1, ~0, ~0],最终概率约为 [1, 0, 0]。计算过程没有任何溢出现象。
三维推广:矩阵与批量计算
在实际神经网络中,logits 的形状通常是 (N, C)(N 个样本,C 个类别)或 (N, C, H, W)(像素级分类)。我们需要沿着类别维度 axis 应用 softmax。
稳定的批量实现只需在指定维度上计算最大值并广播即可。
NumPy 示例:
def stable_softmax_2d(x, axis=-1):
x_max = np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
PyTorch 示例:
import torch
import torch.nn.functional as F
# PyTorch 的 F.softmax 内部已经实现了数值稳定
output = F.softmax(logits, dim=-1)
# 手动实现(用于教学)
def manual_stable_softmax(x, dim=-1):
x_max = torch.max(x, dim=dim, keepdim=True).values
exp_x = torch.exp(x - x_max)
return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)
TensorFlow / Keras 示例:
import tensorflow as tf
# Keras 内置 softmax 也是数值稳定的
output = tf.nn.softmax(logits, axis=-1)
深度学习框架中的 softmax 实现都自动应用了减去最大值的技巧,因此直接调用即可。
对数 Softmax 与 Log-Sum-Exp 技巧
直接计算 log softmax 的必要性
在训练分类模型时,我们通常先计算 softmax 得到概率,再取对数以配合交叉熵损失。但这样做会导致两次舍入误差,且当 softmax 输出概率为 0 时,log(0) 会得到 -inf。
更优的方式是直接计算 log softmax,这需要借助 Log-Sum-Exp 技巧。
log softmax 的定义: [ \log\left(\frac{e^{z_i}}{\sum_j e^{z_j}}\right) = z_i - \log\left(\sum_j e^{z_j}\right) ]
直接对 (\sum e^{z_j}) 取对数会遇到上溢问题。我们可以同样减去最大值: [ \log\left(\sum_j e^{z_j}\right) = c + \log\left(\sum_j e^{z_j - c}\right) ] 通常令 ( c = \max_k z_k ),则: [ \text{log softmax}(z_i) = z_i - \max_k z_k - \log\left(\sum_j e^{z_j - \max_k z_k}\right) ]
实现 log softmax
def stable_log_softmax(z):
z_max = np.max(z)
log_sum = np.log(np.sum(np.exp(z - z_max)))
return z - z_max - log_sum
这个函数返回的结果等价于 np.log(stable_softmax(z)),但数值上更加精确,且永远不会产生 -inf(除非输入本身全为 -inf)。
PyTorch 直接提供:
log_probs = F.log_softmax(logits, dim=-1)
TensorFlow 提供:
log_probs = tf.nn.log_softmax(logits, axis=-1)
交叉熵损失的稳定计算
交叉熵损失通常结合 log softmax 实现,以获得最佳的数值性能。切勿单独调用 softmax 再取 log 后计算交叉熵。
推荐做法是使用框架提供的复合函数,它们内部利用了 log-softmax 的稳定性:
PyTorch:
loss = F.cross_entropy(logits, labels)
# 等价于:loss = F.nll_loss(F.log_softmax(logits, dim=1), labels)
TensorFlow:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
loss = tf.reduce_mean(loss)
这种做法避免了从 softmax 到概率再到 log 的数值往返,将整个计算融合为一个稳定的操作。
常见误区与调试清单
-
在低精度数据类型下使用 softmax
FP16(半精度)的表示范围有限,最大值仅为 65504。偏移后可能仍出现上溢。解决方案:在 FP16 训练时,框架通常会在内部将 softmax 计算提升为 FP32,或使用分块计算。 -
手动实现时忘记保持维度
在广播计算exp_x / sum(exp_x)时,必须正确设置keepdim=True或显式重塑形状,否则维度不匹配会导致错误或非预期的广播。 -
混淆 log softmax 和 softmax
在计算交叉熵损失时,传入softmax而非logits,然后框架再次应用 softmax,等于进行了双重指数运算,结果完全错误且损失不下降。 -
忽略负无穷的输入
如果输入向量中存在-inf(如掩码处理后的填充位置),减去最大值后仍可能得到-inf,exp(-inf)为 0,这是合法行为。但若所有值都是-inf,log-sum-exp 会返回-inf。需根据业务逻辑进行掩码豁免。
总结
- 数值稳定 Softmax 的核心是减去最大值(平移不变性)。
- 直接计算 log softmax 使用 Log-Sum-Exp 技巧,避免中间下溢和取反对
0的不稳定。 - 在训练分类网络时,始终使用框架内置的
cross_entropy等复合损失函数,它们已融合 log softmax,数值最优。 - 自己实现时注意维度保持和数据类型精度,尤其是在半精度或混合精度环境下。
掌握了这些技巧,你就能写出在任何输入范围和硬件条件下都稳健运行的 softmax 计算,远离 NaN 和 -inf 的困扰。