保姆级教程:用PyTorch和TorchText搞定AG_NEWS新闻分类(附完整代码)

张开发
2026/4/15 7:31:34 15 分钟阅读

分享文章

保姆级教程:用PyTorch和TorchText搞定AG_NEWS新闻分类(附完整代码)
从零构建新闻分类器PyTorchTorchText实战AG_NEWS数据集当你第一次接触自然语言处理NLP时文本分类可能是最直观的入门项目。今天我们就用PyTorch和TorchText这两个强大的工具从零开始构建一个新闻分类器。这个教程会手把手带你走过每一个环节包括数据加载、预处理、模型构建、训练和评估确保即使是没有经验的新手也能顺利跑通整个流程。1. 环境准备与数据集介绍在开始之前我们需要确保开发环境配置正确。推荐使用Python 3.8和最新稳定版的PyTorch。可以通过以下命令安装必要的库pip install torch torchtextAG_NEWS是学术界常用的新闻分类基准数据集包含4个类别世界新闻World体育新闻Sports商业新闻Business科技新闻Sci/Tech数据集中的每条样本都由类别标签和新闻文本组成。原始数据集以CSV格式存储但幸运的是TorchText已经内置了对这个数据集的支持我们可以直接使用。提示如果你在国内网络环境下遇到下载问题可以尝试设置代理或手动下载数据集后指定本地路径。2. 数据加载与初步探索让我们首先加载数据集并查看其结构import torch from torchtext.datasets import AG_NEWS # 加载训练集和测试集 train_iter, test_iter AG_NEWS(root./data, split(train, test)) # 查看前几个样本 for i, (label, text) in enumerate(train_iter): if i 3: break print(fLabel: {label}, Text: {text[:100]}...)这段代码会输出训练集的前三个样本。你可能注意到标签是数字形式对应关系如下1: 世界新闻2: 体育新闻3: 商业新闻4: 科技新闻3. 数据预处理与词表构建文本数据不能直接输入模型需要转换为数值形式。这一过程包括分词、构建词表和数值化。3.1 分词与词表构建TorchText提供了多种分词器我们使用基础的英文分词器from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator # 获取分词器 tokenizer get_tokenizer(basic_english) # 构建词表 def yield_tokens(data_iter): for _, text in data_iter: yield tokenizer(text) vocab build_vocab_from_iterator(yield_tokens(train_iter), specials[unk]) vocab.set_default_index(vocab[unk])3.2 文本向量化与批处理我们需要将文本转换为统一的张量格式并处理变长序列问题import torch from torch.nn.utils.rnn import pad_sequence # 文本处理流水线 text_pipeline lambda x: vocab(tokenizer(x)) label_pipeline lambda x: int(x) - 1 # 批处理函数 def collate_batch(batch): label_list, text_list [], [] for (_label, _text) in batch: label_list.append(label_pipeline(_label)) processed_text torch.tensor(text_pipeline(_text), dtypetorch.int64) text_list.append(processed_text) return torch.tensor(label_list), pad_sequence(text_list, padding_value0)4. 模型设计与实现我们将实现一个简单的嵌入平均池化模型这是文本分类的经典基线方法。4.1 模型架构import torch.nn as nn class TextClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, num_class): super().__init__() self.embedding nn.EmbeddingBag(vocab_size, embed_dim, sparseTrue) self.fc nn.Linear(embed_dim, num_class) def forward(self, text, offsets): embedded self.embedding(text, offsets) return self.fc(embedded)4.2 模型初始化vocab_size len(vocab) embed_dim 64 num_class 4 model TextClassifier(vocab_size, embed_dim, num_class)5. 训练与评估5.1 训练过程from torch.utils.data import DataLoader # 准备数据加载器 train_loader DataLoader(list(train_iter), batch_size8, collate_fncollate_batch) # 定义损失函数和优化器 criterion torch.nn.CrossEntropyLoss() optimizer torch.optim.SGD(model.parameters(), lr4.0) # 训练函数 def train(dataloader): model.train() total_acc, total_count 0, 0 for idx, (label, text) in enumerate(dataloader): optimizer.zero_grad() predicted_label model(text, torch.zeros(len(text), dtypetorch.int64)) loss criterion(predicted_label, label) loss.backward() optimizer.step() total_acc (predicted_label.argmax(1) label).sum().item() total_count label.size(0) if idx % 500 0: print(f| {idx:5d} batches | accuracy {total_acc/total_count:.3f})5.2 评估过程def evaluate(dataloader): model.eval() total_acc, total_count 0, 0 with torch.no_grad(): for label, text in dataloader: predicted_label model(text, torch.zeros(len(text), dtypetorch.int64)) total_acc (predicted_label.argmax(1) label).sum().item() total_count label.size(0) return total_acc / total_count6. 完整训练流程现在我们把所有部分组合起来进行完整的训练和评估import time # 训练多个epoch for epoch in range(1, 6): epoch_start_time time.time() train(train_loader) accu_val evaluate(test_loader) print(fEpoch: {epoch}, |time: {time.time()-epoch_start_time:.1f}s|, val_accuracy: {accu_val:.3f})7. 模型优化与技巧7.1 学习率调整学习率对模型性能影响很大我们可以实现学习率衰减scheduler torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma0.9)7.2 使用预训练词向量提升模型性能的一个简单方法是使用预训练词向量from torchtext.vocab import GloVe # 使用GloVe词向量初始化嵌入层 pretrained_embeddings GloVe(name6B, dim100) model.embedding.weight.data.copy_(pretrained_embeddings.get_vecs_by_tokens(vocab.get_itos()))8. 实际应用与预测训练好的模型可以用来预测新的新闻类别def predict(text): with torch.no_grad(): text torch.tensor(text_pipeline(text)) output model(text, torch.tensor([0])) return output.argmax(1).item() 1 # 示例预测 sample_text Apple announced new products at their annual developer conference print(fPredicted category: {predict(sample_text)})在实际项目中我发现嵌入维度设置为64-128之间通常能取得不错的效果同时不会显著增加计算负担。对于更复杂的任务可以考虑使用LSTM或Transformer架构但对于这个简单的新闻分类任务我们的基线模型已经能达到约85%的准确率。

更多文章