Linux中PyTorch的图形界面如何操作
导读:Linux下PyTorch图形界面操作指南 一 工具总览与定位 TensorBoard:训练过程指标与图结构可视化,浏览器查看(默认端口6006)。 Weights & Biases(W&B):云端实验追踪、团队协作与模型...
Linux下PyTorch图形界面操作指南
一 工具总览与定位
- TensorBoard:训练过程指标与图结构可视化,浏览器查看(默认端口6006)。
- Weights & Biases(W& B):云端实验追踪、团队协作与模型管理。
- Netron:查看模型结构与张量形状,支持多框架模型文件。
- PyTorchviz:用 Graphviz 绘制计算图,直观展示数据流向与依赖。
- Torchinfo:模型层信息、参数量与输入输出形状汇总。
- Matplotlib / Seaborn / Pandas:训练曲线、分布与表格可视化(本地静态图)。
- Visdom:轻量级可视化服务,适合快速画线、图像、文本等面板。
- Gradio / Streamlit:快速搭建模型Web Demo与应用界面,便于演示与交互。
二 训练监控与实验追踪
- TensorBoard
- 安装:
pip install tensorboard - 记录:
from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(log_dir="runs/exp1") for epoch in range(num_epochs): # ... 训练步骤 ... writer.add_scalar('Loss/train', loss, epoch) writer.add_scalar('Acc/train', acc, epoch) writer.close() - 启动与查看:
tensorboard --logdir=runs,浏览器访问 http://localhost:6006。
- 安装:
- W&
B
- 特点:实验跟踪、超参与权重云端保存、团队协作;在 PyTorch 训练脚本中按官方 API 初始化与打点即可,适合多人协作与云端管理。
- Visdom
- 启动服务:
python -m visdom.server(默认端口8097) - 画线示例:
from visdom import Visdom vis = Visdom() vis.line([0], [0], win='loss', opts=dict(title='Training Loss')) for epoch in range(100): loss = ... # 计算损失 vis.line([loss], [epoch], win='loss', update='append') - 常用绘图函数:
vis.image、vis.images、vis.text、vis.matplot等。
- 启动服务:
三 模型结构与参数查看
- Netron
- 安装:
pip install netron - 使用:
import torch, torchvision model = torchvision.models.resnet18(pretrained=True) torch.save(model.state_dict(), "resnet18.pth") # 终端启动 netron resnet18.pth - 浏览器将自动打开 http://localhost:8080 展示网络结构与张量形状。
- 安装:
- PyTorchviz
- 安装:
pip install torchviz - 使用:
import torch, torchvision from torchviz import make_dot model = torchvision.models.resnet18(pretrained=True) x = torch.randn(1, 3, 224, 224) y = model(x) dot = make_dot(y, params=dict(model.named_parameters())) dot.render("resnet18_graph", format="png") # 生成 PNG - 生成的计算图可保存为 PDF/PNG 并离线查看。
- 安装:
- Torchinfo
- 安装:
pip install torchinfo - 使用:
from torchinfo import summary from torchvision.models import resnet18 model = resnet18() summary(model, input_size=(1, 3, 224, 224)) - 输出包含每层的输出形状、参数量与可训练状态。
- 安装:
四 结果绘图与本地分析
- Matplotlib / Seaborn / Pandas
- 绘制训练/验证损失曲线:
import matplotlib.pyplot as plt import seaborn as sns import pandas as pd epochs = list(range(1, num_epochs+1)) df = pd.DataFrame({ 'Epoch': epochs, 'Train Loss': train_losses, 'Val Loss': val_losses} ) # Matplotlib plt.plot(epochs, train_losses, label='Train') plt.plot(epochs, val_losses, label='Val') plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.show() # Seaborn sns.lineplot(data=df, x='Epoch', y='Train Loss'); plt.show() sns.histplot(train_losses, kde=True); plt.show() - 适合在 Jupyter Notebook/Lab 中交互式展示与保存图表。
- 绘制训练/验证损失曲线:
五 部署交互式Web界面
- Gradio
- 安装:
pip install gradio - 示例(图像增强演示):
import gradio as gr def enhance_image(img): # img: PIL.Image 或 ndarray,做你的增强逻辑 return img # 这里直接回显示例 inputs = gr.Image(type="pil", label="输入图像") outputs = gr.Image(type="pil", label="增强结果") demo = gr.Interface(fn=enhance_image, inputs=inputs, outputs=outputs, title="图像增强演示") demo.launch(share=True) # 生成可分享链接
- 安装:
- Streamlit
- 安装:
pip install streamlit - 示例(MNIST 推理演示):
import streamlit as st import torch, torch.nn as nn from torchvision import transforms from PIL import Image # 假设已训练好并保存 state_dict: mnist_cnn.pth class CNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 16, 3, padding=1) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.fc1 = nn.Linear(32*7*7, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = nn.functional.relu(self.conv1(x)) x = nn.functional.max_pool2d(x, 2) x = nn.functional.relu(self.conv2(x)) x = nn.functional.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = nn.functional.relu(self.fc1(x)) return self.fc2(x) @st.cache_resource def load_model(): m = CNN() m.load_state_dict(torch.load("mnist_cnn.pth", map_location="cpu")) m.eval() return m transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) st.title("MNIST 手写体识别") uploaded = st.file_uploader("上传灰度图", type=["png","jpg","jpeg"]) if uploaded: img = Image.open(uploaded).convert("L").resize((28,28)) st.image(img, caption="输入", width=150) x = transform(img).unsqueeze(0) model = load_model() with torch.no_grad(): pred = model(x).argmax(1).item() st.write(f"预测数字:{ pred} ") - 运行:
streamlit run app.py,浏览器自动打开本地页面。
- 安装:
声明:本文内容由网友自发贡献,本站不承担相应法律责任。对本内容有异议或投诉,请联系2913721942#qq.com核实处理,我们将尽快回复您,谢谢合作!
若转载请注明出处: Linux中PyTorch的图形界面如何操作
本文地址: https://pptw.com/jishu/755788.html
