model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth')) # This line uses .load() to read a .pth file and load the network weights on to the architecture.
model.eval() # enabling the eval mode to test with new samples.
import torch
from torch_model import Model # Made up package
# select gpu when available, else work with cpu resources
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model()
model.load_state_dict(torch.load('weights.pt'))
model = model.to(device) # Set model to gpu
model.eval();
inputs = torch.random.randn(1, 3, 224, 224) # Dtype is fp32
inputs = inputs.to(device) # You can move your input to gpu, torch defaults to cpu
# Run forward pass
with torch.no_grad():
pred = model(inputs)
# Do something with pred
pred = pred.detach().cpu().numpy() # remove from computational graph to cpu and as numpy
2条答案
按热度按时间brtdzjyr1#
要使用预训练模型,您应该将状态加载到架构的新示例上,如docs/tutorials中所述:
这里
models
是预先导入的:字符串
如果您使用的是自定义架构,则只需更改第一行。
型
启用
eval
模式后,您可以执行以下操作:Dataset
示例中,然后加载到DataLoader
示例中。更多关于
Dataset
和DataLoader
here。1zmg4dgp2#
对于预测来说,有一种叫做向前传球的东西,
字符串