焦点损失 Focal Loss:解决类别不平衡的利器
什么是焦点损失 (Focal Loss)?
焦点损失是一种专门为解决类别极度不平衡问题而设计的损失函数,最初由 Facebook AI Research 在目标检测模型 RetinaNet 中提出。在二分类或多分类任务中,当负样本(背景)数量远超正样本(前景)时,常规的交叉熵损失容易被大量简单负样本主导,导致模型难以有效学习稀有正样本的特征。
焦点损失通过引入一个调制因子,动态降低对已分类正确样本的损失贡献,使模型更加聚焦于难分类样本和少数类样本,从而在不增加计算复杂度的前提下大幅提升模型处理不平衡数据的能力。
为什么交叉熵损失会失效?
标准交叉熵 (Cross-Entropy) 回顾
对于二分类问题,标准交叉熵损失定义为:
[ \text{CE}(p, y) = \begin{cases} -\log(p) & \text{if } y = 1 \ -\log(1-p) & \text{if } y = 0 \end{cases} ]
其中 ( p \in [0,1] ) 是模型预测正类的概率,( y \in {0,1} ) 是真实标签。为简化书写,通常定义 ( p_t ):
[ p_t = \begin{cases} p & \text{if } y = 1 \ 1-p & \text{if } y = 0 \end{cases} ]
于是交叉熵可统一写成:
[ \text{CE}(p_t) = -\log(p_t) ]
核心缺陷:简单样本贡献压倒多数
当数据集严重不平衡时(例如 99% 的背景样本),绝大部分训练样本都是“简单负样本”——模型很快就能以高置信度将其分类正确(( p_t \gg 0.5 ))。虽然单个简单样本的损失 ( -\log(p_t) ) 较小,但因其数量巨大,累计的损失值仍然会远大于数量极少的难样本(正样本)的总损失。结果:
- 训练早期,梯度主要被简单样本驱动,模型朝向忽略正样本的方向收敛。
- 难分类的正样本与难分类的负样本被淹没在大量简单样本的损失中,无法得到充分学习。
- 最终模型可能将所有样本都预测为负类,却仍能获得较低的平均损失,产生“高精度、零召回”的陷阱。
焦点损失的定义与核心思想
数学形式
焦点损失在交叉熵的基础上添加了一个调制因子 ((1 - p_t)^\gamma),定义如下:
[ \text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) ]
其中:
- ( p_t ) 含义同上,表示模型对真实类别的预测概率。
- ( \gamma \geq 0 ) 是可调节的聚焦参数 (focusing parameter)。
- ( \alpha_t ) 是类别平衡因子 (balancing factor),与 ( p_t ) 类似,对于正负样本可以取不同的值 ( \alpha ) 和 ( 1-\alpha ),用于直接调节正负样本的权重。
调制因子如何工作?
当 ( \gamma = 0 ) 时,焦点损失退化为带权重 ( \alpha_t ) 的交叉熵。
当 ( \gamma > 0 ) 时,调制因子 ((1 - p_t)^\gamma) 产生以下效果:
- 若样本已被正确分类且置信度高(( p_t \to 1 )),则 ( 1 - p_t \to 0 ),调制因子接近 0,该样本的损失被大幅降低。
- 若样本分类错误或置信度低(( p_t ) 较小),则 ( 1 - p_t ) 接近 1,调制因子接近 1,该样本的损失几乎不变。
直观理解:( \gamma ) 越大,模型就越“嫌弃”那些已经分类得很好的简单样本,而更加“关注”那些还在苦苦挣扎的难样本。因此,焦点损失天然具有硬样本挖掘的能力,且这种挖掘是软执行、连续可微的,不需要显式地过滤样本。
参数 α 的作用
即使在加入了调制因子后,如果正负样本总数极端悬殊,仅靠聚焦可能还不够。引入平衡因子 ( \alpha_t ) 可以直接提高正样本的整体权重,抑制负样本的权重。通常 α 取值较小(如 0.25)表示降低负样本权重,或根据正样本比例设置,如 α 设为正样本比例的倒数。
实践中,同时调节 γ 和 α 往往能获得最佳效果。论文推荐的一组经典参数为 γ = 2, α = 0.25。
焦点损失的直观特性
-
自动降低简单样本的贡献
无需手动设计样本挖掘规则或级联结构,网络动态调整每个样本的损失尺度。 -
聚焦于难样本,但不完全丢弃简单样本
调制因子平滑地将简单样本损失缩小,而非硬截断,保留了简单样本提供的基础信息。 -
缓解类别不平衡带来的梯度淹没问题
少数类样本即使数量少,只要它们被错误分类,就会产生高损失,其梯度不会被多数类的简单样本完全掩盖。 -
与交叉熵无缝兼容
实现改动极小,只需在计算损失时引入项 ( (1 - p_t)^\gamma ),可轻松嵌入现有模型。
实现示例:PyTorch 代码
以下是一个可直接使用的焦点损失实现,支持二分类与多分类。
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
"""
alpha: 正类权重,当为多分类时可传入张量。
gamma: 聚焦参数,值越大越关注难样本。
reduction: 'mean', 'sum', 或 'none'.
"""
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
# inputs: 预测概率或 logits (shape: [N, C] 多分类,或 [N] 二分类)
# targets: 真实标签 (shape: [N]),二分类为 0/1,多分类为类别索引
if inputs.dim() > 1:
# 多分类情况,使用 softmax
log_p = F.log_softmax(inputs, dim=-1)
pt = torch.exp(log_p)
# 获取真实类别对应的 log_p 和 pt
log_p_gathered = log_p.gather(1, targets.unsqueeze(1)).squeeze(1)
pt_gathered = pt.gather(1, targets.unsqueeze(1)).squeeze(1)
else:
# 二分类情况,使用 sigmoid
p = torch.sigmoid(inputs)
pt = p * targets + (1 - p) * (1 - targets)
log_p = torch.log(pt + 1e-8) # 防数值溢出
log_p_gathered = log_p
pt_gathered = pt
# 计算 alpha_t,二分类根据 targets 选取
if isinstance(self.alpha, (float, int)):
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
else:
# 多分类下 alpha 可为一维权重数组
alpha_t = self.alpha[targets]
# 焦点损失核心计算
focal_weight = (1 - pt_gathered) ** self.gamma
loss = - alpha_t * focal_weight * log_p_gathered
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
使用提示:
- 对于目标检测中的分类子网络,输入通常为 sigmoid 的输出(二分类),可直接使用上述实现。
- 如果模型输出 logits,上述代码已内部处理 softmax 或 sigmoid。
- 如果任务多类极度不平衡,可将
alpha设为与类别频次成反比的张量。
如何选择 γ 和 α ?
-
γ 的选择:
- γ = 0:退化为加权交叉熵,无聚焦效果。
- γ = 2:论文默认值,适用于大多数不平衡场景。
- γ 过大(如 5):过度聚焦于少数超难样本,可能影响拟合稳定性,需谨慎调试。
- 调参建议:从不平衡程度出发,如果正负比例在 1:100 以上,可考虑 γ 从 2 开始实验;如果比例接近但仍需聚焦,γ = 1 可能足够。
-
α 的选择:
- 二分类:α 通常设为 0.25 或 0.5,若正样本占比极低(如 1:1000),可将 α 提高至 0.75~0.9。
- 多分类:可以设为各类别样本频率的倒数(归一化后),或使用固定的 1/类别数。
- α 也可以在训练过程中动态调整,如基于有效样本数。
-
联合调参:
γ 和 α 存在相互补偿关系。增大 γ 会大幅削弱简单样本,此时适当减小 α(接近 0.5)通常可保持平衡;反之,若 γ 较小,可适当增大 α 来强调正样本。
焦点损失的优势与局限
优势
- 直接解决类别不平衡,无需过采/欠采样或生成合成样本。
- 软加权机制比硬负样本挖掘更稳定,梯度更平滑。
- 仅修改损失函数,对网络结构无侵入性,部署成本低。
- 在目标检测、语义分割、医学图像分析等任务中普遍带来显著性能提升。
局限
- 对超参数 γ 和 α 敏感,需要针对具体任务调参。
- 焦点损失主要解决“易分样本占比过大”的问题,若数据集本身极度缺乏难负样本或噪声标签过多,可能仍需数据层面的处理。
- 当 γ 较大时,损失值可能出现数量级差异,影响学习率敏感度,可能需要调整优化器设置。
- 对于类别数量达到数千级别的超多类分类,单一 α 权重向量可能无法精细控制,需考虑类层次化损失。
典型应用场景
- 目标检测:如 RetinaNet、EfficientDet 等单阶段检测器,处理大量背景锚框。
- 医学图像分割:病灶区域占比极小,背景像素占主导。
- 文本分类/情感分析:某些细分类别样本极少。
- 异常检测:正常样本远多于异常样本。
- 人脸识别/重识别:身份类别的长尾分布。
小结
焦点损失通过简单而优雅的调制因子,赋予模型自动关注难样本的能力,成为解决深度学习中类别不平衡问题的一把利器。它让损失函数本身变成一个智能的样本筛选器,堪称从“平等对待每一个样本”到“让模型选择性学习”的一次思想升级。掌握了焦点损失,你就在面对极度不平衡数据时多了一件可靠的武器。
延伸思考:你可以尝试将焦点损失的思想与标签平滑、正则化损失等结合,探索在自身任务上的最佳组合,往往能收获意想不到的增益。