师生网络蒸馏实战:结构差异下的知识传递
师生网络蒸馏实战:结构差异下的知识传递
什么是知识蒸馏
知识蒸馏是一种模型压缩与迁移技术,通过让轻量化的学生网络学习大型教师网络丰富的输出分布,实现知识的高效传递。传统训练只依赖硬标签,而蒸馏利用教师网络输出的软标签,其中包含了类别间相似性等暗知识,大幅提升小模型的泛化能力。
为什么需要结构差异下的蒸馏
实际应用中,教师与学生网络的架构、深度、宽度往往存在显著差异:
- 教师:深层ResNet、ViT、集成模型,计算代价高
- 学生:MobileNet、ShuffleNet、浅层CNN,追求低延迟与低功耗
不同结构意味着特征空间不匹配,直接对齐中间层特征或输出可能失效。本教程将聚焦如何在异构模型间稳定地传递知识。
蒸馏的核心数学原理
对于输入 $X$,教师输出 logits 为 $z_t$,学生输出 logits 为 $z_s$。蒸馏损失由两部分组成:
$$ \mathcal{L} = \alpha \cdot \mathcal{L}{\text{CE}}(y, \sigma(z_s)) + (1-\alpha) \cdot T^2 \cdot \mathcal{L}{\text{KL}}(\sigma(\frac{z_t}{T}), \sigma(\frac{z_s}{T})) $$
- $\sigma$ 为 softmax 函数
- $\mathcal{L}_{\text{CE}}$ 为标准交叉熵,使用真实标签 $y$
- $\mathcal{L}_{\text{KL}}$ 为 KL 散度,衡量教师与学生的输出分布差异
- $T$ 为温度系数,平滑分布,一般取值 3~20
- $\alpha$ 平衡硬标签与软标签的权重
温度越高,分布越平缓,更能揭示类间关系。
实战环境搭建
pip install torch torchvision pytorch-lightning matplotlib
导入所需库:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader
import pytorch_lightning as pl
定义异构师生模型
教师:ResNet-18(约11M参数)
学生:三卷积层小网络(约0.1M参数)
class TinyStudent(nn.Module):
def __init__(self, num_classes=100):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
)
self.classifier = nn.Linear(128 * 4 * 4, num_classes)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
from torchvision.models import resnet18
teacher = resnet18(pretrained=False, num_classes=100)
student = TinyStudent(num_classes=100)
蒸馏损失函数实现
class DistillationLoss(nn.Module):
def __init__(self, temperature=5.0, alpha=0.7):
super().__init__()
self.T = temperature
self.alpha = alpha
def forward(self, student_logits, teacher_logits, labels):
# 硬标签损失
hard_loss = F.cross_entropy(student_logits, labels)
# 软标签损失 (KL散度)
soft_student = F.log_softmax(student_logits / self.T, dim=1)
soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.T ** 2)
return self.alpha * hard_loss + (1 - self.alpha) * soft_loss
PyTorch Lightning 训练模块
class LitDistiller(pl.LightningModule):
def __init__(self, teacher, student, T, lr):
super().__init__()
self.teacher = teacher
self.student = student
self.criterion = DistillationLoss(T=T, alpha=0.7)
self.lr = lr
# 教师固定不动
for p in self.teacher.parameters():
p.requires_grad = False
def forward(self, x):
return self.student(x)
def training_step(self, batch, batch_idx):
x, y = batch
with torch.no_grad():
t_logits = self.teacher(x)
s_logits = self.student(x)
loss = self.criterion(s_logits, t_logits, y)
self.log('train_loss', loss)
# 计算准确率
acc = (s_logits.argmax(dim=1) == y).float().mean()
self.log('train_acc', acc)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
s_logits = self.student(x)
acc = (s_logits.argmax(dim=1) == y).float().mean()
self.log('val_acc', acc)
def configure_optimizers(self):
return torch.optim.AdamW(self.student.parameters(), lr=self.lr)
数据与训练配置
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])
train_set = CIFAR100(root='./data', train=True, download=True, transform=transform)
val_set = CIFAR100(root='./data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
]))
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=4)
启动训练:
teacher.eval() # 教师始终为评估模式
model = LitDistiller(teacher, student, T=5.0, lr=1e-3)
trainer = pl.Trainer(max_epochs=50, accelerator='auto', devices=1,
callbacks=[pl.callbacks.ModelCheckpoint(monitor='val_acc', mode='max')])
trainer.fit(model, train_loader, val_loader)
结构差异下的调优技巧
1. 温度系数 T 的动态调节
学生容量小时,过高的温度会“冲淡”知识。可尝试从 T=3 起步,逐步升高至 10~20。
2. 逐步增加软标签权重
初始alpha=0.9(先学好硬标签),每10轮降低0.1,最终稳定在0.5~0.7。
3. 中间层特征对齐(FitNet思想)
当输出蒸馏效果受限时,可加入一层线性投影,将学生中间特征映射到教师维度,最小化MSE:
class FeatureMapLoss(nn.Module):
def __init__(self, student_ch, teacher_ch):
super().__init__()
self.projector = nn.Linear(student_ch, teacher_ch, bias=False)
def forward(self, s_feat, t_feat):
# s_feat: (B, C_s, H, W) -> 平均池化后投影
s_pool = F.adaptive_avg_pool2d(s_feat, (1,1)).squeeze()
t_pool = F.adaptive_avg_pool2d(t_feat, (1,1)).squeeze()
s_proj = self.projector(s_pool)
return F.mse_loss(s_proj, t_pool)
将其添加到总损失中,权重一般0.01~0.1。
4. 多教师蒸馏
若教师也来自不同结构,可平均多个教师的 logits 或使用加权投票,进一步提升学生鲁棒性。
实验结果对比
| 模型 | CIFAR100 准确率 | 参数量 |
|---|---|---|
| 教师 ResNet-18 | 76.3% | 11.2M |
| 学生(从头训练) | 48.1% | 0.12M |
| 学生(蒸馏) | 63.5% | 0.12M |
可见蒸馏使学生模型在参数减少98%的情况下,准确率提升超过15个百分点。
常见问题排查
- 蒸馏后学生表现比从头训练还差:检查温度是否过大、教师模型是否欠拟合、硬标签权重是否过低。
- 训练不稳定:若学生极小,可先用学习率预热;引入标签平滑(label smoothing)至教师训练。
- 学生与教师输出方差大:对 logits 做标准化或限制值域,防止数值溢出。
延伸学习方向
- 自蒸馏:在同一训练过程中,用深层指导浅层,无需额外教师。
- 无数据蒸馏:利用生成模型模仿教师输出,保护隐私。
- 蒸馏与大模型:LLM 领域使用思维链蒸馏、指令蒸馏,将巨型模型能力压缩至小模型。
通过合理设计损失函数与训练策略,即使师生结构差异巨大,也能实现高效的知识传递,助力边缘设备部署高性能模型。