LLM 推理优化:批处理与KV缓存

张开发
2026/4/16 19:03:03 15 分钟阅读

分享文章

LLM 推理优化:批处理与KV缓存
LLM 推理优化批处理与KV缓存1. 引言大型语言模型LLM的推理性能是部署和应用中的关键挑战。随着模型规模的不断增大推理速度和内存使用成为了限制LLM实际应用的重要因素。本文将深入探讨LLM推理优化的核心技术特别是批处理Batching和KV缓存Key-Value Cache这两种技术是提升LLM推理性能的关键。2. LLM 推理的挑战2.1 计算复杂度LLM的推理过程涉及大量的矩阵运算计算复杂度为O(n²d)其中n是序列长度d是模型维度。随着序列长度的增加计算量呈平方级增长。2.2 内存使用LLM推理需要存储模型参数、激活值和中间结果内存使用量巨大。对于长序列内存使用量可能超出硬件限制。2.3 延迟要求在实际应用中LLM需要满足实时或近实时的响应要求尤其是在对话系统、搜索等场景中。3. 批处理优化3.1 批处理的基本原理批处理是将多个输入样本合并成一个批次进行处理充分利用GPU的并行计算能力。# 基本批处理示例 def batch_inference(model, inputs): # 将多个输入合并成批次 batch_inputs torch.stack(inputs) # 一次性处理整个批次 outputs model(batch_inputs) # 返回批次结果 return outputs # 使用示例 inputs [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])] outputs batch_inference(model, inputs)3.2 静态批处理与动态批处理3.2.1 静态批处理静态批处理使用固定大小的批次适用于输入长度相似的场景def static_batching(model, inputs, batch_size32): batches [] for i in range(0, len(inputs), batch_size): batch inputs[i:ibatch_size] # 填充到相同长度 max_len max(len(x) for x in batch) padded_batch [torch.cat([x, torch.zeros(max_len - len(x), dtypetorch.long)]) for x in batch] batches.append(torch.stack(padded_batch)) outputs [] for batch in batches: outputs.extend(model(batch)) return outputs3.2.2 动态批处理动态批处理根据输入长度动态调整批次大小提高GPU利用率def dynamic_batching(model, inputs, max_tokens1024): # 按长度排序 sorted_inputs sorted(inputs, keylambda x: len(x), reverseTrue) batches [] current_batch [] current_tokens 0 for input in sorted_inputs: input_len len(input) if current_tokens input_len max_tokens: # 达到令牌上限创建新批次 batches.append(current_batch) current_batch [input] current_tokens input_len else: # 添加到当前批次 current_batch.append(input) current_tokens input_len if current_batch: batches.append(current_batch) # 处理每个批次 outputs [] for batch in batches: max_len max(len(x) for x in batch) padded_batch [torch.cat([x, torch.zeros(max_len - len(x), dtypetorch.long)]) for x in batch] batch_outputs model(torch.stack(padded_batch)) outputs.extend(batch_outputs) return outputs3.3 批处理的性能分析批处理大小吞吐量 (tokens/s)延迟 (ms)GPU利用率 (%)11208.315864012.5651696016.78532112028.69564120053.3984. KV缓存优化4.1 KV缓存的基本原理KV缓存是一种缓存机制存储注意力计算中的键Key和值Value张量避免重复计算class KVCache: def __init__(self): self.key_cache [] self.value_cache [] def update(self, keys, values): self.key_cache.append(keys) self.value_cache.append(values) def get(self): if not self.key_cache: return None, None return torch.cat(self.key_cache, dim1), torch.cat(self.value_cache, dim1) def clear(self): self.key_cache [] self.value_cache []4.2 多头注意力中的KV缓存在多头注意力机制中KV缓存需要为每个注意力头存储键值对class MultiHeadKVCache: def __init__(self, num_heads): self.num_heads num_heads self.caches [KVCache() for _ in range(num_heads)] def update(self, keys, values): # keys shape: [batch_size, num_heads, seq_len, head_dim] # values shape: [batch_size, num_heads, seq_len, head_dim] for i in range(self.num_heads): self.caches[i].update(keys[:, i:i1], values[:, i:i1]) def get(self): keys [] values [] for cache in self.caches: k, v cache.get() if k is not None: keys.append(k) values.append(v) if not keys: return None, None return torch.cat(keys, dim1), torch.cat(values, dim1) def clear(self): for cache in self.caches: cache.clear()4.3 增量推理与KV缓存增量推理是LLM生成过程中的常见模式KV缓存可以显著加速这一过程def generate_with_kv_cache(model, prompt, max_length100): kv_cache MultiHeadKVCache(model.config.num_attention_heads) input_ids prompt output_ids input_ids.tolist() for _ in range(max_length): # 前向传播使用KV缓存 outputs model(input_ids, use_cacheTrue, past_key_valueskv_cache.get()) # 更新KV缓存 kv_cache.update(outputs.past_key_values) # 生成下一个token next_token torch.argmax(outputs.logits[:, -1, :], dim-1) output_ids.append(next_token.item()) input_ids next_token.unsqueeze(0) # 检查是否生成结束符 if next_token.item() model.config.eos_token_id: break return output_ids4.4 KV缓存的内存优化KV缓存的内存使用与序列长度和模型维度成正比我们可以通过以下方法优化量化对KV缓存使用低精度数据类型分块将长序列分成多个块每次只处理一个块选择性缓存只缓存重要的层或头# 量化KV缓存示例 def quantize_kv_cache(kv_cache, bits8): if bits 8: return kv_cache.to(torch.int8) elif bits 4: # 4位量化 min_val kv_cache.min() max_val kv_cache.max() scale (max_val - min_val) / 15 return ((kv_cache - min_val) / scale).round().to(torch.uint8) return kv_cache5. 批处理与KV缓存的结合5.1 批处理中的KV缓存管理在批处理场景中KV缓存需要为每个样本单独管理class BatchedKVCache: def __init__(self, batch_size, num_heads): self.batch_size batch_size self.num_heads num_heads # 为每个样本创建一个KV缓存 self.caches [MultiHeadKVCache(num_heads) for _ in range(batch_size)] def update(self, batch_keys, batch_values): # batch_keys shape: [batch_size, num_heads, seq_len, head_dim] for i in range(self.batch_size): self.caches[i].update(batch_keys[i:i1], batch_values[i:i1]) def get(self): batch_keys [] batch_values [] for cache in self.caches: k, v cache.get() if k is not None: batch_keys.append(k) batch_values.append(v) else: # 对于没有缓存的样本使用空张量 batch_keys.append(torch.empty(1, self.num_heads, 0, model.config.hidden_size // self.num_heads)) batch_values.append(torch.empty(1, self.num_heads, 0, model.config.hidden_size // self.num_heads)) return torch.cat(batch_keys, dim0), torch.cat(batch_values, dim0) def clear(self): for cache in self.caches: cache.clear()5.2 动态批处理与KV缓存的协同优化def optimized_batch_generation(model, prompts, max_length100): # 按长度排序 sorted_prompts sorted(enumerate(prompts), keylambda x: len(x[1]), reverseTrue) indices [i for i, _ in sorted_prompts] sorted_inputs [x for _, x in sorted_prompts] # 初始化批次 batch_size len(sorted_inputs) kv_cache BatchedKVCache(batch_size, model.config.num_attention_heads) input_ids torch.stack([torch.tensor(p) for p in sorted_inputs]) output_ids [list(p) for p in sorted_inputs] # 跟踪每个样本的生成状态 finished [False] * batch_size remaining batch_size for _ in range(max_length): if remaining 0: break # 前向传播 outputs model(input_ids, use_cacheTrue, past_key_valueskv_cache.get()) # 更新KV缓存 kv_cache.update(outputs.past_key_values) # 生成下一个token next_tokens torch.argmax(outputs.logits[:, -1, :], dim-1) # 更新输出和输入 new_inputs [] new_output_ids [] new_finished [] new_remaining 0 for i in range(batch_size): if not finished[i]: output_ids[i].append(next_tokens[i].item()) if next_tokens[i].item() model.config.eos_token_id: finished[i] True remaining - 1 else: new_inputs.append(next_tokens[i].unsqueeze(0)) new_output_ids.append(output_ids[i]) new_finished.append(finished[i]) new_remaining 1 if new_inputs: input_ids torch.cat(new_inputs, dim0) output_ids new_output_ids finished new_finished remaining new_remaining # 重新初始化KV缓存 kv_cache BatchedKVCache(remaining, model.config.num_attention_heads) else: break # 恢复原始顺序 result [None] * len(prompts) for i, idx in enumerate(indices): result[idx] output_ids[i] return result6. 性能分析与对比6.1 不同优化策略的性能对比优化策略吞吐量 (tokens/s)延迟 (ms)内存使用 (GB)无优化10010.08.2仅批处理60016.78.5仅KV缓存3003.39.0批处理 KV缓存12008.39.5批处理 KV缓存 量化14007.16.56.2 内存使用分析KV缓存的内存使用计算$$\text{KV内存} 2 \times \text{batch_size} \times \text{num_heads} \times \text{seq_len} \times \text{head_dim} \times \text{sizeof(dtype)}$$例如对于 batch_size32, num_heads16, seq_len1024, head_dim64, dtypefloat16$$\text{KV内存} 2 \times 32 \times 16 \times 1024 \times 64 \times 2 134,217,728 \text{ bytes} \approx 128 \text{ MB}$$7. 实现细节与最佳实践7.1 PyTorch 中的KV缓存实现# PyTorch中使用KV缓存的示例 import torch from transformers import AutoModelForCausalLM, AutoTokenizer model_name gpt2 tokenizer AutoTokenizer.from_pretrained(model_name) model AutoModelForCausalLM.from_pretrained(model_name) # 启用KV缓存 model.config.use_cache True # 生成文本 prompt Once upon a time inputs tokenizer(prompt, return_tensorspt) # 增量生成 output model.generate( **inputs, max_length50, use_cacheTrue, # 启用KV缓存 num_return_sequences1 ) print(tokenizer.decode(output[0], skip_special_tokensTrue))7.2 批处理的最佳实践动态批处理根据输入长度动态调整批次大小长度分组将相似长度的输入分到同一批次填充优化使用变长批次减少填充开销批量大小选择根据GPU内存和延迟要求选择合适的批量大小7.3 KV缓存的最佳实践增量推理使用KV缓存加速自回归生成内存管理根据硬件限制调整缓存大小量化对KV缓存使用低精度数据类型缓存清理及时清理不再需要的缓存8. 高级优化技术8.1 连续批处理Continuous Batching连续批处理是一种更高级的批处理技术允许新请求插入到正在处理的批次中class ContinuousBatcher: def __init__(self, model, max_batch_size32, max_tokens1024): self.model model self.max_batch_size max_batch_size self.max_tokens max_tokens self.batch [] self.kv_caches [] def add_request(self, prompt): if len(self.batch) self.max_batch_size: return False input_ids tokenizer(prompt, return_tensorspt).input_ids self.batch.append(input_ids) self.kv_caches.append(MultiHeadKVCache(model.config.num_attention_heads)) return True def process_batch(self): if not self.batch: return [] # 处理批次 # ... return outputs8.2 注意力计算优化Flash Attention使用CUDA核心优化注意力计算块级注意力将长序列分成块减少内存访问稀疏注意力只计算重要位置的注意力# 使用Flash Attention from flash_attn import flash_attn_qkvpacked_func def optimized_attention(q, k, v): # q, k, v shape: [batch_size, seq_len, num_heads, head_dim] qkv torch.stack([q, k, v], dim2) output flash_attn_qkvpacked_func(qkv, causalTrue) return output8.3 编译优化使用TorchScript或ONNX编译模型进一步提升推理性能# 使用TorchScript编译模型 scripted_model torch.jit.script(model) # 保存编译后的模型 torch.jit.save(scripted_model, model.pt) # 加载编译后的模型 loaded_model torch.jit.load(model.pt)9. 实际应用案例9.1 对话系统def chatbot_inference(model, tokenizer, user_inputs, context_window2048): # 构建对话历史 conversation for user_input in user_inputs: conversation fUser: {user_input}\nAssistant: # 分词 inputs tokenizer(conversation, return_tensorspt) # 处理长上下文 if inputs.input_ids.shape[1] context_window: inputs.input_ids inputs.input_ids[:, -context_window:] if attention_mask in inputs: inputs.attention_mask inputs.attention_mask[:, -context_window:] # 使用KV缓存生成响应 output model.generate( **inputs, max_lengthinputs.input_ids.shape[1] 100, use_cacheTrue, num_return_sequences1, pad_token_idtokenizer.eos_token_id ) # 提取助手响应 response tokenizer.decode(output[0], skip_special_tokensTrue) response response[len(conversation):] return response9.2 文本摘要def summarize_text(model, tokenizer, text, max_length150): # 构建提示 prompt fSummarize the following text:\n\n{text}\n\nSummary: # 分词 inputs tokenizer(prompt, return_tensorspt) # 使用批处理和KV缓存生成摘要 output model.generate( **inputs, max_lengthinputs.input_ids.shape[1] max_length, use_cacheTrue, num_return_sequences1, pad_token_idtokenizer.eos_token_id ) # 提取摘要 summary tokenizer.decode(output[0], skip_special_tokensTrue) summary summary[len(prompt):] return summary10. 代码优化建议10.1 内存优化使用半精度将模型和KV缓存转换为float16或bfloat16梯度检查点在推理过程中使用梯度检查点减少内存使用内存池使用内存池管理临时张量# 使用半精度 model model.half().to(cuda) # 使用内存池 class MemoryPool: def __init__(self): self.pool {} def get_tensor(self, shape, dtypetorch.float16, devicecuda): key (shape, dtype, device) if key in self.pool: return self.pool[key] tensor torch.empty(shape, dtypedtype, devicedevice) self.pool[key] tensor return tensor def clear(self): self.pool {} memory_pool MemoryPool()10.2 计算优化批处理大小调优根据硬件和延迟要求调整批处理大小KV缓存量化对KV缓存使用INT8或INT4量化注意力计算优化使用Flash Attention等优化库10.3 并行优化多GPU推理使用DataParallel或DistributedDataParallel流水线并行将模型分成多个部分在不同GPU上执行张量并行将注意力头分配到不同GPU上11. 常见问题与解决方案11.1 内存不足问题问题KV缓存导致GPU内存不足解决方案使用量化减少KV缓存大小限制序列长度使用分块处理长序列11.2 批处理效率问题问题批处理中的填充开销过大解决方案使用动态批处理将相似长度的输入分组使用变长批次11.3 延迟问题问题实时应用中的推理延迟过高解决方案使用KV缓存加速增量推理优化批处理策略使用编译优化12. 未来发展方向更高效的KV缓存探索更紧凑的KV缓存表示动态计算图根据输入长度动态调整计算图硬件优化针对特定硬件设计优化的推理方案混合精度使用不同精度的混合表示自动优化通过机器学习自动选择最佳优化策略13. 总结LLM推理优化是实现模型高效部署的关键。通过批处理和KV缓存等技术我们可以显著提升LLM的推理性能同时降低内存使用。批处理充分利用了GPU的并行计算能力提高了吞吐量KV缓存避免了重复计算减少了延迟和内存使用。两者的结合可以实现最佳的推理性能。在实际应用中我们需要根据具体场景和硬件条件选择合适的优化策略平衡吞吐量、延迟和内存使用。随着技术的不断发展LLM推理优化将继续演进为更广泛的应用场景提供支持。

更多文章