别再死记硬背SW-MSA了!用Python手动画图,带你搞懂Swin Transformer的滑动窗口到底怎么滑

张开发
2026/4/15 6:11:08 15 分钟阅读

分享文章

别再死记硬背SW-MSA了!用Python手动画图,带你搞懂Swin Transformer的滑动窗口到底怎么滑
用Python动态图解Swin Transformer滑动窗口自注意力机制第一次读到Swin Transformer论文时我被那个滑动窗口自注意力(SW-MSA)的图示搞得一头雾水——窗口到底是怎么滑动的循环位移后怎么保持计算效率掩码矩阵又是如何工作的直到我决定用Python把这些过程动态画出来一切才变得清晰可见。本文将带你用matplotlib一步步实现SW-MSA的可视化从基础窗口划分到循环位移技巧再到掩码矩阵的应用让抽象的理论变成可运行的代码动画。1. 环境准备与基础概念在开始绘制之前我们需要明确几个关键概念。Swin Transformer中的滑动窗口自注意力(SW-MSA)是对标准窗口自注意力(W-MSA)的扩展主要解决窗口间信息隔离的问题。其核心创新在于规则网格划分将输入特征图划分为不重叠的均匀窗口滑动窗口重组通过偏移产生新的窗口配置循环位移技巧将不规则的窗口转换为规则形状以保持计算效率掩码机制确保自注意力计算时只考虑原本相邻的区域先准备好Python环境我们需要以下工具包import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from matplotlib.patches import Rectangle from matplotlib.collections import PatchCollection定义一个基础函数来绘制特征图和窗口划分def draw_feature_map(ax, img_size56, window_size7, shift0): # 创建空白特征图 feature_map np.zeros((img_size, img_size)) # 绘制网格线 for i in range(0, img_size1, window_size): ax.axhline(i-shift, colorgray, linestyle--, alpha0.5) ax.axvline(i-shift, colorgray, linestyle--, alpha0.5) # 设置坐标轴 ax.set_xlim(0, img_size) ax.set_ylim(0, img_size) ax.set_aspect(equal) return feature_map2. 基础窗口划分与滑动窗口可视化让我们首先实现基本的窗口划分。假设我们有一个56×56的特征图窗口大小为7×7def basic_window_partition(): fig, ax plt.subplots(figsize(8,8)) feature_map draw_feature_map(ax) # 添加窗口编号 window_id 0 for i in range(0, 56, 7): for j in range(0, 56, 7): rect Rectangle((i, j), 7, 7, linewidth2, edgecolorr, facecolornone) ax.add_patch(rect) ax.text(i3.5, j3.5, str(window_id), hacenter, vacenter, colorred) window_id 1 plt.title(Basic Window Partition (W-MSA)) plt.show()执行这段代码会生成一个8×8的窗口网格每个窗口被红色边框标记并编号。这就是W-MSA的基础划分方式。接下来实现滑动窗口划分def sliding_window_partition(): fig, (ax1, ax2) plt.subplots(1, 2, figsize(16,8)) # 原始划分 draw_feature_map(ax1) window_id 0 for i in range(0, 56, 7): for j in range(0, 56, 7): rect Rectangle((i, j), 7, 7, linewidth2, edgecolorr, facecolornone) ax1.add_patch(rect) ax1.text(i3.5, j3.5, str(window_id), hacenter, vacenter, colorred) window_id 1 ax1.set_title(Original Partition) # 滑动窗口划分 (偏移3个像素) shift 3 draw_feature_map(ax2, shiftshift) window_id 0 for i in range(-shift, 56, 7): for j in range(-shift, 56, 7): rect Rectangle((i, j), 7, 7, linewidth2, edgecolorb, facecolornone) ax2.add_patch(rect) ax2.text(i3.5, j3.5, str(window_id), hacenter, vacenter, colorblue) window_id 1 ax2.set_title(fShifted Partition (shift{shift})) plt.show()这个可视化清晰地展示了滑动窗口如何产生新的窗口配置——通过将网格偏移窗口大小的一半这里是3像素我们得到了与原始划分不同的窗口布局。3. 循环位移技巧的动态实现滑动窗口划分带来了一个实际问题窗口大小变得不一致。Swin Transformer采用循环位移(cyclic shifting)来解决这个问题。让我们用动画展示这一过程def cyclic_shift_animation(): fig, ax plt.subplots(figsize(8,8)) img_size 56 window_size 7 shift 3 # 初始化特征图 feature_map np.zeros((img_size, img_size)) # 创建窗口集合 rects [] for i in range(-shift, img_size, window_size): for j in range(-shift, img_size, window_size): rect Rectangle((i, j), window_size, window_size, linewidth2, edgecolorb, facecolornone) rects.append(rect) ax.add_patch(rect) # 动画更新函数 def update(frame): for rect in rects: x, y rect.get_xy() new_x x shift if x 0 else x new_y y shift if y 0 else y new_x new_x - img_size if new_x img_size else new_x new_y new_y - img_size if new_y img_size else new_y rect.set_xy((new_x, new_y)) return rects # 创建动画 ani FuncAnimation(fig, update, frames30, interval100, blitTrue) plt.title(Cyclic Shift Animation) plt.show() return ani这个动画展示了如何将超出边界的窗口部分循环位移到另一侧最终得到四个规则的窗口。循环位移后我们可以用标准窗口自注意力计算但需要配合掩码机制来确保注意力只在原本相邻的区域间计算。4. 掩码矩阵的生成与应用掩码是SW-MSA的关键组件它确保循环位移后的窗口在进行自注意力计算时只考虑原本相邻的区域。让我们实现掩码生成和应用的完整流程def generate_mask(window_size7, shift3): # 创建相对位置索引 coords np.stack(np.meshgrid( np.arange(window_size), np.arange(window_size)), axis-1) # 计算相对位置 relative_coords coords[:, :, None, :] - coords[None, None, :, :] relative_coords window_size - 1 # 确保索引非负 # 生成掩码模式 mask np.zeros((window_size**2, window_size**2)) # 左上区域 (原本相邻) mask[:shift*window_size, :shift*window_size] 1 # 右上区域 (原本不相邻) mask[:shift*window_size, shift*window_size:] 0 # 左下区域 (原本不相邻) mask[shift*window_size:, :shift*window_size] 0 # 右下区域 (原本相邻) mask[shift*window_size:, shift*window_size:] 1 return mask def visualize_mask(): mask generate_mask() fig, ax plt.subplots(figsize(8,8)) im ax.imshow(mask, cmapviridis) # 添加网格线 for i in range(0, 49, 7): ax.axhline(i-0.5, colorwhite, linewidth1) ax.axvline(i-0.5, colorwhite, linewidth1) plt.colorbar(im) plt.title(SW-MSA Attention Mask) plt.show()这个掩码矩阵会被加到自注意力得分上在softmax之前将不希望计算的注意力权重设置为负无穷大。实际应用中我们会对不同位移方向的窗口使用不同的掩码模式。5. 完整SW-MSA流程的可视化实现现在我们将所有组件组合起来创建一个完整的SW-MSA流程可视化def full_sw_msa_visualization(): # 设置参数 img_size 56 window_size 7 shift window_size // 2 # 创建图形 fig, axes plt.subplots(2, 2, figsize(16,16)) axes axes.ravel() # 步骤1: 原始窗口划分 draw_feature_map(axes[0]) window_id 0 for i in range(0, img_size, window_size): for j in range(0, img_size, window_size): rect Rectangle((i, j), window_size, window_size, linewidth2, edgecolorr, facecolornone) axes[0].add_patch(rect) axes[0].text(iwindow_size/2, jwindow_size/2, str(window_id), hacenter, vacenter, colorred) window_id 1 axes[0].set_title(Step 1: Original Window Partition (W-MSA)) # 步骤2: 滑动窗口划分 draw_feature_map(axes[1], shiftshift) window_id 0 for i in range(-shift, img_size, window_size): for j in range(-shift, img_size, window_size): rect Rectangle((i, j), window_size, window_size, linewidth2, edgecolorb, facecolornone) axes[1].add_patch(rect) axes[1].text(iwindow_size/2, jwindow_size/2, str(window_id), hacenter, vacenter, colorblue) window_id 1 axes[1].set_title(fStep 2: Shifted Window Partition (shift{shift})) # 步骤3: 循环位移后窗口 draw_feature_map(axes[2]) window_id 0 for i in [0, img_size-shift]: for j in [0, img_size-shift]: rect Rectangle((i, j), window_size, window_size, linewidth2, edgecolorg, facecolornone) axes[2].add_patch(rect) axes[2].text(iwindow_size/2, jwindow_size/2, str(window_id), hacenter, vacenter, colorgreen) window_id 1 axes[2].set_title(Step 3: After Cyclic Shift) # 步骤4: 掩码可视化 mask generate_mask() im axes[3].imshow(mask, cmapviridis) for i in range(0, window_size**21, window_size): axes[3].axhline(i-0.5, colorwhite, linewidth1) axes[3].axvline(i-0.5, colorwhite, linewidth1) plt.colorbar(im, axaxes[3]) axes[3].set_title(Step 4: Attention Mask for SW-MSA) plt.tight_layout() plt.show()这个完整的可视化展示了SW-MSA的四个关键步骤原始窗口划分、滑动窗口重组、循环位移和掩码应用。通过这种动态可视化的方式SW-MSA的工作原理变得直观易懂。

更多文章