MMPose实战:手把手教你用HRNet训练自定义关键点检测模型(附完整配置文件解析)

张开发
2026/4/21 15:12:52 15 分钟阅读

分享文章

MMPose实战:手把手教你用HRNet训练自定义关键点检测模型(附完整配置文件解析)
MMPose实战从零构建HRNet自定义关键点检测模型当我们需要在特定场景下实现精准的人体姿态识别时开源工具箱MMPose无疑是最佳选择之一。本文将带你从零开始使用HRNet骨干网络训练一个自定义关键点检测模型涵盖数据准备、配置文件修改到模型优化的全流程。1. 环境准备与数据标注在开始之前我们需要搭建适合MMPose的开发环境。推荐使用Python 3.8和PyTorch 1.8的组合conda create -n mmpose python3.8 -y conda activate mmpose pip install torch torchvision torchaudio pip install openmim mim install mmengine mmcv-full mim install mmpose对于自定义数据集COCO格式是最为推荐的标注方式。一个典型的标注文件应包含以下结构{ images: [ { id: 1, width: 640, height: 480, file_name: image001.jpg } ], annotations: [ { id: 1, image_id: 1, category_id: 1, keypoints: [x1,y1,v1,...,xk,yk,vk], num_keypoints: k, bbox: [x,y,width,height] } ], categories: [ { id: 1, name: person, keypoints: [nose,left_eye,...], skeleton: [[0,1],[1,3],...] } ] }关键点可见性标记v的取值规则0未标注1标注但不可见2标注且可见2. 配置文件深度解析MMPose采用模块化的配置文件系统我们需要重点关注以下几个核心部分2.1 数据集配置首先创建数据集元信息文件configs/_base_/datasets/custom.pydataset_info dict( dataset_namecustom, keypoint_info{ 0: dict(namepoint1, id0, color[255, 0, 0], typeupper, swap), # 其他关键点定义... }, skeleton_info{ 0: dict(link(point1, point2), id0, color[255, 0, 0]), # 其他骨骼连接定义... }, joint_weights[1.0, 1.0, ...], # 各关键点权重 sigmas[0.025, 0.025, ...] # 各关键点标准差 )2.2 模型配置HRNet的典型配置如下所示model dict( typeTopdownPoseEstimator, data_preprocessordict( typePoseDataPreprocessor, mean[123.675, 116.28, 103.53], std[58.395, 57.12, 57.375], bgr_to_rgbTrue ), backbonedict( typeHRNet, in_channels3, extradict( stage1dict( num_modules1, num_branches1, blockBOTTLENECK, num_blocks(4,), num_channels(64,) ), stage2dict( num_modules1, num_branches2, blockBASIC, num_blocks(4,4), num_channels(32,64) ), # 更多阶段配置... ) ), headdict( typeHeatmapHead, in_channels32, out_channelsnum_keypoints, lossdict(typeKeypointMSELoss, use_target_weightTrue), decoderdict( typeHeatmapDecoder, input_size(256, 256), heatmap_size(64, 64) ) ) )2.3 训练策略配置优化器和学习率策略对模型性能至关重要optim_wrapper dict( optimizerdict(typeAdamW, lr0.001, weight_decay0.01), clip_graddict(max_norm1.0, norm_type2) ) param_scheduler [ dict( typeLinearLR, start_factor0.001, by_epochTrue, begin0, end10 ), dict( typeMultiStepLR, milestones[170, 200], gamma0.1, by_epochTrue ) ]3. 数据流水线构建MMPose的数据处理流程非常灵活下面是一个典型的数据增强配置train_pipeline [ dict(typeLoadImage), dict(typeGetBBoxCenterScale), dict(typeRandomFlip, directionhorizontal), dict(typeRandomHalfBody), dict(typeRandomBBoxTransform), dict(typeTopdownAffine, input_size(256, 256)), dict(typeGenerateTarget, target_typeheatmap, encoderdict( typeHeatmapEncoder, input_size(256, 256), heatmap_size(64, 64), sigma2.0 )), dict(typePackPoseInputs) ]关键数据变换说明变换类型功能描述重要参数RandomFlip随机水平翻转flip_prob, directionRandomHalfBody随机半身增强min_keypoints, upper_body_idsRandomBBoxTransform随机缩放旋转scale_factor, rotate_factorTopdownAffine仿射变换到输入尺寸input_size, scale_type4. 训练与优化技巧4.1 启动训练使用以下命令开始模型训练mim train mmpose configs/body_2d_keypoint/topdown_heatmap/custom/hrnet_w32_custom_256x256.py4.2 常见问题解决问题1显存不足解决方案减小batch size或使用梯度累积train_dataloader dict( batch_size32, num_workers4, persistent_workersTrue, samplerdict(typeDefaultSampler, shuffleTrue), collate_fndict(typedefault_collate), datasetdict( typeCustomDataset, data_rootdata/custom/, pipelinetrain_pipeline ) )问题2关键点预测偏移解决方案调整heatmap生成参数encoderdict( typeHeatmapEncoder, input_size(256, 256), heatmap_size(64, 64), sigma2.0, # 适当增大可提高鲁棒性 use_udpTrue # 使用Unbiased Data Processing )4.3 模型评估指标MMPose支持多种评估指标最常用的是OKSObject Keypoint Similarityval_evaluator dict( typeCocoMetric, ann_filedata/custom/annotations/val.json, score_modebbox, # 使用检测框作为实例匹配依据 metricAP, # 使用平均精度 format_onlyFalse )关键评估参数说明参数说明推荐值score_mode实例匹配方式bbox或keypointmetric评估指标AP, AR等iou_thrOKS阈值默认[0.5:0.05:0.95]keypoint_nms关键点NMS通常关闭5. 高级优化策略5.1 知识蒸馏使用更大的HRNet-W48作为教师模型model dict( typeTopdownPoseEstimatorWithDistiller, teacherdict( typeTopdownPoseEstimator, backbonedict(typeHRNet, extradict(...)), # 教师配置 headdict(...) ), studentdict( typeTopdownPoseEstimator, backbonedict(typeHRNet, extradict(...)), # 学生配置 headdict(...) ), distillerdict( typePoseDistiller, student_trainableTrue, components[ dict( student_modulehead.final_layer, teacher_modulehead.final_layer, losses[ dict(typeKLDiscretLoss, nameloss_kd) ] ) ] ) )5.2 模型量化使用TorchQuant进行INT8量化quant_config dict( activation_observerdict(typeHistogramObserver, bins256), weight_observerdict(typeMinMaxObserver), quantizerdict(typeTensorRTQuantizer) ) model dict( typeQuantizableTopdownPoseEstimator, backbonedict(...), headdict(...), quant_configquant_config )5.3 部署优化使用ONNX和TensorRT进行部署优化python tools/deployment/pytorch2onnx.py \ configs/body_2d_keypoint/topdown_heatmap/custom/hrnet_w32_custom_256x256.py \ checkpoints/hrnet_w32_custom_256x256.pth \ --output-file hrnet.onnx \ --shape 1 3 256 256 \ --verify然后使用TensorRT转换trtexec --onnxhrnet.onnx \ --saveEnginehrnet.engine \ --fp16 \ --workspace20486. 实际应用案例6.1 工业质检场景在电路板元件检测中我们定义了12个关键点keypoint_info { 0: dict(namecapacitor_top, id0, color[255, 0, 0], type, swap), 1: dict(namecapacitor_bottom, id1, color[0, 255, 0], type, swap), # 其他元件关键点... }特殊处理技巧增加小目标关键点的权重使用更高的输入分辨率512x512针对金属反光增加数据增强6.2 体育动作分析篮球投篮动作分析的关键配置train_pipeline [ # ...基础变换... dict(typeRandomMotionBlur, prob0.3, kernel_size7), dict(typeRandomOcclusion, prob0.5, occlusion_size(30, 30)), dict(typeTemporalRandomFlip, flip_ratio0.5) ]时序处理技巧使用3D卷积扩展HRNet增加时序一致性损失采用多帧输入单帧输出的训练策略通过以上完整的配置和优化策略我们可以在自定义数据集上训练出高精度的关键点检测模型。HRNet的高分辨率特性使其特别适合需要精确定位的应用场景。

更多文章