直觉:一次可微的"软查询"
把注意力想成一次数据库检索:你拿着一个查询(Query),去和一堆条目的键(Key)比对相似度,相似度高的条目,其值(Value)就被多取一点。区别在于,传统检索是硬选一条记录,而注意力是按相似度加权所有值——它是一次"软"的、处处可微的查表。正因为可微,整套机制能塞进反向传播里端到端训练。
这套 Query/Key/Value(QKV)抽象,是理解 Transformer 的钥匙。下面我们把它一步步拆到矩阵运算和数值细节。
机制:从单个 token 到矩阵形式
设输入是 个 token 的表示 。注意力不直接用 ,而是先用三个可学习的投影矩阵把它变换成 Q、K、V:
其中 ,。然后核心公式——缩放点积注意力(scaled dot-product attention):
逐步拆:
- :第 行第 列是 query 和 key 的内积,即"token 应该多关注 token "的原始打分(logits)。
- 除以 :缩放(下一节专门讲为什么)。
- 对每一行做 softmax:把打分变成和为 1 的权重分布 , 是 token 分给 token 的注意力权重。
- :用权重对所有 value 加权求和,得到每个 token 的新表示。
一句话:每个 token 的输出 = 所有 token 的 value 的加权平均,权重由 query-key 相似度经 softmax 归一化决定。
公式:为什么要除以
这是面试高频题,也是真正的数学细节。假设 和 的每个分量都是独立、均值 0、方差 1 的随机变量。它们的点积是:
每一项 均值为 0、方差为 1(独立项乘积方差 = 各自方差乘积)。 个独立项相加,方差线性叠加:
所以点积的标准差是 。当 较大(比如 64、128),未缩放的 logits 量级会很大。把这种大值喂进 softmax,会让分布极度尖锐(几乎是 one-hot),而 softmax 在饱和区的梯度趋近于 0——梯度消失,训练停滞。除以 正好把 logits 的方差拉回 ,让 softmax 工作在梯度健康的区间。
1 | 未缩放: logits ~ N(0, d_k) -> softmax 尖锐 -> 梯度近 0 |
多头:在多个子空间并行关注
单个注意力只能学一种"关注模式"。**多头注意力(multi-head)**把 切成 个头,每个头在 维的子空间独立做注意力,再拼接、过一个输出投影:
\text{MultiHead}(X) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O,\quad \text{head}_i = \text{Attention}(XW_i^Q, XW_i^K, XW_i^V)直觉上,不同头可以分工——有的关注相邻词,有的关注长距离依赖,有的关注语法、有的关注指代。切成多头几乎不增加总计算量(因为每头维度变小了),却让模型在多个表示子空间里并行建模关系。
掩码:因果与 padding
两种常见掩码:
- 因果掩码(causal mask):自回归生成时,token 不能看到未来的 。做法是在 softmax 之前,把上三角()的 logits 设为 ,softmax 后这些位置权重为 0。
- padding 掩码:batch 内序列长度不一,补齐的 padding 位置同样置 ,避免模型关注无意义的填充。
最小实现
1 | import numpy as np |
注意 softmax 里"减最大值"的技巧——直接 exp 大数会数值溢出,这是必备的稳定化处理。
工程权衡与边界
- 复杂度是 。 那个 是 矩阵,序列长度翻倍,计算和显存都翻四倍。长上下文的根本瓶颈就在这里。各种线性注意力、稀疏注意力、以及 IO 优化(把注意力分块、避免把完整 矩阵写回显存的 FlashAttention 类方法)都是为了缓解它。
- 显存大头是注意力矩阵。 长序列训练时,存储 的注意力权重(及其反传所需中间量)往往是显存峰值来源,而非参数本身。
- 常见误区一:以为 Q、K、V 来自不同输入。 在自注意力里它们都来自同一个 ,只是过了不同投影;在交叉注意力(如编码器-解码器)里,Q 来自解码器、K/V 来自编码器。
- 常见误区二:忘了缩放或缩放错维度。 缩放因子是 (每个头的维度),不是 。
- 常见误区三:注意力本身不含位置信息。 公式对 token 顺序是置换等变的——打乱输入顺序,输出只是相应打乱。位置信息必须靠位置编码额外注入。
小结
注意力是一次可微的软查表:Q 去匹配 K、用相似度加权 V。缩放点积的精髓在 ——它把点积方差从 拉回 ,避免 softmax 饱和导致梯度消失。多头让模型在多个子空间并行建模不同关系,掩码控制可见范围。代价是 的复杂度与显存,这既是 Transformer 强表达力的来源,也是长上下文优化永恒的战场。