Ubuntu系统中PyTorch内存不足怎么解决
导读:Ubuntu下PyTorch内存不足的排查与优化 一 快速定位问题 明确是GPU显存不足还是系统内存不足:GPU常见报错为CUDA out of memory;系统内存不足常表现为进程被系统终止(如日志中出现Killed)。 GPU侧监控...
Ubuntu下PyTorch内存不足的排查与优化
一 快速定位问题
- 明确是GPU显存不足还是系统内存不足:GPU常见报错为CUDA out of memory;系统内存不足常表现为进程被系统终止(如日志中出现Killed)。
- GPU侧监控与诊断:使用nvidia-smi -l 1实时查看显存;在代码中用torch.cuda.memory_summary()、**torch.cuda.max_memory_allocated()**定位峰值与碎片。
- 系统侧排查:用free -h观察内存与swap使用情况,确认是否因物理内存不足触发OOM Killer。
二 训练阶段的高效优化
- 降低批量与梯度累积:优先减小batch_size;若需保持“大批次”效果,使用梯度累积在不增大显存的前提下稳定训练。
- 混合精度训练:启用torch.cuda.amp.autocast()与GradScaler(),在保持精度的同时通常可减少约**30%–50%**显存占用。
- 梯度检查点:对计算密集或占用激活多的模块使用torch.utils.checkpoint,以计算换显存,常见开销约15%–20%。
- 优化器与激活管理:在精度允许时以SGD替代Adam可显著降低优化器状态占用;必要时对部分激活/参数进行CPU卸载。
- 多卡与分片:超大规模模型采用FSDP或DDP进行张量/参数分片,显著降低单卡显存压力。
- 数据加载:设置pin_memory=True、合理的num_workers,避免数据管道成为瓶颈。
三 推理阶段与常见OOM场景
- 减小batch_size或输入分辨率;推理中同样可启用AMP降低显存占用。
- 及时清理:推理循环中删除不再使用的张量,调用**torch.cuda.empty_cache()与gc.collect()**释放未使用缓存与Python对象。
- 避免异步导致的释放无效:在需要释放前执行torch.cuda.synchronize(),确保计算完成后再清理。
- 使用torch.profiler定位显存峰值与热点算子,针对性优化。
四 系统与CUDA层面的调优
- 调整CUDA内存分配器:设置环境变量PYTORCH_CUDA_ALLOC_CONF,如max_split_size_mb:32,缓解大块分配失败与碎片问题。
- 清理缓存与计划缓存:在关键点调用torch.cuda.empty_cache();必要时清理cufft_plan_cache与cublas_lru_cache以减少残留占用。
- 谨慎清理系统缓存:仅在必要时以root执行echo 1/2/3 > /proc/sys/vm/drop_caches释放PageCache/Dentries/Inodes,避免影响系统性能。
- 增加Swap空间:当发生torch.save或数据预处理阶段被系统Killed时,扩充swap可作为临时兜底方案。
五 实用代码片段
- 混合精度训练模板
import torch
from torch.cuda.amp import autocast, GradScaler
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()
for data, target in dataloader:
optimizer.zero_grad()
with autocast():
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
- 自适应批量大小(示例思路)
def find_max_batch(model, input_shape, max_mem=8*1024**3):
batch_size = 1
while True:
try:
with torch.cuda.amp.autocast(enabled=True):
_ = model(torch.randn(*input_shape, device='cuda')[:batch_size])
used = torch.cuda.max_memory_allocated()
if used >
0.9 * max_mem:
return max(1, batch_size - 1)
batch_size *= 2
except RuntimeError:
return max(1, batch_size // 2)
- 显存清理与同步
import torch, gc
def clear_cache():
if torch.cuda.is_available():
torch.cuda.synchronize() # 等待所有流完成
torch.cuda.empty_cache()
gc.collect()
声明:本文内容由网友自发贡献,本站不承担相应法律责任。对本内容有异议或投诉,请联系2913721942#qq.com核实处理,我们将尽快回复您,谢谢合作!
若转载请注明出处: Ubuntu系统中PyTorch内存不足怎么解决
本文地址: https://pptw.com/jishu/778077.html
