Linux环境下PyTorch模型保存与加载
导读:在Linux环境下,使用PyTorch保存和加载模型非常简单。以下是一个简单的例子来说明如何保存和加载一个PyTorch模型。 首先,我们需要导入所需的库并定义一个简单的模型: import torch import torch.nn as...
在Linux环境下,使用PyTorch保存和加载模型非常简单。以下是一个简单的例子来说明如何保存和加载一个PyTorch模型。
首先,我们需要导入所需的库并定义一个简单的模型:
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
接下来,我们可以使用torch.save()
函数将模型保存到文件中:
torch.save(model, 'model.pth')
现在,模型已经被保存到了名为model.pth
的文件中。要加载模型,我们可以使用torch.load()
函数:
loaded_model = torch.load('model.pth')
加载模型后,我们可以像使用原始模型一样使用它:
input_data = torch.randn(1, 10)
output = loaded_model(input_data)
注意:在加载模型时,确保你的环境中已经安装了与保存模型时相同的PyTorch版本。否则,可能会出现兼容性问题。
声明:本文内容由网友自发贡献,本站不承担相应法律责任。对本内容有异议或投诉,请联系2913721942#qq.com核实处理,我们将尽快回复您,谢谢合作!
若转载请注明出处: Linux环境下PyTorch模型保存与加载
本文地址: https://pptw.com/jishu/722349.html