师生网络蒸馏实战:结构差异下的知识传递

FreeGuideOnline 最新 2026-06-13

师生网络蒸馏实战:结构差异下的知识传递

什么是知识蒸馏

知识蒸馏是一种模型压缩与迁移技术,通过让轻量化的学生网络学习大型教师网络丰富的输出分布,实现知识的高效传递。传统训练只依赖硬标签,而蒸馏利用教师网络输出的软标签,其中包含了类别间相似性等暗知识,大幅提升小模型的泛化能力。

为什么需要结构差异下的蒸馏

实际应用中,教师与学生网络的架构、深度、宽度往往存在显著差异:

  • 教师:深层ResNet、ViT、集成模型,计算代价高
  • 学生:MobileNet、ShuffleNet、浅层CNN,追求低延迟与低功耗

不同结构意味着特征空间不匹配,直接对齐中间层特征或输出可能失效。本教程将聚焦如何在异构模型间稳定地传递知识。

蒸馏的核心数学原理

对于输入 $X$,教师输出 logits 为 $z_t$,学生输出 logits 为 $z_s$。蒸馏损失由两部分组成:

$$ \mathcal{L} = \alpha \cdot \mathcal{L}{\text{CE}}(y, \sigma(z_s)) + (1-\alpha) \cdot T^2 \cdot \mathcal{L}{\text{KL}}(\sigma(\frac{z_t}{T}), \sigma(\frac{z_s}{T})) $$

  • $\sigma$ 为 softmax 函数
  • $\mathcal{L}_{\text{CE}}$ 为标准交叉熵,使用真实标签 $y$
  • $\mathcal{L}_{\text{KL}}$ 为 KL 散度,衡量教师与学生的输出分布差异
  • $T$ 为温度系数,平滑分布,一般取值 3~20
  • $\alpha$ 平衡硬标签与软标签的权重

温度越高,分布越平缓,更能揭示类间关系。

实战环境搭建

pip install torch torchvision pytorch-lightning matplotlib

导入所需库:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader
import pytorch_lightning as pl

定义异构师生模型

教师:ResNet-18(约11M参数)
学生:三卷积层小网络(约0.1M参数)

class TinyStudent(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        )
        self.classifier = nn.Linear(128 * 4 * 4, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
from torchvision.models import resnet18
teacher = resnet18(pretrained=False, num_classes=100)
student = TinyStudent(num_classes=100)

蒸馏损失函数实现

class DistillationLoss(nn.Module):
    def __init__(self, temperature=5.0, alpha=0.7):
        super().__init__()
        self.T = temperature
        self.alpha = alpha

    def forward(self, student_logits, teacher_logits, labels):
        # 硬标签损失
        hard_loss = F.cross_entropy(student_logits, labels)
        # 软标签损失 (KL散度)
        soft_student = F.log_softmax(student_logits / self.T, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.T, dim=1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (self.T ** 2)
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss

PyTorch Lightning 训练模块

class LitDistiller(pl.LightningModule):
    def __init__(self, teacher, student, T, lr):
        super().__init__()
        self.teacher = teacher
        self.student = student
        self.criterion = DistillationLoss(T=T, alpha=0.7)
        self.lr = lr
        # 教师固定不动
        for p in self.teacher.parameters():
            p.requires_grad = False

    def forward(self, x):
        return self.student(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        with torch.no_grad():
            t_logits = self.teacher(x)
        s_logits = self.student(x)
        loss = self.criterion(s_logits, t_logits, y)
        self.log('train_loss', loss)
        # 计算准确率
        acc = (s_logits.argmax(dim=1) == y).float().mean()
        self.log('train_acc', acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        s_logits = self.student(x)
        acc = (s_logits.argmax(dim=1) == y).float().mean()
        self.log('val_acc', acc)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.student.parameters(), lr=self.lr)

数据与训练配置

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

train_set = CIFAR100(root='./data', train=True, download=True, transform=transform)
val_set = CIFAR100(root='./data', train=False, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
]))
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=4)

启动训练:

teacher.eval()  # 教师始终为评估模式
model = LitDistiller(teacher, student, T=5.0, lr=1e-3)

trainer = pl.Trainer(max_epochs=50, accelerator='auto', devices=1,
                     callbacks=[pl.callbacks.ModelCheckpoint(monitor='val_acc', mode='max')])
trainer.fit(model, train_loader, val_loader)

结构差异下的调优技巧

1. 温度系数 T 的动态调节

学生容量小时,过高的温度会“冲淡”知识。可尝试从 T=3 起步,逐步升高至 10~20。

2. 逐步增加软标签权重

初始alpha=0.9(先学好硬标签),每10轮降低0.1,最终稳定在0.5~0.7。

3. 中间层特征对齐(FitNet思想)

当输出蒸馏效果受限时,可加入一层线性投影,将学生中间特征映射到教师维度,最小化MSE:

class FeatureMapLoss(nn.Module):
    def __init__(self, student_ch, teacher_ch):
        super().__init__()
        self.projector = nn.Linear(student_ch, teacher_ch, bias=False)

    def forward(self, s_feat, t_feat):
        # s_feat: (B, C_s, H, W) -> 平均池化后投影
        s_pool = F.adaptive_avg_pool2d(s_feat, (1,1)).squeeze()
        t_pool = F.adaptive_avg_pool2d(t_feat, (1,1)).squeeze()
        s_proj = self.projector(s_pool)
        return F.mse_loss(s_proj, t_pool)

将其添加到总损失中,权重一般0.01~0.1。

4. 多教师蒸馏

若教师也来自不同结构,可平均多个教师的 logits 或使用加权投票,进一步提升学生鲁棒性。

实验结果对比

模型 CIFAR100 准确率 参数量
教师 ResNet-18 76.3% 11.2M
学生(从头训练) 48.1% 0.12M
学生(蒸馏) 63.5% 0.12M

可见蒸馏使学生模型在参数减少98%的情况下,准确率提升超过15个百分点。

常见问题排查

  • 蒸馏后学生表现比从头训练还差:检查温度是否过大、教师模型是否欠拟合、硬标签权重是否过低。
  • 训练不稳定:若学生极小,可先用学习率预热;引入标签平滑(label smoothing)至教师训练。
  • 学生与教师输出方差大:对 logits 做标准化或限制值域,防止数值溢出。

延伸学习方向

  • 自蒸馏:在同一训练过程中,用深层指导浅层,无需额外教师。
  • 无数据蒸馏:利用生成模型模仿教师输出,保护隐私。
  • 蒸馏与大模型:LLM 领域使用思维链蒸馏、指令蒸馏,将巨型模型能力压缩至小模型。

通过合理设计损失函数与训练策略,即使师生结构差异巨大,也能实现高效的知识传递,助力边缘设备部署高性能模型。