把一个又大又准的模型的能力"灌"进一个小模型,让小模型跑得快、装得下、又尽量不掉点——这就是知识蒸馏(Knowledge Distillation)。它的核心洞见出人意料地朴素:模型预测里错误类别之间的相对概率,本身就携带着大量信息。理解这一点,温度、软标签、师生框架就都顺理成章了。

直觉:错误答案里藏着"暗知识"

设想一个手写数字分类器看到一张 “7”。硬标签(hard label)告诉你答案是 7,仅此而已。但一个训练良好的大模型的输出分布可能是:7 占 0.9,1 占 0.08,9 占 0.015,其余接近 0。这个分布在说一件硬标签说不出的事:这个 7 长得有点像 1,更像 9 而不像 3

这种"类别间相似性结构"被称作暗知识(dark knowledge)。one-hot 标签把它全抹平了,而教师模型的完整概率分布把它保留了下来。蒸馏的本质,就是让学生不仅学"正确答案",更学"教师对整个类别空间的看法"。这等于每个样本提供的监督信号从 1 bit 变成了一整个分布,信息量大得多——这也是为什么小模型靠蒸馏能学得比单独硬标签训练好。

机制:温度如何"软化"分布

问题来了:训练好的教师对正确类往往极度自信,softmax 输出接近 one-hot(0.999…),暗知识被压在那些趋近于 0 的小概率里,几乎无法传递。解决办法是给 softmax 加温度(temperature) TT

qi=exp(zi/T)jexp(zj/T)q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}

  • T=1T = 1:标准 softmax。
  • T>1T > 1:分布被软化,拉平峰值、抬高小概率项,类间相对关系凸显出来。
  • TT \to \infty:趋近均匀分布;T0T \to 0:趋近 one-hot(argmax)。

直觉上 TT 像放大镜,把那些 0.0001 量级的"暗知识"放大到学生可感知、梯度可利用的尺度。典型取值在 2 到 10 之间,需按任务调。

公式:蒸馏损失

学生用同样的温度 TT 算自己的软预测,与教师软标签做匹配,常用 KL 散度(等价于交叉熵差一个常数):

Lsoft=T2KL(qTteacherqTstudent)\mathcal{L}_{\text{soft}} = T^2 \cdot \mathrm{KL}\big(q^{\text{teacher}}_T \,\|\, q^{\text{student}}_T\big)

为什么乘 T2T^2?因为软化后梯度量级大约缩小为 1/T21/T^2,乘回去才能让软损失和硬损失在同一量级、好平衡。通常还保留一项用真实标签的硬损失:

L=αLsoft+(1α)Lhard\mathcal{L} = \alpha \, \mathcal{L}_{\text{soft}} + (1-\alpha)\, \mathcal{L}_{\text{hard}}

硬损失防止学生被教师的错误带偏(教师也会犯错),软损失提供丰富的类间结构。α\alpha 常偏向软损失一侧。关键细节:推理时温度务必调回 T=1T=1,软化只在训练阶段用。

1
2
3
4
5
6
7
8
9
10
11
12
import torch.nn.functional as F

def distill_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
# 软损失:在温度 T 下对齐两者分布
s_log = F.log_softmax(student_logits / T, dim=-1)
t_prob = F.softmax(teacher_logits / T, dim=-1) # 教师不回传梯度
soft = F.kl_div(s_log, t_prob, reduction='batchmean') * (T * T)

# 硬损失:学生对真实标签(T=1)
hard = F.cross_entropy(student_logits, labels)

return alpha * soft + (1 - alpha) * hard

注意教师在前向时应 torch.no_grad() 且通常 eval(),它只产标签、不更新。

师生框架的几种变体

最初的蒸馏只匹配最终输出(logits),后来扩展出多种"知识"来源:

  • 基于响应(response-based): 匹配最终软标签,最经典,也最通用。
  • 基于特征(feature-based): 让学生中间层的隐藏表示去逼近教师对应层(如 hint learning)。信息更细,但要处理师生维度不一致,常加一个投影层对齐。
  • 基于关系(relation-based): 不匹配单个激活,而匹配样本之间、层之间的关系结构(如 Gram 矩阵、样本相似度)。
  • 自蒸馏(self-distillation): 师生同架构,甚至同模型不同阶段,深层教浅层,无需额外大模型。
  • 在线蒸馏: 师生同时训练、互为参照,省去预训练教师的一次性成本。

在 LLM 时代,蒸馏的形态又有延伸:用大模型生成的输出(甚至带推理链的回答)作为训练数据去微调小模型,本质上也是一种"序列级"的响应蒸馏——只不过监督信号从概率分布退化成了采样出的 token 序列。

工程权衡与常见误区

  • 容量鸿沟(capacity gap)。 学生太小、和教师差距过大时,蒸馏收益反而下降——学生根本"装不下"教师的知识。有时引入一个中等规模的助教(teaching assistant)逐级蒸馏更稳。
  • 温度不是越高越好。 TT 过高分布过平,正确类的信号被淹没;过低又退回硬标签、暗知识传不出来。TTα\alpha 要一起调。
  • 教师必须够好且校准合理。 蒸馏会忠实地把教师的偏差和过度自信也一起传给学生。若教师本身校准很差,软标签的暗知识含金量会下降。
  • logits vs 概率。 用 KL 匹配概率是主流;也有直接匹配 logits(MSE)的做法,省去温度但丢失 softmax 的归一化语义,效果因任务而异。
  • 数据要对得上。 蒸馏通常在和教师训练同分布的数据上做;分布漂移时教师的软标签可能误导学生。无标签数据也能用——这正是蒸馏相对纯监督的一大优势:软标签本身就是监督。

小结

知识蒸馏的全部魔力,源于一个朴素事实——概率分布里"错误类别之间的相对大小"携带着 one-hot 标签丢掉的暗知识。温度负责把这份藏在小概率里的知识放大到可学的尺度,软损失(配 T2T^2 缩放)+ 硬损失的组合让学生既学教师的世界观又不被其错误带偏。它不是简单的"大模型教小模型",而是一套关于如何把分布级知识转写进更小容量的方法论,从 CV 的轻量分类器一路延伸到今天的 LLM 小型化。