PyTorch中通过训练图像去雾数据集 建立基于SFNet图像去雾算法的完整系统

张开发
2026/4/14 21:58:42 15 分钟阅读

分享文章

PyTorch中通过训练图像去雾数据集 建立基于SFNet图像去雾算法的完整系统
PyTorch中通过训练图像去雾数据集 建立基于SFNet图像去雾算法的完整系统文章目录(a) 整体架构(b) 浅层特征提取(c) ResBlock(d) Decoupler(e) Modulator总结1. 环境配置2. 数据集准备3. SFNet模型定义4. 数据加载与预处理5. 模型训练6. 界面代码1. main.py - 训练和测试脚本2. SFNet_model.py - SFNet模型定义3. GUI.py - GUI界面代码运行步骤以下文字及代码仅供参考。SFNet图像去雾算法 PyTorch 附图像去雾数据集基于SFNet图像去雾算法的完整系统包括环境配置、数据集准备、模型训练、优化以及界面代码深度学习模型用于图像恢复如去雾、超分辨率等的详细设计。让我们深入解析这个架构的各个部分。(a) 整体架构整体架构展示了模型如何处理输入的降质图像并输出恢复后的图像。流程如下输入层接收降质图像。浅层特征提取通过一个Conv 3x3卷积层提取浅层特征。ResBlock堆叠多个残差块ResBlocks被串联起来每个ResBlock内部包含复杂的特征学习机制见© ResBlock。这些ResBlocks负责学习更深层次的特征表示。上采样与下采样在某些ResBlocks之间使用Conv 1x1进行通道调整并通过箭头指示的上采样或下采样操作来改变特征图的空间尺寸。最终恢复经过一系列特征学习后通过Conv 3x3层生成最终的恢复图像。(b) 浅层特征提取浅层特征提取模块主要由几个基础的卷积操作组成Conv 3x3标准的3x3卷积核用于提取局部特征。Conv 1x11x1卷积用于调整通道数不改变空间维度。MCBF和MDSF可能是特定的多尺度融合模块用于结合不同尺度的信息。© ResBlockResBlock是整个网络的核心组件它包括多个Conv 3x3层用于逐层提取特征。Decoupler和Modulator模块见(d)和(e))用于解耦和调制特征增强模型的表达能力。残差连接用⊕符号表示将输入直接加到输出上有助于缓解梯度消失问题。(d) DecouplerDecoupler模块的作用是将输入特征分解为两部分GAP全局平均池化获取全局信息。Split将特征分为两部分分别进行不同的处理。Invert可能是一个逆变换操作用于恢复或转换特征。Concat将处理后的特征重新拼接在一起。(e) ModulatorModulator模块对特征进行调制Sum、GAP、FC全连接层、Concat、Softmax、Split等操作共同作用实现对特征的非线性变换和选择性增强。这些操作有助于模型关注更重要的特征抑制不重要的信息。总结该模型通过多层次的特征提取和复杂的特征调制机制能够有效地从降质图像中恢复出高质量的图像。其设计考虑了特征的多尺度融合、深度残差学习以及特征的动态调制体现了现代深度学习模型在图像恢复任务中的先进性和复杂性。1. 环境配置首先确保你的环境中安装了必要的库pipinstalltorch torchvision opencv-python pillow PyQt52. 数据集准备假设你已经有了RSHAZE或其他图像去雾数据集并且已经按照以下结构组织好data/ train/ hazy/ gt/ test/ hazy/ gt/3. SFNet模型定义这里我们简化地展示一个基础的SFNet模型定义实际应用中请参考官方或相关论文中的具体实现importtorchimporttorch.nnasnnclassSFNet(nn.Module):def__init__(self):super(SFNet,self).__init__()self.encodernn.Sequential(nn.Conv2d(3,64,kernel_size3,padding1),nn.ReLU(),nn.Conv2d(64,128,kernel_size3,padding1),nn.ReLU())self.decodernn.Sequential(nn.Conv2d(128,64,kernel_size3,padding1),nn.ReLU(),nn.Conv2d(64,3,kernel_size3,padding1))defforward(self,x):xself.encoder(x)xself.decoder(x)returnx4. 数据加载与预处理fromtorch.utils.dataimportDataset,DataLoaderfromPILimportImageimportosfromtorchvisionimporttransformsclassDehazeDataset(Dataset):def__init__(self,hazy_dir,gt_dir,transformNone):self.hazy_imagessorted([os.path.join(hazy_dir,img)forimginos.listdir(hazy_dir)])self.gt_imagessorted([os.path.join(gt_dir,img)forimginos.listdir(gt_dir)])self.transformtransformdef__len__(self):returnlen(self.hazy_images)def__getitem__(self,idx):hazy_imageImage.open(self.hazy_images[idx]).convert(RGB)gt_imageImage.open(self.gt_images[idx]).convert(RGB)ifself.transform:hazy_imageself.transform(hazy_image)gt_imageself.transform(gt_image)returnhazy_image,gt_image transformtransforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])train_datasetDehazeDataset(data/train/hazy,data/train/gt,transformtransform)train_loaderDataLoader(train_dataset,batch_size8,shuffleTrue)5. 模型训练modelSFNet()criterionnn.MSELoss()optimizertorch.optim.Adam(model.parameters(),lr0.001)num_epochs10forepochinrange(num_epochs):fori,(hazy,gt)inenumerate(train_loader):optimizer.zero_grad()outputsmodel(hazy)losscriterion(outputs,gt)loss.backward()optimizer.step()print(fEpoch [{epoch1}/{num_epochs}], Step [{i1}/{len(train_loader)}], Loss:{loss.item()})6. 界面代码SFNet图像去雾系统包括训练、测试和推理GUI界面我们需要编写多个Python脚本文件。以下是详细的代码示例1.main.py- 训练和测试脚本importargparseimportosimporttorchimporttorch.nnasnnfromtorchvisionimporttransformsfromtorch.utils.dataimportDataLoaderfromPILimportImageimportnumpyasnpfromSFNet_modelimportSFNet# 假设SFNet模型定义在SFNet_model.py中classDehazeDataset(Dataset):def__init__(self,hazy_dir,gt_dir,transformNone):self.hazy_imagessorted([os.path.join(hazy_dir,img)forimginos.listdir(hazy_dir)])self.gt_imagessorted([os.path.join(gt_dir,img)forimginos.listdir(gt_dir)])self.transformtransformdef__len__(self):returnlen(self.hazy_images)def__getitem__(self,idx):hazy_imageImage.open(self.hazy_images[idx]).convert(RGB)gt_imageImage.open(self.gt_images[idx]).convert(RGB)ifself.transform:hazy_imageself.transform(hazy_image)gt_imageself.transform(gt_image)returnhazy_image,gt_imagedeftrain(args):transformtransforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])datasetDehazeDataset(os.path.join(args.data_dir,train,hazy),os.path.join(args.data_dir,train,gt),transformtransform)dataloaderDataLoader(dataset,batch_sizeargs.batch_size,shuffleTrue)modelSFNet().cuda()criterionnn.MSELoss()optimizertorch.optim.Adam(model.parameters(),lrargs.learning_rate)forepochinrange(args.num_epoch):fori,(hazy,gt)inenumerate(dataloader):hazy,gthazy.cuda(),gt.cuda()optimizer.zero_grad()outputsmodel(hazy)losscriterion(outputs,gt)loss.backward()optimizer.step()print(fEpoch [{epoch1}/{args.num_epoch}], Step [{i1}/{len(dataloader)}], Loss:{loss.item()})torch.save(model.state_dict(),fresults/SFNet/{args.data}/Training-Results/Epoch_{epoch1}.pkl)deftest(args):transformtransforms.Compose([transforms.Resize((256,256)),transforms.ToTensor()])datasetDehazeDataset(os.path.join(args.data_dir,test,hazy),os.path.join(args.data_dir,test,gt),transformtransform)dataloaderDataLoader(dataset,batch_sizeargs.batch_size,shuffleFalse)modelSFNet().cuda()model.load_state_dict(torch.load(args.test_model))model.eval()withtorch.no_grad():fori,(hazy,gt)inenumerate(dataloader):hazy,gthazy.cuda(),gt.cuda()outputsmodel(hazy)ifargs.save_image:forjinrange(outputs.size(0)):output_imgtransforms.ToPILImage()(outputs[j].cpu())output_img.save(fresults/SFNet/{args.data}/Test-Results/image_{i*args.batch_sizej}.png)if__name____main__:parserargparse.ArgumentParser(descriptionSFNet Image Dehazing)parser.add_argument(--data_dir,typestr,requiredTrue,helpdirectory of the dataset)parser.add_argument(--data,typestr,requiredTrue,helpdataset name)parser.add_argument(--mode,typestr,requiredTrue,choices[train,test],helptrain or test mode)parser.add_argument(--batch_size,typeint,default4,helpbatch size)parser.add_argument(--learning_rate,typefloat,default2e-5,helplearning rate)parser.add_argument(--num_epoch,typeint,default300,helpnumber of epochs)parser.add_argument(--test_model,typestr,default,helppath to the trained model for testing)parser.add_argument(--save_image,typebool,defaultFalse,helpwhether to save dehazed images)argsparser.parse_args()ifargs.modetrain:train(args)elifargs.modetest:test(args)2.SFNet_model.py- SFNet模型定义importtorchimporttorch.nnasnnclassSFNet(nn.Module):def__init__(self):super(SFNet,self).__init__()self.encodernn.Sequential(nn.Conv2d(3,64,kernel_size3,padding1),nn.ReLU(),nn.Conv2d(64,128,kernel_size3,padding1),nn.ReLU())self.decodernn.Sequential(nn.Conv2d(128,64,kernel_size3,padding1),nn.ReLU(),nn.Conv2d(64,3,kernel_size3,padding1))defforward(self,x):xself.encoder(x)xself.decoder(x)returnx3.GUI.py- GUI界面代码importsysfromPyQt5.QtWidgetsimportQApplication,QWidget,QPushButton,QVBoxLayout,QLabel,QFileDialogfromPyQt5.QtGuiimportQPixmapimportcv2importnumpyasnpimporttorchfromtorchvisionimporttransformsfromSFNet_modelimportSFNetclassDehazeApp(QWidget):def__init__(self):super().__init__()self.initUI()definitUI(self):self.setWindowTitle(图像去雾)self.setGeometry(100,100,800,400)layoutQVBoxLayout()self.btn_selectQPushButton(选择图像,self)self.btn_select.clicked.connect(self.select_image)layout.addWidget(self.btn_select)self.btn_dehazeQPushButton(SFNet去雾,self)self.btn_dehaze.clicked.connect(self.dehaze_image)layout.addWidget(self.btn_dehaze)self.image_labelQLabel(self)layout.addWidget(self.image_label)self.setLayout(layout)defselect_image(self):optionsQFileDialog.Options()fileName,_QFileDialog.getOpenFileName(self,选择图像,,Images (*.png *.xpm *.jpg *.bmp);;All Files (*),optionsoptions)iffileName:self.image_pathfileName pixmapQPixmap(fileName)self.image_label.setPixmap(pixmap.scaled(400,400))defdehaze_image(self):ifhasattr(self,image_path):# Load and preprocess imageimagecv2.imread(self.image_path)imagecv2.resize(image,(256,256))imageimage/255.0imagenp.transpose(image,(2,0,1))imagetorch.tensor(image,dtypetorch.float32).unsqueeze(0).cuda()# Load pre-trained modelmodelSFNet().cuda()model.load_state_dict(torch.load(results/SFNet/Outdoor/Training-Results/Best.pkl))model.eval()# Perform dehazingwithtorch.no_grad():outputmodel(image).squeeze().cpu().numpy()outputnp.transpose(output,(1,2,0))output(output*255).astype(np.uint8)# Display resultcv2.imwrite(dehazed.jpg,output)pixmapQPixmap(dehazed.jpg)self.image_label.setPixmap(pixmap.scaled(400,400))if__name____main__:appQApplication(sys.argv)exDehazeApp()ex.show()sys.exit(app.exec_())运行步骤训练模型python main.py--data_dirdehaze--dataOutdoor--modetrain--batch_size4--learning_rate2e-5--num_epoch300测试模型python main.py--data_dirdehaze--dataOutdoor--modetest--batch_size4--test_modelresults/SFNet/Outdoor/Training-Results/Best.pkl--save_imageTrue运行GUI界面python GUI.py确保所有路径正确并根据实际情况调整参数和文件路径。

更多文章