别再死磕TensorFlow了!用PyTorch VGG16从零搭建猫狗分类器(附完整代码与数据集处理)

张开发
2026/4/14 12:25:28 15 分钟阅读

分享文章

别再死磕TensorFlow了!用PyTorch VGG16从零搭建猫狗分类器(附完整代码与数据集处理)
从TensorFlow到PyTorchVGG16猫狗分类实战指南当深度学习新手第一次接触计算机视觉项目时往往会陷入框架选择的困境。网上铺天盖地的TensorFlow教程确实为初学者提供了便利但当课程或项目要求使用PyTorch时这种便利反而成了障碍。本文将带你跨越这个鸿沟用PyTorch实现一个完整的猫狗分类器特别适合那些需要快速从TensorFlow思维转换到PyTorch的开发者。1. 为什么选择PyTorch进行计算机视觉项目PyTorch近年来在学术界和工业界的采用率持续攀升这并非偶然。与TensorFlow相比PyTorch的动态计算图特性让调试过程更加直观特别适合研究型项目和教育场景。在猫狗分类这样的经典计算机视觉任务中PyTorch提供的torchvision库包含了大量预训练模型和数据处理工具能显著降低入门门槛。对于习惯TensorFlow的开发者PyTorch的几个核心差异点值得注意即时执行模式PyTorch不需要先定义完整的计算图代码执行顺序就是计算图的构建顺序Pythonic风格PyTorch的API设计更接近原生Python学习曲线相对平缓调试友好性可以直接使用Python调试工具如pdb检查中间变量# PyTorch与TensorFlow的简单对比示例 import torch import tensorflow as tf # PyTorch方式 x torch.tensor([1.0], requires_gradTrue) y x ** 2 y.backward() print(x.grad) # 输出梯度值 # TensorFlow方式 x tf.Variable(1.0) with tf.GradientTape() as tape: y x ** 2 print(tape.gradient(y, x)) # 输出梯度值2. 环境配置与数据集准备2.1 PyTorch环境搭建正确的环境配置是项目成功的第一步。PyTorch支持CPU和GPU两种计算模式对于猫狗分类这样的任务GPU加速能显著缩短训练时间。以下是配置PyTorch GPU环境的步骤确认显卡支持CUDANVIDIA显卡安装与显卡驱动匹配的CUDA工具包通过PyTorch官网获取对应CUDA版本的安装命令验证安装是否成功# 验证PyTorch是否识别到GPU python -c import torch; print(torch.cuda.is_available())提示使用Anaconda或Miniconda管理Python环境可以避免依赖冲突问题。对于教育用途PyCharm Professional提供了完善的PyTorch开发支持包括远程解释器配置和TensorBoard集成。2.2 数据集处理技巧Kaggle的猫狗数据集包含25000张图片但对于学习和快速验证来说这个规模可能过大。我们可以创建一个小型子集import os import shutil from tqdm import tqdm # 进度条工具 def create_mini_dataset(original_dir, target_dir, samples_per_class200): os.makedirs(target_dir, exist_okTrue) for class_name in [cats, dogs]: class_dir os.path.join(target_dir, class_name) os.makedirs(class_dir, exist_okTrue) # 从原始数据集中随机选取样本 src_dir os.path.join(original_dir, class_name) all_files os.listdir(src_dir) selected_files np.random.choice(all_files, samples_per_class, replaceFalse) # 复制文件 for filename in tqdm(selected_files, descfCopying {class_name}): src os.path.join(src_dir, filename) dst os.path.join(class_dir, filename) shutil.copyfile(src, dst)数据集组织建议采用以下结构Smalldata/ train/ cats/ cat.0.jpg ... dogs/ dog.0.jpg ... test/ cats/ ... dogs/ ...3. 构建PyTorch版VGG16模型3.1 模型架构解析VGG16是牛津大学视觉几何组提出的经典卷积神经网络其特点是使用连续的3×3卷积核代替大尺寸卷积核。在PyTorch中我们可以直接调用预训练的VGG16模型import torchvision.models as models # 加载预训练模型 model models.vgg16(pretrainedTrue) # 冻结卷积层参数 for param in model.parameters(): param.requires_grad False # 替换全连接层 model.classifier torch.nn.Sequential( torch.nn.Linear(25088, 256), torch.nn.ReLU(), torch.nn.Dropout(0.5), torch.nn.Linear(256, 2) )3.2 数据增强策略数据增强是提升模型泛化能力的关键。PyTorch的transforms模块提供了丰富的图像变换方法from torchvision import transforms train_transform transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) test_transform transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])4. 训练与评估流程4.1 训练循环实现PyTorch的训练循环比TensorFlow更加灵活直观。以下是典型的训练过程def train_model(model, criterion, optimizer, dataloaders, num_epochs10): best_acc 0.0 for epoch in range(num_epochs): print(fEpoch {epoch}/{num_epochs-1}) print(- * 10) # 每个epoch都有训练和验证阶段 for phase in [train, val]: if phase train: model.train() # 训练模式 else: model.eval() # 评估模式 running_loss 0.0 running_corrects 0 # 迭代数据 for inputs, labels in dataloaders[phase]: inputs inputs.to(device) labels labels.to(device) # 梯度清零 optimizer.zero_grad() # 前向传播 with torch.set_grad_enabled(phase train): outputs model(inputs) _, preds torch.max(outputs, 1) loss criterion(outputs, labels) # 反向传播优化仅在训练阶段 if phase train: loss.backward() optimizer.step() # 统计 running_loss loss.item() * inputs.size(0) running_corrects torch.sum(preds labels.data) epoch_loss running_loss / len(dataloaders[phase].dataset) epoch_acc running_corrects.double() / len(dataloaders[phase].dataset) print(f{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}) # 深度拷贝模型 if phase val and epoch_acc best_acc: best_acc epoch_acc best_model_wts copy.deepcopy(model.state_dict()) # 加载最佳模型权重 model.load_state_dict(best_model_wts) return model4.2 模型评估指标除了准确率我们还应该关注其他评估指标指标公式说明精确率TP/(TPFP)预测为正类中实际为正类的比例召回率TP/(TPFN)实际为正类中被正确预测的比例F1分数2*(精确率*召回率)/(精确率召回率)精确率和召回率的调和平均实现多指标评估from sklearn.metrics import classification_report def evaluate_model(model, dataloader): model.eval() all_preds [] all_labels [] with torch.no_grad(): for inputs, labels in dataloader: inputs inputs.to(device) labels labels.to(device) outputs model(inputs) _, preds torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) print(classification_report(all_labels, all_preds, target_names[cat, dog]))5. 模型部署与预测5.1 保存和加载模型PyTorch提供了灵活的模型保存方式# 保存整个模型 torch.save(model, model.pth) # 仅保存模型参数推荐 torch.save(model.state_dict(), model_weights.pth) # 加载模型 model models.vgg16() # 先初始化模型结构 model.load_state_dict(torch.load(model_weights.pth)) model.eval()5.2 实现预测接口一个完整的预测流程应该包括图像预处理、模型推理和后处理from PIL import Image import matplotlib.pyplot as plt def predict_image(image_path, model, transform, class_names): # 加载图像 img Image.open(image_path) # 预处理 img_t transform(img) batch_t torch.unsqueeze(img_t, 0) batch_t batch_t.to(device) # 预测 with torch.no_grad(): output model(batch_t) # 获取预测结果 _, pred torch.max(output, 1) class_name class_names[pred.item()] confidence torch.nn.functional.softmax(output, dim1)[0] * 100 # 可视化 plt.imshow(img) plt.title(fPrediction: {class_name} ({confidence[pred.item()]:.1f}%)) plt.axis(off) plt.show() return class_name在实际项目中遇到的典型问题包括图像尺寸不匹配、通道顺序错误等。一个健壮的预测接口应该包含异常处理def safe_predict(image_path, model, transform, class_names): try: img Image.open(image_path) if img.mode ! RGB: img img.convert(RGB) return predict_image(img, model, transform, class_names) except Exception as e: print(fPrediction failed: {str(e)}) return None

更多文章