直觉:为什么 Transformer 取代了 RNN

RNN 处理序列是"逐字传话"——第 tt 步依赖第 t1t-1 步的隐藏状态,天然串行,没法并行;而且长距离信息要经过很多步传递,容易梯度消失。Transformer 的核心革命是:用注意力一次性建立任意两个 token 之间的直连,把序列建模从串行变成可大规模并行的矩阵运算。任意两个位置的"路径长度"是常数 1,长依赖不再衰减。

代价是注意力的 O(n2)O(n^2) 复杂度(本文最后会谈)。下面我们从输入到输出,逐层拆开一个标准 Transformer block。

数据流总览

一个 token 从进入到产出 logits,经过的管线是:

1
2
3
4
5
6
7
token ids
-> 词嵌入 + 位置编码 (变成向量并注入位置)
-> [ Transformer Block ] × N
├─ 多头自注意力 + 残差 + LayerNorm
└─ 前馈网络 FFN + 残差 + LayerNorm
-> 最终 LayerNorm
-> 输出投影 (lm_head) -> softmax -> 下一 token 概率

每个 block 内部只有两个子层:注意力和前馈,外加残差连接和归一化把它们包起来。把这两块讲透,整个架构就清楚了。

第一层:嵌入与位置编码

输入是离散 token id,先查嵌入表得到 XRn×dmodelX \in \mathbb{R}^{n \times d_{model}}。但注意力对顺序是置换等变的——它本身分不清"猫追狗"和"狗追猫"。所以必须显式注入位置信息。

经典做法是正弦位置编码,对位置 pospos、维度 ii

PE(pos,2i)=sin(pos100002i/dmodel),PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos,2i)} = \sin\!\left(\frac{pos}{10000^{2i/d_{model}}}\right),\quad PE_{(pos,2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d_{model}}}\right)

不同维度用不同频率,让模型能从中解码出绝对和相对位置(相对位置可由三角恒等式线性表示)。现代模型多用可学习位置嵌入或旋转位置编码(RoPE)等变体,但目的不变:把位置信号加进 token 表示。位置编码直接和词嵌入相加:XX+PEX \leftarrow X + PE

第二层:多头自注意力子层

block 的第一个子层是上一篇讲过的多头自注意力。这里只强调它在 block 里的"包装":

X=LayerNorm(X+MultiHeadAttention(X))X' = \text{LayerNorm}\big(X + \text{MultiHeadAttention}(X)\big)

注意两个外壳:残差连接 X+sublayer(X)X + \text{sublayer}(X)层归一化。它们不是装饰,而是深层网络能训起来的关键,下面单独讲。

第三层:前馈网络(FFN)

注意力负责"跨 token 混合信息",FFN 负责"对每个 token 单独做非线性变换"。它是一个逐位置(position-wise)的两层 MLP,对每个 token 独立施加同一套权重:

FFN(x)=max(0, xW1+b1)W2+b2\text{FFN}(x) = \max(0,\ xW_1 + b_1)W_2 + b_2

典型地,中间维度 dffd_{ff}dmodeld_{model} 的 4 倍——先升维到一个更宽的空间做非线性,再投影回来。这一层往往是参数量大头:两个 dmodel×4dmodeld_{model}\times 4d_{model} 的矩阵,参数远多于注意力的投影矩阵。激活函数现代多用 GELU/SwiGLU 替代 ReLU。

同样包上残差和归一化:

X=LayerNorm(X+FFN(X))X'' = \text{LayerNorm}\big(X' + \text{FFN}(X')\big)

残差与归一化:深层可训练的支柱

残差连接让梯度有一条"高速公路"直达浅层:因为 (x+f(x))x=I+fx\frac{\partial (x + f(x))}{\partial x} = I + \frac{\partial f}{\partial x},那个单位矩阵 II 保证即使 ff 的梯度很小,整体梯度也不会消失。这是堆几十上百层还能训练的根本原因。

层归一化(LayerNorm) 对每个 token 的特征向量做标准化:

LN(x)=γxμσ2+ϵ+β\text{LN}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

μ,σ\mu,\sigma 是该 token 向量内部所有特征的均值方差,γ,β\gamma,\beta 是可学习的缩放和偏移。和 BatchNorm 不同,LayerNorm 不跨样本统计,因此不依赖 batch 大小,对变长序列和小 batch 友好——这正是序列模型选它的原因。

一个重要工程细节是 Pre-LN vs Post-LN:原始结构把 LN 放在残差之后(Post-LN),训练深层时不稳定、常需要 warmup;现代实现多改成 Pre-LN(LN 放进残差分支之前 x+sublayer(LN(x))x + \text{sublayer}(\text{LN}(x))),梯度更稳、更好训。这是踩坑高发区。

最小代码骨架

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch, torch.nn as nn

class Block(nn.Module):
def __init__(self, d_model, n_head, d_ff):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
)

def forward(self, x, mask=None):
# Pre-LN:归一化放进残差分支之前,训练更稳
h = self.ln1(x)
x = x + self.attn(h, h, h, attn_mask=mask)[0] # 自注意力:Q=K=V=h
x = x + self.ffn(self.ln2(x)) # 前馈
return x

整个模型就是 Embedding + PositionalEncoding -> N×Block -> final LN -> Linear(lm_head)。可以看到 block 是高度规整的重复单元,这种同构堆叠正是它易于扩展(scaling)的工程优势。

三种架构变体

同一个 block 拼法不同,对应三类模型:

  • Encoder-only(双向、看全文):每个 token 能看到左右全部上下文,适合理解类任务(分类、抽取、检索)。无因果掩码。
  • Decoder-only(自回归、只看左侧):用因果掩码,逐 token 生成,是当今主流大语言模型的形态。
  • Encoder-Decoder:编码器双向理解源序列,解码器用交叉注意力(Q 来自解码器,K/V 来自编码器输出)逐步生成目标序列,适合翻译、摘要等序列到序列任务。

工程权衡与边界

  • 复杂度 O(n2d)O(n^2 d):注意力随序列长度平方增长,是长上下文的根本瓶颈。FlashAttention 类方法通过分块、不把完整 n×nn\times n 矩阵写回显存,把显存从 O(n2)O(n^2) 降到 O(n)O(n)(计算量量级不变但常数大降)。
  • 显存构成:参数、优化器状态(Adam 要存一阶二阶动量,约为参数的 2 倍)、激活值(反传要缓存)。长序列训练时激活和注意力中间量常是峰值来源,梯度检查点(重算换显存)是常用手段。
  • 推理优化 KV Cache:自回归生成时,已生成 token 的 K、V 可缓存复用,避免每步重算,把单步复杂度从 O(n2)O(n^2) 降到 O(n)O(n)。代价是 KV cache 占显存,且随上下文线性增长——长对话的显存压力主要在这。
  • 常见误区:以为加深层数总能提升效果。没有 Pre-LN、合适的 warmup/初始化,深层 Transformer 极易训练发散;残差和归一化的细节决定能不能堆深。

小结

Transformer 把序列建模拆成两个正交的操作:注意力做跨 token 的信息混合,FFN 做逐 token 的非线性变换,再用残差连接(梯度高速公路)和层归一化(稳定激活分布)把它们包成可深层堆叠的同构 block。位置编码补上注意力缺失的顺序信息。这套规整、可并行、易扩展的设计,加上 O(n2)O(n^2) 的代价与对应的 FlashAttention、KV Cache 等工程优化,构成了今天几乎所有大模型的骨架。