Deformable DETR源码精讲:从稀疏采样到小目标检测的实战优化

张开发
2026/4/19 17:45:29 15 分钟阅读

分享文章

Deformable DETR源码精讲:从稀疏采样到小目标检测的实战优化
1. Deformable DETR的核心创新稀疏采样机制Deformable DETR最关键的改进在于引入了稀疏采样机制。传统DETR模型在处理特征图时需要对每个像素点与其他所有像素点进行注意力计算这种全局计算方式带来了巨大的计算开销。想象一下当输入特征图尺寸为100×100时需要计算1万个点之间的两两关系计算量会达到惊人的1亿次而Deformable DETR的聪明之处在于它发现图像中的每个点其实只需要关注周围少数几个关键点就够了。就像我们看一幅画时眼睛会自动聚焦在重要区域而不会平均分配注意力到每个像素。具体实现上模型为每个查询点预测K个采样位置论文中K4只计算这些采样点与当前点的关系。这种设计带来了三个显著优势计算量从O(N²)降到了O(N×K)其中K远小于N由于可以关注更精细的局部特征小目标检测效果明显提升训练收敛速度大幅加快30个epoch就能达到原始DETR 300个epoch的效果2. 多尺度特征融合的实现细节2.1 特征金字塔构建Deformable DETR采用了经典的特征金字塔结构通常包含4个不同尺度的特征图。在代码实现中这通过ResNet backbone和额外的neck层完成# mmdet/models/necks/channel_mapper.py def forward(self, inputs): outs [self.convs[i](inputs[i]) for i in range(len(inputs))] if self.extra_convs: outs.append(self.extra_convs[0](inputs[-1])) return tuple(outs)这里有个精妙的设计所有层级的特征通道数都被统一映射到256维这样不同尺度的特征可以在同一空间进行计算。同时模型还引入了可学习的层级位置编码level_embed让网络能够区分不同层级的特征。2.2 参考点生成机制处理多尺度特征时一个关键挑战是如何统一不同尺度的坐标空间。Deformable DETR通过参考点归一化解决了这个问题# mmdet/models/layers/transformer/deformable_detr_layers.py def get_encoder_reference_points(spatial_shapes, valid_ratios): reference_points_list [] for lvl, (H, W) in enumerate(spatial_shapes): ref_y, ref_x torch.meshgrid(torch.linspace(0.5, H-0.5, H), torch.linspace(0.5, W-0.5, W)) ref_y ref_y.reshape(-1)[None] / (valid_ratios[:,None,lvl,1]*H) ref_x ref_x.reshape(-1)[None] / (valid_ratios[:,None,lvl,0]*W) reference_points_list.append(torch.stack((ref_x, ref_y), -1)) return torch.cat(reference_points_list, 1)这段代码做了三件事为每个特征层级生成网格坐标根据valid_ratios有效区域比例进行归一化将所有层级的参考点拼接起来这样处理后不同尺度的特征点都被映射到了统一的归一化坐标空间方便后续的注意力计算。3. 可变形注意力的代码实现剖析3.1 采样偏移量预测可变形注意力的核心是动态预测采样位置。在代码中这通过两个全连接层实现# mmcv/ops/multi_scale_deform_attn.py sampling_offsets self.sampling_offsets(query).view( bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) attention_weights self.attention_weights(query).view( bs, num_query, self.num_heads, self.num_levels * self.num_points)这里有个非常巧妙的设计采样偏移量直接从查询特征预测得到不需要额外的监督信号。模型会自动学习到哪些位置的偏移对目标检测最有用。实验表明这种数据驱动的方式比手工设计的采样模式效果更好。3.2 双线性特征采样得到采样位置后需要使用双线性插值获取特征值# 伪代码表示采样过程 sampling_value F.grid_sample( value, sampling_grid, modebilinear, padding_modezeros, align_cornersFalse)这个过程就像用放大镜查看特征图对于每个采样点查看它周围4个像素的特征值然后根据距离进行加权平均。这种操作既保留了亚像素级的精度又保持了特征的可微性使得整个系统能够端到端训练。4. 解码器设计与目标查询优化4.1 动态锚框生成Deformable DETR的解码器使用了300个目标查询比原始DETR的100个更多这些查询会逐步细化预测结果# mmdet/models/detectors/deformable_detr.py query_embed self.query_embedding.weight # (300, 512) query_pos, query torch.split(query_embed, c, dim1) reference_points self.reference_points_fc(query_pos).sigmoid()特别值得注意的是参考点坐标是通过sigmoid归一化的这使得初始锚框都位于图像中心区域。在训练过程中模型会学习逐步调整这些参考点位置最终得到准确的检测框。4.2 级联预测头解码器包含6个相同的层每层都会产生中间预测结果# mmdet/models/layers/transformer/deformable_detr_layers.py for layer_id, layer in enumerate(self.layers): output layer(output, reference_points_input) if reg_branches is not None: tmp reg_branches[layer_id](output) new_reference_points tmp inverse_sigmoid(reference_points) reference_points new_reference_points.sigmoid()这种级联结构有两个好处让后续层能够基于前面层的预测结果进行细化在训练时可以提供多层次的监督信号加速收敛5. 损失函数与匹配策略5.1 匈牙利匹配算法Deformable DETR延续了DETR的二分图匹配策略但改进了cost计算方式# mmdet/models/task_modules/assigners/hungarian_assigner.py cost torch.stack([ self.cls_cost(pred_instances, gt_instances), self.reg_cost(pred_instances, gt_instances), self.iou_cost(pred_instances, gt_instances) ]).sum(dim0) matched_row_inds, matched_col_inds linear_sum_assignment(cost)这里综合考虑了分类得分、边界框位置和IoU三个因素确保每个真实框都能匹配到最合适的预测框。这种全局最优的匹配方式比传统的基于IoU的匹配更加鲁棒。5.2 小目标检测优化技巧针对小目标检测我们在实践中总结了几点经验增加小尺度特征图的权重在数据增强中适当保留小目标调整正负样本匹配阈值使用更高分辨率的测试图像这些技巧配合Deformable DETR的稀疏采样机制可以将小目标检测的AP提升5-10个百分点。特别是在无人机航拍、医学影像等小目标密集的场景效果提升非常明显。

更多文章