pytorch 运行时错误:有办法在CPU上torch.load(model.pt)吗?

jdzmm42g  于 2023-04-06  发布在  其他
关注(0)|答案(1)|浏览(145)

对于我的GPU训练模型的使用,我想在我的CPU上运行它。如果我理解正确,已经有一个解决方案,那就是保存和加载模型的参数(或 state_dict)。加载过程(这是我在这里感兴趣的)将像这样进行:

device = torch.device('cpu')
state_dict = torch.load(self.model_state_dict_path, map_location=self.device)

model = ModelClass(*model_params*)
model.load_state_dict(state_dict)
model.to(device)

,看起来效果很好。
不过,我想知道是否也可以通过保存/加载整个模型而不是其参数来实现相同的功能。尝试在CPU上加载GPU训练的模型,我做了以下操作:

def __init__(self, entire_model_path):
        self.model_path = entire_model_path
        self.__device = device('cpu')
        self.__model = load(self.model_path, map_location=self.__device)

,这会引发以下错误:

Traceback (most recent call last):
  File "predictor.py", line 138, in <module>
    p.predict_masks_from_directory(config.image_path, config.mask_path)
  File "predictor.py", line 25, in predict_masks_from_directory
    pred = self.predict_mask_from_imagepath(imagepath)
  File "predictor.py", line 37, in predict_mask_from_imagepath
    pred_patches.append(np.moveaxis(self.predict_segmentation_from_image(img)[0], 0, -1))
  File "predictor.py", line 53, in predict_segmentation_from_image
    segmentation_result = sigmoid(self.__unet(image))
  File "C:\Users\bboche\Anaconda3\envs\UNet\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\bboche\Anaconda3\envs\UNet\lib\site-packages\torch\nn\parallel\data_parallel.py", line 155, in forward
    "them on device: {}".format(self.src_device_obj, t.device))
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

是不是不可能以这种方式加载整个模型,因为它可能保存了一些特定于GPU的数据?
还有:加载整个模型而不是它的状态字典是可取的吗?

myzjeezk

myzjeezk1#

不确定你的问题是否是因为使用TensorFlow SavedModel的习惯,但PyTorch通常是关于保存/加载参数值的。
将模型移动到设备实际上是将其所有参数(值和梯度)移动到目标设备。因此,除了对您来说非常耗时外,最好的选择通常是:

  • 在CPU上示例化模型
  • 在CPU上加载检查点
  • 将检查点中的参数值加载到模型中
  • 将模型移动到目标设备
import torch
model = ...
state_dict = torch.load("path/to/your/checkpoint.pt", map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(device=...)

现在,如果你想保存整个模型(架构+它的权重),你可以这样做(参见https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-entire-model):

import torch
model = ...
# Save the architecture and the weights
torch.save(model, "path/to/your/checkpoint.pt")

# Load it
restored_model = torch.load("path/to/your/checkpoint.pt")

相关问题