知识蒸馏入门:教师-学生网络的软化标签训练

FreeGuideOnline 最新 2026-06-13

什么是知识蒸馏

知识蒸馏是一种模型压缩技术,核心思想是将一个大型、复杂的“教师网络”所学到的知识,迁移到一个更小、更快的“学生网络”中,从而在保持较高精度的同时降低计算成本和存储需求。

在传统监督学习中,模型直接从“硬标签”中学习。例如,一张猫的图片,标签为 [1, 0, 0](猫、狗、汽车三分类)。硬标签只告诉模型“这是猫”,但忽略了类别之间丰富的相似性关系。猫和狗在外观上有相似之处,而猫和汽车则完全不同,这些信息在硬标签中荡然无存。

知识蒸馏的关键在于“软化标签”,也就是教师网络输出的概率分布。虽然教师网络最终预测猫的概率可能高达 0.98,但它在狗类别上可能给出 0.015,在汽车类别上仅给出 0.005。这些相对较小的概率携带了教师网络学到的“暗知识”——类别结构、样本间的微妙差异等信息。学生网络通过模仿这一软标签分布,可以学习到比单纯拟合硬标签更丰富的特征表达。

知识蒸馏的核心原理

知识蒸馏的训练过程通常包含以下几个关键要素:

  • 教师网络:一个预训练好的大型模型,通常具有很高的准确率,但参数多、推理慢。教师网络的输出被用作学生网络的训练目标之一。
  • 学生网络:一个结构更简单的模型,参数量少,推理速度快。学生网络同时接受两种监督信号。
  • 软化标签:通过调节 softmax 函数的“温度”参数 T 来控制概率分布的平滑程度。温度越高,类别间的概率差异越小,分布越“软”,能暴露更多类间关系。
  • 蒸馏损失:指导学生网络去匹配教师网络软化后的输出。
  • 学生损失:学生网络自身的硬标签预测损失,用于保证基本分类准确率。

最终的训练目标是将蒸馏损失和学生损失按一定权重相加,共同优化学生网络。

为什么软标签有效

软标签之所以能传递更多知识,是因为它编码了类别之间的相似性。以一个手写数字识别任务为例,数字“7”和数字“1”在书写上可能有些相似,教师网络在预测“7”时,会在“1”的类别上给出相对较高的概率(如 0.12),而其他不相关的数字则概率很低。学生网络通过拟合这种概率分布,能够学到“7和1相似”这一先验,从而在没有看到大量数据的情况下也能更好地泛化。

从信息论角度看,硬标签所包含的信息量极少,每样本只能提供约 log(K) 比特的信息(K为类别数)。而软标签相当于对每个样本提供了一个分布,其熵远大于硬标签,意味着教师网络为每个训练样本注入了更多的监督信息,等效于对数据进行了有效扩充。

温度参数 T 的作用

温度 T 是知识蒸馏中最重要的超参数。softmax 函数的一般形式为:

[ q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} ]

  • T = 1 时,就是标准的 softmax,输出概率会较为集中。
  • T > 1 时,概率分布被“软化”,大值和小值之间的差距缩小,使得原本接近于0的概率值也变为可感知的数值,从而暴露出教师网络对于非最可能类别的置信度。
  • T → ∞ 时,所有概率趋于均一化,此时知识全部碾平,无学习价值。
  • T < 1 时,分布更陡峭,会放大峰值,不利于传递类间信息。

通常实践中将 T 设置在 220 之间,需要根据教师网络的置信度水平进行调整。蒸馏时,对教师和学生网络使用相同的温度参数,学生在推理时仍然使用 T=1 进行标准 softmax。

损失函数设计

知识蒸馏的总损失函数一般由两部分组成:

[ \mathcal{L} = \alpha \cdot \mathcal{L}{\text{KD}} + (1 - \alpha) \cdot \mathcal{L}{\text{CE}} ]

其中,α 为平衡系数(通常取 0.1~0.9)。

蒸馏损失(L_KD)

蒸馏损失度量学生网络软化输出与教师网络软化输出之间的差异。常用 KL 散度或均方误差(MSE)。

若使用 KL 散度,定义如下: [ \mathcal{L}_{\text{KD}} = T^2 \cdot \text{KL}( \text{softmax}(z_t / T) \parallel \text{softmax}(z_s / T) ) ] 乘以 是为了在反向传播时保持梯度的尺度,因为温度缩放会改变损失的大小。

学生损失(L_CE)

学生损失是学生网络常规 softmax(T=1)输出与真实硬标签之间的交叉熵损失,它确保学生网络最终能给出正确的分类结果,避免过度跟随教师网络可能的偏差。

一个简单的代码实现示例

下面以 PyTorch 为例,展示知识蒸馏训练的核心代码片段。

import torch
import torch.nn as nn
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T, alpha):
    # 教师和学生的软标签
    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    soft_student = F.log_softmax(student_logits / T, dim=1)
    
    # 蒸馏损失(KL 散度)乘以 T^2
    kd_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)
    
    # 学生硬标签损失
    ce_loss = F.cross_entropy(student_logits, labels)
    
    return alpha * kd_loss + (1 - alpha) * ce_loss

# 训练循环伪代码
for inputs, labels in dataloader:
    optimizer.zero_grad()
    
    with torch.no_grad():
        teacher_outputs = teacher_model(inputs)
    student_outputs = student_model(inputs)
    
    loss = distillation_loss(student_outputs, teacher_outputs, labels, T=4.0, alpha=0.7)
    loss.backward()
    optimizer.step()

需要注意,教师模型必须设置为评估模式并禁用梯度计算,以减少显存消耗。

知识蒸馏的变体与应用场景

知识蒸馏经过多年发展衍生出多种变体,以下介绍两种典型扩展:

特征蒸馏(Feature-based Distillation)

不仅让学生模仿教师最终的输出层,还要求学生在中间隐藏层的特征表示上与教师对齐。例如,通过最小化教师与学生对应层特征图之间的 MSE 损失或最大化互信息,让学生网络学到更接近教师内部的表征,往往能带来更高的精度提升。这类方法对视觉任务尤为有效。

在线蒸馏(Online Distillation)

教师和学生同时训练,而不是事先存在一个预训练好的教师。多个学生网络之间互相充当彼此的教师,共同提升。或者一个大型网络和一个小网络在训练过程中动态交换知识,省去了单独训练教师模型的步骤,适合端到端训练要求。

知识蒸馏广泛应用于:

  • 模型压缩:在移动端或嵌入式设备上部署高精度模型。
  • 半监督学习:用无标签数据,通过教师模型生成伪软标签来训练学生。
  • 隐私保护与联邦学习:在本地学生模型上蒸馏全局教师模型的知识,用于聚合更新。
  • 多任务学习:从多个教师模型中蒸馏不同任务的知识到统一的学生模型中。

关键注意事项与调参建议

  • 教师模型的选择:教师模型并非越大越好,应与学生模型的结构具有一定的同源性或兼容性。差异过大时,学生可能难以模仿。
  • 温度与 α 的配合:如果 T 较大,软标签非常平滑,此时应增大 KD 损失的权重(α 较大),因为硬标签本身的监督强度较低。反之,如果 T 接近 1,KD 损失贡献的信号弱,应降低 α。
  • 避免过拟合:学生网络容量小,如果完全模仿教师可能丢失一些特异性。引入少量噪声或使用标签平滑可以缓解。
  • 批量大小:由于软标签的维度包含丰富信息,知识蒸馏对较大的批量大小更友好,有助于稳定训练分布。
  • 评估指标:最终以学生网络在验证集上的标准准确率(T=1)为评判依据,而不是蒸馏损失的大小。

知识蒸馏以简单有效的方式弥合了大模型与小模型之间的性能鸿沟。理解软化标签和温度机制,便掌握了这一技术的核心心法。随着大模型的普及,知识蒸馏作为一种低成本的知识迁移手段,实用价值愈发凸显。