动手学深度学习——转置卷积代码

张开发
2026/4/21 20:32:41 15 分钟阅读

分享文章

动手学深度学习——转置卷积代码
这一篇要比上一节更偏“动手验证”重点不是再空讲概念而是通过代码把这几件事看清楚转置卷积到底怎么计算ConvTranspose2d怎么用kernel_size、padding、stride如何影响输出转置卷积和普通卷积在形状变化上有什么关系1. 前言上一篇我们已经学习了**转置卷积Transposed Convolution**的基本概念。我们知道转置卷积常用于上采样恢复特征图空间分辨率语义分割、生成模型等任务但如果只停留在概念层面还是不够扎实。真正要理解一个算子最好的方法之一就是看代码、跑代码、验证输出。因此这一节我们就按照《动手学深度学习》的思路用代码来具体观察一个小输入矩阵经过转置卷积后变成什么样转置卷积的输出是如何形成的不同参数会怎样改变输出尺寸2. 一个最基础的手写示例李沐这里一开始通常不会直接上ConvTranspose2d而是先写一个最小化版本的函数帮助我们理解转置卷积的计算过程。例如import torch from d2l import torch as d2l def trans_conv(X, K): h, w K.shape Y torch.zeros((X.shape[0] h - 1, X.shape[1] w - 1)) for i in range(X.shape[0]): for j in range(X.shape[1]): Y[i:ih, j:jw] X[i, j] * K return Y这段代码非常关键因为它直接把转置卷积最核心的计算过程写出来了。3. 这段代码在做什么我们逐步拆解一下。3.1 输入X表示输入矩阵。K表示卷积核。3.2 输出大小Y torch.zeros((X.shape[0] h - 1, X.shape[1] w - 1))如果输入大小是m × n卷积核大小是h × w那么这里先构造一个大小为(m h - 1) × (n w - 1)的输出矩阵。这正体现了转置卷积“让输出变大”的特点。3.3 核心循环for i in range(X.shape[0]): for j in range(X.shape[1]): Y[i:ih, j:jw] X[i, j] * K这就是转置卷积最本质的操作输入中的每个元素X[i, j]都会乘上整个卷积核K然后加到输出矩阵Y的一个局部区域中。而多个输入元素投影后的区域可能会重叠重叠部分就直接相加。这正是转置卷积和普通卷积最大的直观区别。4. 代入一个具体例子接下来我们用一个最经典的例子X torch.tensor([[0.0, 1.0], [2.0, 3.0]]) K torch.tensor([[0.0, 1.0], [2.0, 3.0]]) trans_conv(X, K)输出结果为tensor([[ 0., 0., 1.], [ 0., 4., 6.], [ 4., 12., 9.]])5. 这个结果是怎么来的这个地方非常适合博客里详细解释因为它能让读者真正看懂转置卷积。5.1 输入左上角元素00 * K [[0, 0], [0, 0]]加到输出左上角区域不产生影响。5.2 输入右上角元素11 * K [[0, 1], [2, 3]]加到输出中从(0,1)开始的区域。5.3 输入左下角元素22 * K [[0, 2], [4, 6]]加到输出中从(1,0)开始的区域。5.4 输入右下角元素33 * K [[0, 3], [6, 9]]加到输出中从(1,1)开始的区域。5.5 重叠部分相加最终把这些局部块叠加起来就得到[[ 0., 0., 1.], [ 0., 4., 6.], [ 4., 12., 9.]]这样一看转置卷积的结果就不再神秘了。6. 用 PyTorch 的ConvTranspose2d实现同样过程手写代码只是为了理解原理。真正做深度学习时通常还是直接调用框架提供的算子。对应代码如下X torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]]) K torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]]) tconv torch.nn.ConvTranspose2d(1, 1, kernel_size2, biasFalse) tconv.weight.data K tconv(X)输出结果通常是tensor([[[[ 0., 0., 1.], [ 0., 4., 6.], [ 4., 12., 9.]]]], grad_fn...)可以看到这和我们手写函数得到的结果是一致的。7. 为什么这里输入要写成四维张量很多初学者会被这里的形状绕住。X.shape (1, 1, 2, 2)这四个维度分别表示批量大小batch_size 1输入通道数in_channels 1高height 2宽width 2这是 PyTorch 卷积层统一使用的输入格式。所以虽然我们直观看到的是一个2 × 2矩阵但喂给网络层时必须补成四维张量。8.padding对转置卷积输出有什么影响接下来就进入参数观察部分。先看代码tconv torch.nn.ConvTranspose2d(1, 1, kernel_size2, padding1, biasFalse) tconv.weight.data K tconv(X)这时输出尺寸会变小。这说明什么在转置卷积中padding会影响输出边缘通常可以理解为对输出做裁剪。这和普通卷积里“给输入补零”的直觉不太一样所以特别容易混。9.stride对输出大小的影响再看步幅。tconv torch.nn.ConvTranspose2d(1, 1, kernel_size2, stride2, biasFalse) tconv.weight.data K tconv(X)当stride2时输出尺寸会明显变大。这说明转置卷积中的步幅越大输出通常越大。直观上可以理解为输入元素在投影到输出时间隔被拉开了于是整体输出范围增大。10. 用公式验证输出大小二维转置卷积的输出大小常用公式为输出大小 (输入大小 - 1) * stride - 2 * padding kernel_size例如输入大小2stride 1padding 0kernel_size 2那么输出大小为(2 - 1) * 1 - 0 2 3所以输出是3 × 3。如果改成输入大小2stride 2padding 0kernel_size 2则输出大小为(2 - 1) * 2 - 0 2 4所以输出会变成4 × 4。这和代码实验是对应的。11. 转置卷积可以实现“放大特征图”通过上面的代码我们可以非常清楚地看到普通卷积常常把图变小转置卷积常常把图变大因此在语义分割中当我们有一个较小的特征图希望把它恢复到更高分辨率时就可以考虑使用转置卷积。也就是说这一节代码不是孤立的它是在为后面的 FCN 做技术准备。12. 卷积和转置卷积在形状变化上的对照这个点李沐也会特别强调。假设一个普通卷积层conv nn.Conv2d(10, 20, kernel_size5, padding2, stride3)输入形状是X torch.rand(size(1, 10, 16, 16))那么输出形状可能会变成Y conv(X) Y.shape如果我们再构造一个“对应参数”的转置卷积层tconv nn.ConvTranspose2d(20, 10, kernel_size5, padding2, stride3)然后把刚才卷积得到的Y输入进去tconv(Y).shape最终会发现输出形状回到了(1, 10, 16, 16)13. 这说明了什么这说明转置卷积在形状变换关系上确实和普通卷积密切相关。也就是说普通卷积把16 × 16变成更小对应的转置卷积可以把这个更小的结果变回16 × 16注意这里说的是形状可以恢复但并不意味着数值内容一定恢复成原来那个输入这两点一定要区分开。14. 为什么叫“转置卷积”从代码实验中我们已经能感觉到转置卷积并不是普通卷积的简单反向播放。它之所以叫“转置卷积”是因为从矩阵表示角度看普通卷积可以写成一个线性变换矩阵乘法而转置卷积对应的是这个矩阵的转置形式所以这里的“转置”本质上是线性代数意义上的转置不是说它一定能把原输入精确还原回来。15. 这节代码最该掌握什么如果从“考试/面试/博客总结”角度看这一节最重要的是以下几点15.1 手写版计算过程要知道转置卷积是输入元素乘卷积核投影到输出局部区域重叠部分相加15.2ConvTranspose2d的基本使用方式至少要知道输入是四维张量可以手动设置权重核可以通过stride、padding控制输出形状15.3 输出尺寸公式这一点非常常用输出大小 (输入大小 - 1) * stride - 2 * padding kernel_size15.4 它和普通卷积的关系要知道它和卷积有密切联系但它不是严格意义上的“逆卷积”。16. 本节总结这一节我们通过代码具体验证了转置卷积核心内容可以总结为以下几点。16.1 手写实现帮助理解原理输入中的每个元素都会乘上卷积核并加到输出局部区域中。16.2ConvTranspose2d可以直接实现转置卷积PyTorch 中已经提供了标准接口。16.3padding和stride都会影响输出尺寸其中stride增大时输出通常也会增大。16.4 转置卷积能恢复空间尺寸因此非常适合语义分割等需要上采样的任务。16.5 它和普通卷积在形状变化上密切相关但不是严格意义上的可逆还原。17. 学习感悟这一节最大的价值在于它把“转置卷积”从一个抽象概念变成了一个你能亲手验证的运算过程。很多深度学习知识一旦能用一个小矩阵例子跑通理解就会扎实很多。转置卷积就是典型代表。前面只说它“能上采样”可能还比较模糊但通过这节代码我们就能真正看到它为什么能放大输出参数又是如何控制这个放大过程的。

更多文章