从零到一:用PyTorch复现U-Net,搞定遥感图像分割(附完整代码与数据集)

张开发
2026/4/15 10:54:59 15 分钟阅读

分享文章

从零到一:用PyTorch复现U-Net,搞定遥感图像分割(附完整代码与数据集)
从零到一用PyTorch复现U-Net搞定遥感图像分割附完整代码与数据集遥感图像分割一直是计算机视觉领域的热门研究方向尤其在农业监测、城市规划等领域有着广泛的应用。而U-Net作为医学图像分割的经典网络凭借其独特的编码器-解码器结构和跳跃连接在遥感图像分割任务中同样表现出色。本文将手把手带你用PyTorch从零实现一个完整的U-Net模型并在公开的遥感数据集上进行训练和评估。1. 环境准备与数据加载在开始之前我们需要准备好开发环境。推荐使用Python 3.8和PyTorch 1.10版本这样可以确保代码的兼容性。以下是环境配置的具体步骤conda create -n unet python3.8 conda activate unet pip install torch torchvision torchaudio pip install opencv-python matplotlib tqdm numpy对于遥感图像分割任务ISPRS Vaihingen数据集是一个很好的选择。这个数据集包含33张高分辨率航拍图像平均大小约2500×2000像素和对应的标注图涵盖了6个语义类别不透水面、建筑、低矮植被、树木、汽车和背景。import os from torch.utils.data import Dataset import cv2 import numpy as np class RemoteSensingDataset(Dataset): def __init__(self, img_dir, mask_dir, transformNone): self.img_dir img_dir self.mask_dir mask_dir self.transform transform self.images os.listdir(img_dir) def __len__(self): return len(self.images) def __getitem__(self, idx): img_path os.path.join(self.img_dir, self.images[idx]) mask_path os.path.join(self.mask_dir, self.images[idx].replace(.tif, _label.tif)) image cv2.imread(img_path) image cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) if self.transform: augmented self.transform(imageimage, maskmask) image augmented[image] mask augmented[mask] return image, mask提示在实际项目中数据增强对提升模型性能至关重要。推荐使用albumentations库进行数据增强它针对图像分割任务进行了优化且与PyTorch兼容性良好。2. U-Net模型架构详解U-Net的核心思想是通过编码器提取特征再通过解码器恢复空间信息同时使用跳跃连接保留细节。下面我们分模块实现U-Net的各个组件。2.1 编码器模块编码器由多个下采样块组成每个块包含两个3×3卷积层每个卷积后接ReLU激活和BatchNorm和一个2×2最大池化层import torch import torch.nn as nn import torch.nn.functional as F class DoubleConv(nn.Module): (卷积 [BN] ReLU) * 2 def __init__(self, in_channels, out_channels): super().__init__() self.double_conv nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size3, padding1), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue), nn.Conv2d(out_channels, out_channels, kernel_size3, padding1), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue) ) def forward(self, x): return self.double_conv(x) class Down(nn.Module): 下采样块双卷积 最大池化 def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x)2.2 解码器模块解码器使用转置卷积进行上采样并通过跳跃连接融合编码器的特征class Up(nn.Module): 上采样块转置卷积 特征拼接 双卷积 def __init__(self, in_channels, out_channels, bilinearTrue): super().__init__() if bilinear: self.up nn.Upsample(scale_factor2, modebilinear, align_cornersTrue) else: self.up nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size2, stride2) self.conv DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 self.up(x1) # 计算填充以确保尺寸匹配 diffY x2.size()[2] - x1.size()[2] diffX x2.size()[3] - x1.size()[3] x1 F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x torch.cat([x2, x1], dim1) return self.conv(x)2.3 完整U-Net实现将编码器和解码器组合起来构建完整的U-Net架构class UNet(nn.Module): def __init__(self, n_channels, n_classes, bilinearTrue): super(UNet, self).__init__() self.n_channels n_channels self.n_classes n_classes self.bilinear bilinear self.inc DoubleConv(n_channels, 64) self.down1 Down(64, 128) self.down2 Down(128, 256) self.down3 Down(256, 512) self.down4 Down(512, 1024) self.up1 Up(1024, 512, bilinear) self.up2 Up(512, 256, bilinear) self.up3 Up(256, 128, bilinear) self.up4 Up(128, 64, bilinear) self.outc nn.Conv2d(64, n_classes, kernel_size1) def forward(self, x): x1 self.inc(x) x2 self.down1(x1) x3 self.down2(x2) x4 self.down3(x3) x5 self.down4(x4) x self.up1(x5, x4) x self.up2(x, x3) x self.up3(x, x2) x self.up4(x, x1) logits self.outc(x) return logits注意在实际应用中根据显存大小和输入图像尺寸可能需要调整网络深度或初始通道数。对于高分辨率遥感图像可以考虑使用更深的网络或更大的初始通道数。3. 训练策略与技巧训练U-Net时有几个关键因素需要考虑损失函数选择、学习率调度和数据增强策略。3.1 损失函数选择对于多类分割任务常用的损失函数包括交叉熵损失适用于类别平衡的数据集Dice损失对类别不平衡问题更鲁棒组合损失结合交叉熵和Dice损失的优势class DiceLoss(nn.Module): def __init__(self, weightNone, size_averageTrue): super(DiceLoss, self).__init__() def forward(self, inputs, targets, smooth1): inputs torch.sigmoid(inputs) inputs inputs.view(-1) targets targets.view(-1) intersection (inputs * targets).sum() dice (2.*intersection smooth)/(inputs.sum() targets.sum() smooth) return 1 - dice class DiceBCELoss(nn.Module): def __init__(self): super(DiceBCELoss, self).__init__() def forward(self, inputs, targets, smooth1): BCE F.binary_cross_entropy_with_logits(inputs, targets, reductionmean) inputs torch.sigmoid(inputs) inputs inputs.view(-1) targets targets.view(-1) intersection (inputs * targets).sum() dice_loss 1 - (2.*intersection smooth)/(inputs.sum() targets.sum() smooth) return BCE dice_loss3.2 训练循环实现下面是一个完整的训练循环实现包含验证阶段和模型保存def train_model(model, device, train_loader, val_loader, criterion, optimizer, num_epochs25): best_dice 0.0 for epoch in range(num_epochs): print(fEpoch {epoch1}/{num_epochs}) print(- * 10) model.train() running_loss 0.0 for images, masks in tqdm(train_loader): images images.to(device) masks masks.to(device) optimizer.zero_grad() outputs model(images) loss criterion(outputs, masks) loss.backward() optimizer.step() running_loss loss.item() * images.size(0) epoch_loss running_loss / len(train_loader.dataset) print(fTrain Loss: {epoch_loss:.4f}) # 验证阶段 val_loss, dice_score evaluate(model, device, val_loader, criterion) print(fVal Loss: {val_loss:.4f}, Dice: {dice_score:.4f}) # 保存最佳模型 if dice_score best_dice: best_dice dice_score torch.save(model.state_dict(), best_model.pth) print(fTraining complete, best Dice: {best_dice:.4f}) def evaluate(model, device, loader, criterion): model.eval() running_loss 0.0 dice_score 0.0 with torch.no_grad(): for images, masks in loader: images images.to(device) masks masks.to(device) outputs model(images) loss criterion(outputs, masks) running_loss loss.item() * images.size(0) # 计算Dice系数 preds torch.sigmoid(outputs) preds (preds 0.5).float() dice_score dice_coeff(preds, masks).item() * images.size(0) loss running_loss / len(loader.dataset) dice dice_score / len(loader.dataset) return loss, dice3.3 学习率调度与早停为了防止过拟合和加速收敛可以引入学习率调度和早停机制from torch.optim import lr_scheduler optimizer torch.optim.Adam(model.parameters(), lr1e-4) scheduler lr_scheduler.ReduceLROnPlateau(optimizer, max, patience3, factor0.1, verboseTrue) # 在训练循环中更新学习率 scheduler.step(dice_score)4. 模型评估与可视化训练完成后我们需要对模型性能进行全面评估并可视化分割结果。4.1 评估指标计算除了常用的准确率和Dice系数外遥感图像分割还常用以下指标指标名称计算公式说明交并比(IoU)TP/(TPFPFN)预测与真实区域的重叠度精确率TP/(TPFP)预测为正的样本中实际为正的比例召回率TP/(TPFN)实际为正的样本中被预测为正的比例F1分数2*(精确率*召回率)/(精确率召回率)精确率和召回率的调和平均def calculate_metrics(pred, target, threshold0.5): pred (pred threshold).float() target (target 0).float() tp (pred * target).sum() fp (pred * (1 - target)).sum() fn ((1 - pred) * target).sum() tn ((1 - pred) * (1 - target)).sum() precision tp / (tp fp 1e-8) recall tp / (tp fn 1e-8) f1 2 * (precision * recall) / (precision recall 1e-8) iou tp / (tp fp fn 1e-8) return { precision: precision.item(), recall: recall.item(), f1: f1.item(), iou: iou.item() }4.2 结果可视化可视化是理解模型性能的重要手段我们可以将原始图像、真实标注和预测结果并排显示import matplotlib.pyplot as plt def plot_results(image, mask, pred, num_samples3): plt.figure(figsize(15, 5*num_samples)) for i in range(num_samples): idx np.random.randint(0, len(image)) plt.subplot(num_samples, 3, i*31) plt.imshow(image[idx].permute(1, 2, 0).cpu().numpy()) plt.title(Original Image) plt.axis(off) plt.subplot(num_samples, 3, i*32) plt.imshow(mask[idx].cpu().numpy(), cmapgray) plt.title(Ground Truth) plt.axis(off) plt.subplot(num_samples, 3, i*33) plt.imshow(pred[idx].cpu().numpy(), cmapgray) plt.title(Prediction) plt.axis(off) plt.tight_layout() plt.show()在实际项目中我发现几个关键点对提升模型性能特别有效数据增强的多样性除了常规的旋转、翻转加入色彩抖动和随机弹性变形可以显著提升模型泛化能力损失函数的选择对于类别不平衡问题Dice损失通常比交叉熵表现更好学习率预热在训练初期使用较低的学习率然后逐步增加有助于稳定训练过程

更多文章