数值稳定的 softmax 实现与技巧

FreeGuideOnline 最新 2026-06-21

为什么需要数值稳定的 Softmax

Softmax 函数是深度学习和机器学习中最常见的激活函数之一,它将一个实数向量转换为概率分布。其标准数学定义为:

给定向量 ( \mathbf{z} = [z_1, z_2, \dots, z_n] ),softmax 函数的第 ( i ) 个分量为:

[ \text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} ]

这个公式在数学上完全正确,但在计算机中直接实现时存在严重的数值稳定性问题。指数运算会将微小的输入差异急剧放大:

  • 当 ( z_i ) 较大(如 1000)时,( e^{1000} ) 会超过浮点数表示范围,导致上溢,结果为 infNaN
  • 当 ( z_i ) 较小(如 -1000)时,( e^{-1000} ) 会趋近于 0,造成下溢,分母为 0 时甚至会产生除零错误。

更隐蔽的问题是:即使没有完全上溢或下溢,当某个分量远大于其他分量时,分母中的其他 ( e^{z_j} ) 会被近似忽略,但若仅有一个极大值,计算出的概率会变成 1 和多个 0,导致对数 softmax(交叉熵损失的前置计算)取对数时出现 -inf

解决这些问题的方法就是使用数值稳定的 Softmax


核心技巧:平移不变性

Softmax 具备一个关键的数学性质:对输入向量的每个分量减去同一个常数,输出结果不变。

设 ( c ) 为任意常数,有: [ \text{softmax}(z_i) = \frac{e^{z_i - c}}{\sum_{j} e^{z_j - c}} ]

因为分子分母同时乘以 ( e^{-c} ),原式的值保持不变。我们可以利用这个性质,选择一个恰当的 ( c ) 来约束指数输入的范围,从而避免溢出。

在工程实践中,通常选择 ( c = \max_k z_k ),即减去输入向量中的最大值。这样:

  • 所有输入偏移后的最大值变为 0,其余值 ≤ 0。
  • 指数的输入范围变为 ( (-\infty, 0] ),( e^0 = 1 ),其余值在 0 到 1 之间。
  • 从根本上杜绝了上溢(因为没有任何大于 0 的输入),而下溢虽仍可能发生(( e^{-k} ) 非常趋近于 0),但分母中至少有一个 1,保证了分母不为零,数值完全可控。

分步实现:从朴素到稳定

1. 朴素实现(不稳定)

import numpy as np

def unstable_softmax(z):
    exp_z = np.exp(z)          # 可能产生 inf
    return exp_z / np.sum(exp_z)

[1000, 1000, 1000] 这样的向量,np.exp(1000) 会返回 inf,最终结果变为 [nan, nan, nan]

2. 稳定实现:减去最大值

def stable_softmax(z):
    z_max = np.max(z)
    exp_z = np.exp(z - z_max)  # 所有输入 ≤ 0,安全
    return exp_z / np.sum(exp_z)

现在即使输入是 [1000, 1000, 1000],减去最大值 1000 后所有值变为 0,exp(0)=1,softmax 输出 [0.333..., 0.333..., 0.333...],正确且稳定。

对于更加极端的输入 [1000, 0, -1000],减去 1000 后得到 [0, -1000, -2000],指数分别为 [1, ~0, ~0],最终概率约为 [1, 0, 0]。计算过程没有任何溢出现象。


三维推广:矩阵与批量计算

在实际神经网络中,logits 的形状通常是 (N, C)(N 个样本,C 个类别)或 (N, C, H, W)(像素级分类)。我们需要沿着类别维度 axis 应用 softmax。

稳定的批量实现只需在指定维度上计算最大值并广播即可。

NumPy 示例:

def stable_softmax_2d(x, axis=-1):
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

PyTorch 示例:

import torch
import torch.nn.functional as F

# PyTorch 的 F.softmax 内部已经实现了数值稳定
output = F.softmax(logits, dim=-1)

# 手动实现(用于教学)
def manual_stable_softmax(x, dim=-1):
    x_max = torch.max(x, dim=dim, keepdim=True).values
    exp_x = torch.exp(x - x_max)
    return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)

TensorFlow / Keras 示例:

import tensorflow as tf

# Keras 内置 softmax 也是数值稳定的
output = tf.nn.softmax(logits, axis=-1)

深度学习框架中的 softmax 实现都自动应用了减去最大值的技巧,因此直接调用即可。


对数 Softmax 与 Log-Sum-Exp 技巧

直接计算 log softmax 的必要性

在训练分类模型时,我们通常先计算 softmax 得到概率,再取对数以配合交叉熵损失。但这样做会导致两次舍入误差,且当 softmax 输出概率为 0 时,log(0) 会得到 -inf

更优的方式是直接计算 log softmax,这需要借助 Log-Sum-Exp 技巧。

log softmax 的定义: [ \log\left(\frac{e^{z_i}}{\sum_j e^{z_j}}\right) = z_i - \log\left(\sum_j e^{z_j}\right) ]

直接对 (\sum e^{z_j}) 取对数会遇到上溢问题。我们可以同样减去最大值: [ \log\left(\sum_j e^{z_j}\right) = c + \log\left(\sum_j e^{z_j - c}\right) ] 通常令 ( c = \max_k z_k ),则: [ \text{log softmax}(z_i) = z_i - \max_k z_k - \log\left(\sum_j e^{z_j - \max_k z_k}\right) ]

实现 log softmax

def stable_log_softmax(z):
    z_max = np.max(z)
    log_sum = np.log(np.sum(np.exp(z - z_max)))
    return z - z_max - log_sum

这个函数返回的结果等价于 np.log(stable_softmax(z)),但数值上更加精确,且永远不会产生 -inf(除非输入本身全为 -inf)。

PyTorch 直接提供:

log_probs = F.log_softmax(logits, dim=-1)

TensorFlow 提供:

log_probs = tf.nn.log_softmax(logits, axis=-1)

交叉熵损失的稳定计算

交叉熵损失通常结合 log softmax 实现,以获得最佳的数值性能。切勿单独调用 softmax 再取 log 后计算交叉熵。

推荐做法是使用框架提供的复合函数,它们内部利用了 log-softmax 的稳定性:

PyTorch:

loss = F.cross_entropy(logits, labels)
# 等价于:loss = F.nll_loss(F.log_softmax(logits, dim=1), labels)

TensorFlow:

loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
loss = tf.reduce_mean(loss)

这种做法避免了从 softmax 到概率再到 log 的数值往返,将整个计算融合为一个稳定的操作。


常见误区与调试清单

  1. 在低精度数据类型下使用 softmax
    FP16(半精度)的表示范围有限,最大值仅为 65504。偏移后可能仍出现上溢。解决方案:在 FP16 训练时,框架通常会在内部将 softmax 计算提升为 FP32,或使用分块计算。

  2. 手动实现时忘记保持维度
    在广播计算 exp_x / sum(exp_x) 时,必须正确设置 keepdim=True 或显式重塑形状,否则维度不匹配会导致错误或非预期的广播。

  3. 混淆 log softmax 和 softmax
    在计算交叉熵损失时,传入 softmax 而非 logits,然后框架再次应用 softmax,等于进行了双重指数运算,结果完全错误且损失不下降。

  4. 忽略负无穷的输入
    如果输入向量中存在 -inf(如掩码处理后的填充位置),减去最大值后仍可能得到 -infexp(-inf) 为 0,这是合法行为。但若所有值都是 -inf,log-sum-exp 会返回 -inf。需根据业务逻辑进行掩码豁免。


总结

  • 数值稳定 Softmax 的核心是减去最大值(平移不变性)。
  • 直接计算 log softmax 使用 Log-Sum-Exp 技巧,避免中间下溢和取反对 0 的不稳定。
  • 在训练分类网络时,始终使用框架内置的 cross_entropy 等复合损失函数,它们已融合 log softmax,数值最优。
  • 自己实现时注意维度保持和数据类型精度,尤其是在半精度或混合精度环境下。

掌握了这些技巧,你就能写出在任何输入范围和硬件条件下都稳健运行的 softmax 计算,远离 NaN-inf 的困扰。