【AI核心原理30讲】-Transformer架构(一)

张开发
2026/4/14 12:33:31 15 分钟阅读

分享文章

【AI核心原理30讲】-Transformer架构(一)
Self-AttentionTransformer 的第一个核心部分详细拆解前置知识建议先阅读 01-Transformer架构 获得整体认识。本专栏《AI核心原理30讲》专注AI核心原理回归技术本质。一、开篇为什么 Self-Attention 是革命性的1.1 RNN 的困境要理解 Self-Attention 的价值先看它的对手——RNN循环神经网络。RNN 处理序列的方式时间步 1 时间步 2 时间步 3 时间步 4 时间步 5 ↓ ↓ ↓ ↓ ↓ The → cat → sat → on → the ↓ ↓ ↓ ↓ ↓ h₁ h₂ h₃ h₄ h₅RNN 的信息传递是链式的h₁ 包含 “The” 的信息h₂ f(h₁, “cat”)这时 “The” 的信息被编码进了 h₂h₃ f(h₂, “sat”)“The” 和 “cat” 的信息继续传递但已经有所损失以此类推……问题在哪当序列很长时比如1000个词序列开头的 “The” 要经过1000次传递才能影响最后一个词的表示。每次传递都可能有信息损失到达末端时早期信息已经被稀释得几乎看不见。这就是所谓的长期依赖问题Long-Range Dependency Problem。1.2 Self-Attention 如何破局Self-Attention 的核心思想让任意两个位置之间的信息直接交互不经过任何中间传递。Self-Attention 的连接方式完全并行 ┌─────────────────────────┐ The ─────────────┼──→ cat │ └──→ ─ ─ ─ ─ ──┼──────────→ sat │ │ ↓ │ │ ↓ │ │ ↓ │ └──────────────→ on ←───┘ (所有位置两两直接相连)关键对比特性RNNSelf-Attention信息传递路径长度O(n)O(1)两个位置间的依赖必须顺序传递直接建模并行化能力差必须顺序计算强完全并行长距离信息保留差逐层稀释强直接连接二、Self-Attention 的完整数学推导2.1 从输入到 Q、K、V假设输入序列是[The, cat, sat]每个词已经经过 embedding 得到向量输入 X [x₁, x₂, x₃] 形状: (3, d_model) x₁ embedding(The) → (d_model,) x₂ embedding(cat) → (d_model,) x₃ embedding(sat) → (d_model,)然后通过三个独立的线性变换将每个词的 embedding 投影到 Q、K、V 空间Q X · W_Q 形状: (3, d_k) K X · W_K 形状: (3, d_k) V X · W_V 形状: (3, d_v)其中W_Q,W_K,W_V是可学习的权重矩阵形状均为(d_model, d_k)或(d_model, d_v)通常d_k d_v d_model / num_heads为什么要投影直接用原始 embedding 做 attention 不是不行但投影后的 Q/K/V 能学习到更有意义的表示。每个投影空间让模型能够关注不同的方面。2.2 注意力分数计算对于序列中的每个词计算它对所有词的注意力分数scores Q · K^T / √d_k 形状: (3, 3)展开来看scores[i,j] q_i · k_j / √d_k 其中 - q_i x_i · W_Q 第 i 个词的 query - k_j x_j · W_K 第 j 个词的 key - √d_k 是缩放因子为什么要除以 √d_k假设 Q 和 K 的每个分量是均值为0、方差为1的独立随机变量那么 Q·K^T 的方差会是 d_k。当 d_k 较大时点积的值会很大导致 softmax 函数进入饱和区梯度接近于零。除以 √d_k 后点积的方差恢复到 1softmax 的输出分布更加均匀梯度也更稳定。2.3 Softmax 归一化attention_weights softmax(scores, dim-1) 形状: (3, 3)Softmax 将每一行转换为概率分布softmax(x_i) exp(x_i) / Σ exp(x_j)每行的所有值加和为 1表示当前位置对序列中各个位置的关注程度。2.4 加权求和得到输出output attention_weights · V 形状: (3, d_v)这步操作用注意力权重对 V 向量做加权平均output[i] Σ attention_weights[i,j] · v_j意思是对于第 i 个词我根据它对其他词的关注程度取其他词的 value 向量的加权平均。2.5 完整公式Attention ( Q , K , V ) softmax ( Q K T d k ) V \text{Attention}(Q, K, V) \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)softmax(dk​​QKT​)V三、Multi-Head Attention 详解3.1 为什么需要多个注意力头一个注意力头只能学到一种匹配模式。但语言的复杂性需要多种类型的关注以句子“The animal didn’t cross the street because it was too tired”为例指代消解“it” 应该关注 “animal”不是 “street”因果关系“because” 连接了 “didn’t cross” 和 “tired”位置关系“street” 和 “cross” 紧密相连单一 attention 头难以同时捕捉这些关系。Multi-Head 让模型能在不同的子空间并行学习不同的关系。3.2 Multi-Head 的计算MultiHead(Q, K, V) Concat(head_1, head_2, ..., head_h) · W_O 其中 head_i Attention(Q · W_Q_i, K · W_K_i, V · W_V_i)张量形状变化输入 Q: (batch, seq_len, d_model) ↓ 投影 (h 个头) Q_i: (batch, seq_len, d_k) 每个头 i 1..h K_i: (batch, seq_len, d_k) V_i: (batch, seq_len, d_v) ↓ 分头 (view 操作) Q_i: (batch, num_heads, seq_len, d_k) K_i: (batch, num_heads, seq_len, d_k) V_i: (batch, num_heads, seq_len, d_v) ↓ 注意力计算 head_i: (batch, num_heads, seq_len, d_v) ↓ 拼接 Concat: (batch, seq_len, h * d_v) (batch, seq_len, d_model) ↓ 输出投影 output: (batch, seq_len, d_model)3.3 论文中的标准配置参数Base 模型Large 模型d_model5121024num_heads816d_k d_v6464FFN 维度20484096四、代码实现从理论到代码4.1 最简版本纯 Python / NumPyimportnumpyasnpdefsoftmax(x,axis-1):Numerically stable softmaxexp_xnp.exp(x-np.max(x,axisaxis,keepdimsTrue))returnexp_x/np.sum(exp_x,axisaxis,keepdimsTrue)defself_attention(Q,K,V,d_k): 简化版 Self-Attention无 batch 维度 参数: Q: (seq_len, d_k) 查询矩阵 K: (seq_len, d_k) 键矩阵 V: (seq_len, d_v) 值矩阵 d_k: 缩放因子 返回: output: (seq_len, d_v) 注意力输出 weights: (seq_len, seq_len) 注意力权重 # Step 1: 计算点积注意力分数scoresnp.dot(Q,K.T)/np.sqrt(d_k)# Step 2: Softmax 归一化attention_weightssoftmax(scores,axis-1)# Step 3: 加权求和outputnp.dot(attention_weights,V)returnoutput,attention_weights# 测试d_k64seq_len5d_v64# 模拟输入Qnp.random.randn(seq_len,d_k)Knp.random.randn(seq_len,d_k)Vnp.random.randn(seq_len,d_v)output,weightsself_attention(Q,K,V,d_k)print(f输出形状:{output.shape})print(f注意力权重形状:{weights.shape})print(f权重验证每行和1:{weights.sum(axis1)})4.2 PyTorch 完整实现importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimportmathclassMultiHeadAttention(nn.Module):Multi-Head Attention 的完整 PyTorch 实现def__init__(self,d_model,num_heads,dropout0.1):super().__init__()assertd_model%num_heads0,d_model 必须能被 num_heads 整除self.d_modeld_model self.num_headsnum_heads self.d_kd_model//num_heads# 每个注意力头的维度# 四个线性变换层self.W_qnn.Linear(d_model,d_model)self.W_knn.Linear(d_model,d_model)self.W_vnn.Linear(d_model,d_model)self.W_onn.Linear(d_model,d_model)self.dropoutnn.Dropout(dropout)defforward(self,query,key,value,maskNone): 参数: query: (batch, seq_len, d_model) key: (batch, seq_len, d_model) value: (batch, seq_len, d_model) mask: (batch, seq_len, seq_len) 或 (batch, 1, seq_len, seq_len) 返回: output: (batch, seq_len, d_model) attention_weights: (batch, num_heads, seq_len, seq_len) batch_sizequery.size(0)seq_lenquery.size(1)# Step 1: 线性投影 分头 # Q, K, V: (batch, seq_len, d_model)Qself.W_q(query)Kself.W_k(key)Vself.W_v(value)# 视图分割最后一项分成 num_heads × d_k# (batch, seq_len, num_heads, d_k) 然后转置# → (batch, num_heads, seq_len, d_k)QQ.view(batch_size,seq_len,self.num_heads,self.d_k).transpose(1,2)KK.view(batch_size,seq_len,self.num_heads,self.d_k).transpose(1,2)VV.view(batch_size,seq_len,self.num_heads,self.d_k).transpose(1,2)# Step 2: 计算注意力分数 # scores: (batch, num_heads, seq_len, seq_len)scorestorch.matmul(Q,K.transpose(-2,-1))/math.sqrt(self.d_k)# Step 3: 应用 Mask ifmaskisnotNone:# 支持两种 mask 格式ifmask.dim()2:# (seq_len, seq_len)maskmask.unsqueeze(0).unsqueeze(0)# → (1, 1, seq_len, seq_len)elifmask.dim()3:# (batch, seq_len, seq_len)maskmask.unsqueeze(1)# → (batch, 1, seq_len, seq_len)scoresscores.masked_fill(mask0,float(-inf))# Step 4: Softmax Dropout attention_weightsF.softmax(scores,dim-1)attention_weightsself.dropout(attention_weights)# Step 5: 加权求和 # context: (batch, num_heads, seq_len, d_k)contexttorch.matmul(attention_weights,V)# Step 6: 合并多头 # (batch, num_heads, seq_len, d_k) → (batch, seq_len, num_heads, d_k)contextcontext.transpose(1,2).contiguous()# → (batch, seq_len, d_model)contextcontext.view(batch_size,seq_len,self.d_model)# Step 7: 最终线性投影 outputself.W_o(context)returnoutput,attention_weights4.3 使用示例# 创建一个 Multi-Head Attention 层d_model512num_heads8attentionMultiHeadAttention(d_model,num_heads)# 模拟输入batch_size2seq_len10xtorch.randn(batch_size,seq_len,d_model)# 前向传播output,attn_weightsattention(x,x,x)print(f输出形状:{output.shape})# (2, 10, 512)print(f注意力权重形状:{attn_weights.shape})# (2, 8, 10, 10)# 可视化第 1 个样本第 1 个头的注意力权重# attn_weights[0, 0] 形状是 (10, 10)五、Mask 的作用与实现5.1 为什么需要 Mask在 Transformer 中Mask 主要有两种用途用途 1Padding Mask输入序列长度不一需要 padding 到统一长度。但 padding 位置不应该参与注意力计算。原始句子: [The, cat, sat] Padding后: [The, cat, sat, [PAD], [PAD]] 注意力权重应该是: The cat sat [PAD] [PAD] The [ 0.4 0.3 0.2 0.0 0.0 ] cat [ 0.3 0.4 0.2 0.0 0.0 ] sat [ 0.2 0.2 0.4 0.0 0.0 ] [PAD] [ 0.0 0.0 0.0 0.0 0.0 ] [PAD] [ 0.0 0.0 0.0 0.0 0.0 ] ↑ padding 位置权重为 0用途 2因果掩码Causal Mask / Look-Ahead Mask在解码器中预测第 N 个词时不能看到第 N 个词之后的任何信息。目标序列: [The, cat, sat, [EOS]] 允许的注意力连接✓ 表示可见 Step1 Step2 Step3 Step4 Step1 → ✓ ✗ ✗ ✗ Step2 → ✓ ✓ ✗ ✗ Step3 → ✓ ✓ ✓ ✗ Step4 → ✓ ✓ ✓ ✓ 解码器中真正的注意力权重: Step1 [ 1.0 0.0 0.0 0.0 ] Step2 [ 0.5 0.5 0.0 0.0 ] Step3 [ 0.3 0.3 0.4 0.0 ] Step4 [ 0.2 0.2 0.2 0.4 ]5.2 Mask 的 PyTorch 实现defcreate_padding_mask(seq,pad_idx0): 创建 Padding Mask 参数: seq: (batch, seq_len) token IDs pad_idx: padding 的 token ID默认为 0 返回: mask: (batch, 1, 1, seq_len) True 表示有效位置False 表示需要 mask mask(seq!pad_idx).unsqueeze(1).unsqueeze(2)returnmask# (batch, 1, 1, seq_len) → 广播后 (batch, num_heads, seq_len, seq_len)defcreate_causal_mask(seq_len): 创建因果掩码上三角 mask 返回: mask: (1, 1, seq_len, seq_len) 上三角为 False不可见 # torch.triu: 上三角不含对角线为 1masktorch.triu(torch.ones(seq_len,seq_len),diagonal1).bool()# → (seq_len, seq_len)# 对角线及以上为 True需要 mask对角线以下为 False可见maskmask.unsqueeze(0).unsqueeze(0)# → (1, 1, seq_len, seq_len)return~mask# 取反True 表示可见False 表示 mask# 使用示例batch_size2seq_len5# Padding Maskseqtorch.tensor([[1,2,3,0,0],# 句子1: [PAD]0[1,2,0,0,0]])# 句子2: [PAD]0padding_maskcreate_padding_mask(seq)print(Padding Mask:)print(padding_mask[0])# 句子1的 mask# 因果 Maskcausal_maskcreate_causal_mask(seq_len)print(\n因果 Mask:)print(causal_mask[0,0])六、注意力权重的可视化6.1 典型的注意力模式不同层的注意力头会捕捉不同类型的依赖Layer 1 Head 1捕捉局部关系: The cat sat on mat The [ 0.5 0.3 0.1 0.05 0.05 ] cat [ 0.3 0.4 0.2 0.05 0.05 ] sat [ 0.1 0.2 0.4 0.2 0.1 ] on [ 0.05 0.05 0.2 0.4 0.3 ] mat [ 0.05 0.05 0.1 0.3 0.5 ] ↑ 局部性关注相邻词 Layer 3 Head 5捕捉语法关系: The cat sat on mat The [ 0.4 0.1 0.1 0.1 0.1 ] cat [ 0.1 0.6 0.1 0.1 0.1 ] ← cat 关注自身主语 sat [ 0.1 0.1 0.6 0.1 0.1 ] ← sat 关注自身谓语 on [ 0.1 0.1 0.1 0.6 0.1 ] mat [ 0.1 0.1 0.1 0.1 0.6 ] ↑ 语法每个词更关注自己完整实体表示6.2 可视化代码importmatplotlib.pyplotaspltimportseabornassnsdefplot_attention_weights(weights,tokens,save_pathNone): 可视化注意力权重热力图 参数: weights: (num_heads, seq_len, seq_len) 或 (seq_len, seq_len) tokens: list of strings词元列表 save_path: 可选保存路径 ifweights.dim()4:# 取第一个样本的第一个头weightsweights[0,0].detach().numpy()else:weightsweights.detach().numpy()plt.figure(figsize(10,8))sns.heatmap(weights,xticklabelstokens,yticklabelstokens,cmapBlues,annotFalse,fmt.2f)plt.xlabel(Key 位置)plt.ylabel(Query 位置)plt.title(Self-Attention 权重热力图)plt.tight_layout()ifsave_path:plt.savefig(save_path)plt.show()# 使用示例tokens[The,cat,sat,on,the,mat]# 假设已经得到 attention_weights# plot_attention_weights(attention_weights, tokens)七、Self-Attention 的复杂度分析7.1 时间复杂度Self-Attention 的主要计算Step 1: Q, K, V 投影 O(n · d_model · d_k) × 3 Step 2: 计算 QK^T O(n² · d_k) Step 3: Softmax O(n²) Step 4: 加权求和 O(n² · d_v) 总计: O(n² · d_model)其中 n 是序列长度d_model 是模型维度。关键点Self-Attention 的时间复杂度是序列长度的平方O(n²)。这是 Transformer 的主要瓶颈。7.2 空间复杂度存储 Q, K, V: O(n · d_model) × 3 存储注意力权重矩阵: O(n²) 存储输出: O(n · d_model) 总计: O(n² n · d_model)注意力权重矩阵 O(n²) 是最大的开销。7.3 复杂度对比模型注意力复杂度RNN/LSTMO(n · d)Self-AttentionO(n² · d)局部注意力 (窗口 k)O(n · k · d)LinformerO(n · k)ReformerO(n · log n)这就是为什么 Long Context 是一个热门研究方向——标准 Self-Attention 无法直接处理超长序列。八、Self-Attention 在 Transformer 中的位置┌─────────────────────────────────────────────────────────────┐ │ Encoder Layer │ │ │ │ Input X ──→ Multi-Head Self-Attention ──→ Add Norm ──┐ │ │ │ │ │ ↑ │ │ │ │ │ │ │ └──────── Feed Forward ────────────────────────┘ │ │ │ │ × N 层 │ └─────────────────────────────────────────────────────────────┘ 每层内部 x_input │ ├──→ Multi-Head Self-Attention ──→ Add(x_input, attention_output) │ │ │ ↓ │ LayerNorm │ │ │ ↓ (sublayer_1 output) │ │ └──→ Feed Forward Network ─────────→ Add(sublayer_1, ffn_output) │ ↓ LayerNorm │ ↓ x_output (传给下一层)Self-Attention 的核心作用Encoder Self-Attention让每个位置能看到序列中所有其他位置学习输入的上下文表示Decoder Self-Attention类似但有 Mask确保自回归生成时不泄露未来信息Cross AttentionDecoder 层中Q 来自 DecoderK/V 来自 Encoder 输出实现跨模块交互九、关键设计选择的原因设计选择选择原因缩放因子 √d_kQK^T / √d_k防止 d_k 较大时 softmax 梯度消失多头注意力h 个独立头不同头学习不同类型的依赖关系QKV 输入自注意力让序列内部进行自我比较学习内部结构Linear 投影学习的 W_Q, W_K, W_V增加模型表达能力让投影空间更有意义拼接后投影Concat → W_O合并多头的不同子空间表示十、总结与延伸10.1 核心要点回顾Self-Attention 通过直接建模任意位置间的依赖解决了 RNN 的长距离依赖问题Q/K/V 三元组让每个词既能问问题Query也能回答问题Key/Value缩放因子 √d_k 是关键细节防止大维度下的梯度消失Multi-Head 扩展了模型的表示能力不同头学习不同类型的关系Mask 机制让 Transformer 能处理变长序列和控制信息流动10.2 延伸阅读方向方向关键技术适用场景高效注意力Flash Attention, Sparse Attention长上下文位置编码RoPE, ALiBi,绝对位置编码位置感知注意力变体Grouped Query Attention高效推理跨模态Cross Attention多模态融合参考资料Vaswani et al., “Attention Is All You Need”, NeurIPS 2017The Illustrated Transformer: http://nlp.seas.harvard.edu/2018/04/03/attention.htmlLilian Weng, “Attention? Attention!”, Lil’LogPyTorch Transformer Documentation: https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html本专栏持续更新下一篇《Feed Forward Network注意力之外的另一条腿》

更多文章