动手学深度学习——注意力机制代码

张开发
2026/4/16 2:59:05 15 分钟阅读

分享文章

动手学深度学习——注意力机制代码
1. 前言上一篇我们已经从思想上理解了注意力机制基础 Seq2Seq 的问题在于固定长度上下文向量解码器在不同时间步其实应该关注输入序列的不同位置注意力机制的本质就是对输入表示做加权和权重由当前位置和各输入位置的相关性决定这一篇就继续按李沐的节奏把注意力机制真正落到代码上。这一节最重要的不是一开始就把所有复杂变体都铺开而是先把最核心的代码骨架看懂查询query是什么键key是什么值value是什么注意力权重怎么得到加权求和怎么实现你会发现注意力机制代码的灵魂其实很简单先算相关性分数再做 softmax再对 value 加权求和。2. 注意力代码到底在做什么如果从最抽象的角度看注意力机制的输入通常是三部分querykeyvalue然后输出一个结果根据 query 和 key 的匹配程度决定如何对 value 做加权汇总。所以它的计算主线可以写成三步第一步算分数score(query, key)第二步归一化成权重softmax(scores)第三步对 value 加权和attention_weights values所以注意力代码不是神秘黑箱本质就是分数 → 权重 → 加权和3. 为什么会有 query、key、value 这三个名字这三个名字第一次看会有点抽象但其实非常形象。你可以把它理解成“查询数据库”的过程query表示你现在想找什么。key表示每个候选位置的“索引标签”。value表示每个候选位置真正存放的内容。在注意力里query 决定当前需要什么信息key 决定每个位置和当前需求有多相关value 才是最终被加权汇总的内容所以query 用来问key 用来比value 用来取。4. 在 Seq2Seq 中query、key、value 分别是谁放到机器翻译的解码器场景里最常见的理解是query当前解码器时刻的隐藏状态也就是我现在要生成第t个目标词我当前最需要什么信息key编码器每个时间步的输出表示也就是源句子每个位置都提供一个“可匹配的表示”value通常也是编码器每个时间步的输出表示也就是最终真正被加权汇总的源句信息所以在最基础的 Seq2Seq 注意力里常见是query decoder hidden statekey encoder outputsvalue encoder outputs5. 最基础的注意力代码要先解决什么这一节李沐这里通常会先实现一种比较简单的注意力层例如“加性注意力”或一个通用注意力模块。但在进入具体分数函数之前通常会先把一个公共步骤处理掉masked softmax因为在序列任务里输入往往有 padding。如果不把 padding 位置屏蔽掉模型可能会把注意力错误地分给那些补齐出来的无效位置。所以注意力代码里非常基础的一步就是先算出分数再对无效位置 mask再做 softmax6. 什么是 masked softmaxmasked softmax 的作用是只在有效位置上做 softmax把 padding 位置的权重压成 0。为什么需要它假设一个 batch 里两条句子长度不同第一句长度是 5第二句长度是 3但 pad 到了 5那么第二句后面两个位置其实是无效的pad。如果注意力还把权重分给这两个位置就会污染上下文向量。所以必须在 softmax 前把这些位置“屏蔽掉”。7. masked softmax 代码怎么理解常见写法大致如下def masked_softmax(X, valid_lens): if valid_lens is None: return nn.functional.softmax(X, dim-1) else: shape X.shape if valid_lens.dim() 1: valid_lens torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens valid_lens.reshape(-1) X d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value-1e6) return nn.functional.softmax(X.reshape(shape), dim-1)这段代码乍一看有点绕但核心思想其实很简单第一步把无效位置赋成一个非常小的值例如-1e6第二步再做 softmax因为 softmax 后有效位置还能得到正常权重无效位置由于值极小权重几乎就是 0所以它本质上就是先 mask再 softmax8. 为什么把无效位置设成-1e6因为 softmax 的形式是指数归一化exp(x_i) / sum(exp(x_j))如果某个位置被设成-1e6那么exp(-1e6) ≈ 0这样它在 softmax 后的权重就几乎为 0。所以这种做法非常常见也非常实用。它不需要单独手写一个“软屏蔽公式”只要借助 softmax 的性质就行。9.valid_lens是什么valid_lens表示每个样本真实有效的序列长度例如一个 batch 有两条序列第一条长度 5第二条长度 3那么valid_lens [5, 3]这样注意力层就知道第一条的前 5 个位置有效第二条只有前 3 个位置有效后面是 padding所以valid_lens本质上就是 mask 的依据。10. 为什么注意力代码里常常要保存attention_weights很多实现里都会写self.attention_weights ...这是因为注意力机制一个很大的优点就是可解释性很强保存注意力权重有两个作用第一后续计算需要有些模块需要直接拿权重做加权和。第二便于可视化分析你可以把注意力权重画出来看模型当前到底在关注哪些输入位置。这也是注意力机制特别有魅力的一点它不像普通隐状态那么黑箱至少你能看到“它把注意力放在哪里”。11. 一个典型的注意力层长什么样李沐这里通常会实现一个加性注意力层例如class AdditiveAttention(nn.Module): def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) self.W_k nn.Linear(key_size, num_hiddens, biasFalse) self.W_q nn.Linear(query_size, num_hiddens, biasFalse) self.w_v nn.Linear(num_hiddens, 1, biasFalse) self.dropout nn.Dropout(dropout)这里只先看初始化。你会发现它并没有直接拿 query 和 key 点乘而是先做了几次线性变换。这就是“加性注意力”的特点。12. 加性注意力为什么叫“加性”因为它不是直接做内积而是先把 query 和 key 投影到同一个隐藏空间再相加、过非线性、再打分。直觉上可以写成score(q, k) w^T tanh(W_q q W_k k)这里最显眼的地方就是W_q q W_k k有个“加”。所以它被称为加性注意力Additive Attention这类注意力最早在 Seq2Seq 里非常经典也常叫Bahdanau attention的打分思路。13. 这三个线性层分别在干什么在初始化代码里self.W_k nn.Linear(key_size, num_hiddens, biasFalse) self.W_q nn.Linear(query_size, num_hiddens, biasFalse) self.w_v nn.Linear(num_hiddens, 1, biasFalse)可以这样理解。W_k把 key 投影到共同隐藏空间。W_q把 query 也投影到共同隐藏空间。w_v把两者融合后的隐藏表示再压成一个标量分数。也就是说加性注意力的分数不是直接算出来的而是先投影融合压缩最后得到一个注意力分数。14. 加性注意力的前向传播怎么写常见写法如下def forward(self, queries, keys, values, valid_lens): queries, keys self.W_q(queries), self.W_k(keys) features queries.unsqueeze(2) keys.unsqueeze(1) features torch.tanh(features) scores self.w_v(features).squeeze(-1) self.attention_weights masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)这段代码就是注意力机制代码里最值得细拆的一段。15.queries, keys self.W_q(queries), self.W_k(keys)在做什么这一句表示先把 query 投影到隐藏空间再把 key 投影到同一个隐藏空间这样做的好处是不管原始 query 和 key 维度是否一样都可以先映射到统一空间里再比较。这是一种非常常见的做法。因为不同来源的表示未必天然适合直接比较先投影能让匹配更灵活。16.unsqueeze和广播加法为什么这么写这一句是核心features queries.unsqueeze(2) keys.unsqueeze(1)它的目的就是让每个 query 和每个 key 两两配对。假设queries形状是(batch_size, num_queries, num_hiddens)keys形状是(batch_size, num_kv_pairs, num_hiddens)那么queries.unsqueeze(2)会变成(batch_size, num_queries, 1, num_hiddens)keys.unsqueeze(1)会变成(batch_size, 1, num_kv_pairs, num_hiddens)然后通过广播相加就得到(batch_size, num_queries, num_kv_pairs, num_hiddens)这就相当于每个 query 都和所有 key 组合了一遍。这一步特别关键因为注意力本质上就是要比较当前 query 和所有 key 的相关性17. 为什么后面要tanhfeatures torch.tanh(features)这是加性注意力的非线性变换步骤。它的作用是增强表达能力让 query-key 融合后的表示不只是线性相加为后面的分数计算提供更灵活特征这和前面 RNN/LSTM/GRU 中tanh的作用有些相似都是为了让模型不只是简单线性变换。18.scores self.w_v(features).squeeze(-1)在干什么这一句表示把最后那个num_hiddens维特征压成一个标量分数。也就是说对每个 query-key 对最终都会得到一个实数分数于是scores的形状通常是(batch_size, num_queries, num_kv_pairs)这正好对应每个 query 对所有 key 的相关性打分表这张分数表后面经过 softmax就会变成注意力权重。19.masked_softmax(scores, valid_lens)在这里的意义是什么这里就是把前面讲的 mask 用上了。因为 key/value 序列可能有 padding所以在注意力分数转成权重之前必须把无效位置屏蔽掉。这一步之后self.attention_weights就会变成一组合法的注意力分布非负和为 1padding 位置权重几乎为 0所以这一步本质上是在说只在真实有效输入位置上分配注意力。20.torch.bmm(attention_weights, values)为什么能得到上下文向量最后一步torch.bmm(self.attention_weights, values)这里的bmm是 batch matrix multiplication也就是批量矩阵乘法。假设attention_weights形状是(batch_size, num_queries, num_kv_pairs)values形状是(batch_size, num_kv_pairs, value_dim)那么相乘后得到(batch_size, num_queries, value_dim)这正好就是对每个 query把所有 value 按注意力权重做加权和所以bmm这一步其实就是把“加权求和”高效矩阵化实现了。这也是注意力代码最核心的落地点注意力输出 权重 × values21. 为什么 values 不一定等于 keys在很多基础 Seq2Seq 注意力里keys encoder outputsvalues encoder outputs所以两者看起来一样。但从更一般的框架看它们其实不是必须相同。key负责被 query 匹配决定权重。value负责被加权求和形成输出。在更复杂模型里key 和 value 可以来自不同投影或不同表示。所以把它们分开是一种更通用的设计。22. 这一节代码最该掌握什么如果从学习重点看最重要的是这几件事。22.1 理解 masked softmax知道为什么注意力一定要 mask padding。22.2 理解 query、key、value 的角色分工query当前需求key匹配对象value最终取出的内容22.3 理解unsqueeze broadcast的作用这是实现 query-key 两两配对的关键。22.4 理解注意力分数到注意力权重的转换也就是打分softmax得到分布22.5 理解bmm为什么就是加权和这是注意力机制代码最核心的一步。23. 这一节和下一节“注意力分数”是什么关系这一节主要是在讲注意力机制的基本代码框架怎么搭也就是分数算出来以后怎么办权重怎么算加权和怎么实现而下一节“注意力分数”会更聚焦于分数本身到底怎么设计例如加性注意力缩放点积注意力打分函数不同会带来什么差异所以这两节可以这么理解这一节偏整体计算流程。下一节偏分数函数本身。24. 本节总结这一节我们学习了注意力机制的代码基础核心内容可以总结为以下几点。24.1 注意力机制代码的主线是打分 → softmax → 加权和这是最核心的三步。24.2 masked softmax 用于屏蔽 padding 位置确保无效 token 不参与注意力分配。24.3 query、key、value 分别承担不同角色它们共同决定当前上下文向量如何生成。24.4 加性注意力通过线性变换、非线性融合和打分得到注意力分数这是经典的 Seq2Seq 注意力实现方式。24.5torch.bmm实现了对 values 的批量加权求和这是注意力输出的关键一步。25. 学习感悟这一节特别有价值因为它让注意力机制第一次真正“落地”成了一个你能看懂的计算过程。以前我们说模型在关注某些位置模型在动态分配注意力这些话听起来都很抽象。但代码一拆开你会发现它其实很朴素先比较相关性再把相关性变成权重再按权重把信息汇总出来。也就是说注意力机制的伟大之处不在于它特别复杂而在于它用一种很自然的方式把“选择性读信息”这件事变成了可训练的模块。

更多文章