从H5AD到空间感知scGPT:手把手复现与多任务训练实战

张开发
2026/4/20 7:22:01 15 分钟阅读

分享文章

从H5AD到空间感知scGPT:手把手复现与多任务训练实战
1. 空间转录组与scGPT技术背景单细胞转录组测序技术近年来快速发展为生物医学研究提供了前所未有的细胞分辨率。传统的单细胞RNA测序scRNA-seq虽然能够揭示细胞间的基因表达差异但丢失了细胞在组织中的空间位置信息。而空间转录组技术如10x Visium、MERFISH等的出现使得我们能够在保留空间位置信息的同时获取全转录组数据。H5AD文件格式已成为单细胞和空间转录组数据的标准存储格式它基于HDF5二进制格式能够高效存储大规模的稀疏矩阵数据。一个典型的H5AD文件包含X矩阵基因表达数据细胞×基因obs细胞注释信息如细胞类型、样本来源var基因注释信息如基因名称、特征选择标记obsm细胞维度嵌入数据如空间坐标、UMAP坐标scGPT是基于Transformer架构设计的单细胞数据分析模型其核心优势在于基因词表映射将基因名称转换为离散token类似自然语言处理中的单词编码多任务学习同时处理基因表达预测MLM、细胞表达补全MVC和空间信息辅助补全MVC Impute空间感知通过坐标信息增强模型对组织结构的理解2. 数据预处理实战2.1 H5AD文件读取与基础处理首先使用scanpy加载H5AD文件并进行基础质控import scanpy as sc # 读取H5AD文件 adata sc.read_h5ad(spatial_data.h5ad) # 基础质控 print(f原始数据维度: {adata.shape}) sc.pp.filter_cells(adata, min_genes200) # 过滤低质量细胞 sc.pp.filter_genes(adata, min_cells3) # 过滤低频基因 print(f质控后维度: {adata.shape})关键预处理步骤包括归一化处理使用CPMCounts Per Million校正测序深度差异对数变换稳定方差使数据更适合下游分析高变基因选择减少计算量聚焦信息量丰富的基因from scgpt.preprocess import Preprocessor preprocessor Preprocessor( normalize_total1e4, log1pTrue, subset_hvg2000 # 选择2000个高变基因 ) preprocessor(adata)2.2 空间坐标处理空间转录组数据的核心价值在于其坐标信息通常存储在obsm[spatial]中import pandas as pd # 检查空间坐标 if spatial in adata.obsm: coords pd.DataFrame(adata.obsm[spatial], columns[x, y], indexadata.obs_names) print(空间坐标示例) print(coords.head()) else: raise ValueError(未找到空间坐标信息)对于坐标数据的特殊处理坐标归一化不同样本间坐标尺度可能不同需统一到相同范围邻域构建基于空间坐标计算细胞邻域关系用于后续的KNN补全# 坐标归一化 coords (coords - coords.min()) / (coords.max() - coords.min()) adata.obsm[spatial_norm] coords.values3. 基因词表构建与数据转换3.1 基因到token的映射scGPT使用预定义的基因词表将基因名称映射为数字IDfrom scgpt.tokenizer import GeneVocab # 加载预训练词表 vocab GeneVocab.from_file(vocab.json) # 基因名称统一为大写 adata.var_names [gene.upper() for gene in adata.var_names] # 建立映射关系 adata.var[gene_id] [vocab[gene] for gene in adata.var_names] valid_genes adata.var[gene_id] 0 adata adata[:, valid_genes] # 过滤不在词表中的基因3.2 构建训练样本每个细胞需要转换为模型可接受的输入格式import torch import numpy as np def create_cell_example(adata, cell_idx): # 获取基因ID和表达值 gene_ids adata.var[gene_id].values expressions adata.X[cell_idx].toarray().flatten() # 构建样本字典 example { genes: torch.tensor(gene_ids, dtypetorch.long), expressions: torch.tensor(expressions, dtypetorch.float32), } # 添加空间坐标如果存在 if spatial_norm in adata.obsm: example[coordinates] torch.tensor( adata.obsm[spatial_norm][cell_idx], dtypetorch.float32 ) return example4. 模型训练与多任务整合4.1 scGPT模型架构scGPT的核心是一个多层Transformer编码器from scgpt.model import TransformerModel model TransformerModel( ntokenlen(vocab), # 词表大小 d_model512, # 嵌入维度 nhead8, # 注意力头数 d_hid2048, # 前馈网络维度 nlayers6, # Transformer层数 n_cls1, # 分类头数量 vocabvocab, # 基因词表 do_mvcTrue, # 启用表达补全任务 do_mvc_imputeTrue # 启用空间辅助补全 ).to(device)4.2 多任务损失函数scGPT同时优化三个任务的损失def compute_loss(outputs, inputs): # Masked Language Modeling损失 loss_mlm F.mse_loss( outputs[mlm_output], inputs[expr_values] ) # Masked Value Completion损失 loss_mvc F.mse_loss( outputs[mvc_output], inputs[expr_values] ) # 空间辅助补全损失 loss_mvci masked_mse_loss( outputs[impute_pred], inputs[expr_values], ~inputs[padding_mask] ) # 加权组合 total_loss loss_mlm 0.2 * loss_mvc 0.1 * loss_mvci return total_loss4.3 训练循环实现完整的训练流程包括数据加载、前向传播和参数更新from torch.utils.data import DataLoader from tqdm import tqdm # 构建DataLoader dataset SpatialDataset(adata, vocab) dataloader DataLoader(dataset, batch_size32, shuffleTrue) # 优化器设置 optimizer torch.optim.AdamW(model.parameters(), lr1e-4) for epoch in range(30): model.train() total_loss 0 for batch in tqdm(dataloader): # 数据转移到设备 genes batch[genes].to(device) exprs batch[expressions].to(device) coords batch.get(coordinates, None) # 前向传播 outputs model( srcgenes, valuesexprs, coordinatescoords ) # 计算损失 loss compute_loss(outputs, batch) # 反向传播 loss.backward() optimizer.step() optimizer.zero_grad() total_loss loss.item() print(fEpoch {epoch1}, Loss: {total_loss/len(dataloader):.4f})5. 常见问题与解决方案5.1 基因名称映射失败问题现象大量基因无法匹配到词表中的ID可能原因基因命名方式不一致ENSEMBL ID vs Symbol大小写不匹配物种不匹配人类vs小鼠解决方案# 使用mygene进行ID转换 import mygene mg mygene.MyGeneInfo() query_result mg.querymany( adata.var_names.tolist(), scopesensembl.gene, fieldssymbol, specieshuman ) # 构建映射字典 ensg_to_symbol { item[query]: item.get(symbol, None) for item in query_result if not item.get(notfound) }5.2 训练损失不收敛典型表现Loss初始值很高100训练过程中波动大后期下降缓慢调试策略检查数据归一化确保表达值经过log1p变换print(f表达值范围: {np.min(adata.X)} - {np.max(adata.X)})调整学习率尝试1e-5到1e-3之间的不同值验证损失计算单独检查每个任务的损失print(fMLM loss: {loss_mlm.item():.4f}) print(fMVC loss: {loss_mvc.item():.4f}) print(fImpute loss: {loss_mvci.item():.4f})5.3 空间信息利用不足优化方向增强坐标特征将原始坐标转换为相对位置特征def enhance_coordinates(coords): # 计算细胞间距离矩阵 dist torch.cdist(coords, coords) # 添加局部密度特征 k min(5, coords.shape[0]-1) knn_dist torch.topk(dist, kk1, largestFalse).values density 1.0 / (knn_dist[:, 1:].mean(dim1) 1e-6) return torch.cat([coords, density.unsqueeze(1)], dim1)调整损失权重提高空间相关任务的损失系数6. 进阶技巧与性能优化6.1 多GPU训练加速使用PyTorch的DistributedDataParallel实现多卡训练# 启动命令 torchrun --nproc_per_node4 train.py对应的训练脚本修改# 初始化分布式环境 def setup_ddp(): dist.init_process_group(nccl) local_rank int(os.environ[LOCAL_RANK]) torch.cuda.set_device(local_rank) return torch.device(fcuda:{local_rank}) # 模型包装 model DDP(model, device_ids[device.index])6.2 混合精度训练通过自动混合精度AMP减少显存占用from torch.cuda.amp import GradScaler, autocast scaler GradScaler() for batch in dataloader: optimizer.zero_grad() with autocast(): outputs model(**batch) loss compute_loss(outputs, batch) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()6.3 模型保存与加载保存检查点时包含完整训练状态checkpoint { model_state: model.state_dict(), optimizer_state: optimizer.state_dict(), epoch: epoch, loss: best_loss } torch.save(checkpoint, model_checkpoint.pt)加载时恢复训练checkpoint torch.load(model_checkpoint.pt) model.load_state_dict(checkpoint[model_state]) optimizer.load_state_dict(checkpoint[optimizer_state]) start_epoch checkpoint[epoch]在实际项目中空间感知的scGPT模型训练通常需要20-50个epoch才能达到理想效果。关键是要监控各个任务的损失变化当MVC损失开始稳定时可以适当降低学习率继续微调。

更多文章