首页主机资讯Linux中PyTorch的图形界面如何操作

Linux中PyTorch的图形界面如何操作

时间2025-11-25 18:14:03发布访客分类主机资讯浏览1268
导读:Linux下PyTorch图形界面操作指南 一 工具总览与定位 TensorBoard:训练过程指标与图结构可视化,浏览器查看(默认端口6006)。 Weights & Biases(W&B):云端实验追踪、团队协作与模型...

Linux下PyTorch图形界面操作指南

一 工具总览与定位

  • TensorBoard:训练过程指标与图结构可视化,浏览器查看(默认端口6006)。
  • Weights & Biases(W& B):云端实验追踪、团队协作与模型管理。
  • Netron:查看模型结构与张量形状,支持多框架模型文件。
  • PyTorchviz:用 Graphviz 绘制计算图,直观展示数据流向与依赖。
  • Torchinfo:模型层信息、参数量与输入输出形状汇总。
  • Matplotlib / Seaborn / Pandas:训练曲线、分布与表格可视化(本地静态图)。
  • Visdom:轻量级可视化服务,适合快速画线、图像、文本等面板。
  • Gradio / Streamlit:快速搭建模型Web Demo与应用界面,便于演示与交互。

二 训练监控与实验追踪

  • TensorBoard
    • 安装:pip install tensorboard
    • 记录:
      from torch.utils.tensorboard import SummaryWriter
      writer = SummaryWriter(log_dir="runs/exp1")
      for epoch in range(num_epochs):
          # ... 训练步骤 ...
          writer.add_scalar('Loss/train', loss, epoch)
          writer.add_scalar('Acc/train', acc, epoch)
      writer.close()
      
    • 启动与查看:tensorboard --logdir=runs,浏览器访问 http://localhost:6006
  • W& B
    • 特点:实验跟踪、超参与权重云端保存、团队协作;在 PyTorch 训练脚本中按官方 API 初始化与打点即可,适合多人协作与云端管理。
  • Visdom
    • 启动服务:python -m visdom.server(默认端口8097
    • 画线示例:
      from visdom import Visdom
      vis = Visdom()
      vis.line([0], [0], win='loss', opts=dict(title='Training Loss'))
      for epoch in range(100):
          loss = ...  # 计算损失
          vis.line([loss], [epoch], win='loss', update='append')
      
    • 常用绘图函数:vis.imagevis.imagesvis.textvis.matplot 等。

三 模型结构与参数查看

  • Netron
    • 安装:pip install netron
    • 使用:
      import torch, torchvision
      model = torchvision.models.resnet18(pretrained=True)
      torch.save(model.state_dict(), "resnet18.pth")
      # 终端启动
      netron resnet18.pth
      
    • 浏览器将自动打开 http://localhost:8080 展示网络结构与张量形状。
  • PyTorchviz
    • 安装:pip install torchviz
    • 使用:
      import torch, torchvision
      from torchviz import make_dot
      model = torchvision.models.resnet18(pretrained=True)
      x = torch.randn(1, 3, 224, 224)
      y = model(x)
      dot = make_dot(y, params=dict(model.named_parameters()))
      dot.render("resnet18_graph", format="png")  # 生成 PNG
      
    • 生成的计算图可保存为 PDF/PNG 并离线查看。
  • Torchinfo
    • 安装:pip install torchinfo
    • 使用:
      from torchinfo import summary
      from torchvision.models import resnet18
      model = resnet18()
      summary(model, input_size=(1, 3, 224, 224))
      
    • 输出包含每层的输出形状、参数量与可训练状态。

四 结果绘图与本地分析

  • Matplotlib / Seaborn / Pandas
    • 绘制训练/验证损失曲线:
      import matplotlib.pyplot as plt
      import seaborn as sns
      import pandas as pd
      
      epochs = list(range(1, num_epochs+1))
      df = pd.DataFrame({
      'Epoch': epochs,
                        'Train Loss': train_losses,
                        'Val Loss': val_losses}
          )
      
      # Matplotlib
      plt.plot(epochs, train_losses, label='Train')
      plt.plot(epochs, val_losses, label='Val')
      plt.xlabel('Epoch');
           plt.ylabel('Loss');
           plt.legend();
           plt.show()
      
      # Seaborn
      sns.lineplot(data=df, x='Epoch', y='Train Loss');
           plt.show()
      sns.histplot(train_losses, kde=True);
       plt.show()
      
    • 适合在 Jupyter Notebook/Lab 中交互式展示与保存图表。

五 部署交互式Web界面

  • Gradio
    • 安装:pip install gradio
    • 示例(图像增强演示):
      import gradio as gr
      
      def enhance_image(img):
          # img: PIL.Image 或 ndarray,做你的增强逻辑
          return img  # 这里直接回显示例
      
      inputs  = gr.Image(type="pil", label="输入图像")
      outputs = gr.Image(type="pil", label="增强结果")
      demo = gr.Interface(fn=enhance_image, inputs=inputs, outputs=outputs,
                           title="图像增强演示")
      demo.launch(share=True)  # 生成可分享链接
      
  • Streamlit
    • 安装:pip install streamlit
    • 示例(MNIST 推理演示):
      import streamlit as st
      import torch, torch.nn as nn
      from torchvision import transforms
      from PIL import Image
      
      # 假设已训练好并保存 state_dict: mnist_cnn.pth
      class CNN(nn.Module):
          def __init__(self):
              super().__init__()
              self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
              self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
              self.fc1 = nn.Linear(32*7*7, 128)
              self.fc2 = nn.Linear(128, 10)
          def forward(self, x):
              x = nn.functional.relu(self.conv1(x))
              x = nn.functional.max_pool2d(x, 2)
              x = nn.functional.relu(self.conv2(x))
              x = nn.functional.max_pool2d(x, 2)
              x = x.view(x.size(0), -1)
              x = nn.functional.relu(self.fc1(x))
              return self.fc2(x)
      
      @st.cache_resource
      def load_model():
          m = CNN()
          m.load_state_dict(torch.load("mnist_cnn.pth", map_location="cpu"))
          m.eval()
          return m
      
      transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.1307,), (0.3081,))
      ])
      
      st.title("MNIST 手写体识别")
      uploaded = st.file_uploader("上传灰度图", type=["png","jpg","jpeg"])
      if uploaded:
          img = Image.open(uploaded).convert("L").resize((28,28))
          st.image(img, caption="输入", width=150)
          x = transform(img).unsqueeze(0)
          model = load_model()
          with torch.no_grad():
              pred = model(x).argmax(1).item()
          st.write(f"预测数字:{
      pred}
          ")
      
    • 运行:streamlit run app.py,浏览器自动打开本地页面。

声明:本文内容由网友自发贡献,本站不承担相应法律责任。对本内容有异议或投诉,请联系2913721942#qq.com核实处理,我们将尽快回复您,谢谢合作!


若转载请注明出处: Linux中PyTorch的图形界面如何操作
本文地址: https://pptw.com/jishu/755788.html
Linux中PyTorch依赖怎么配置 Linux Oracle集群搭建实践

游客 回复需填写必要信息