numpy 将.npy加载到torch中会抛出缺少键的错误

7ivaypg9  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(85)

这是我的代码,我试图加载一个.npy文件到 Torch :

# Load the state dictionary from the numpy file
state_dict_np = np.load(model_path, allow_pickle=True).item()

# Convert numpy arrays within the state dictionary to PyTorch tensors
state_dict_torch = {k: torch.tensor(v, dtype=torch.float32).cpu() for k, v in state_dict_np.items()}

# Load the converted state dictionary into the model
self.model.load_state_dict(state_dict_torch)

这是我遇到的错误:

Error while subprocess initialization: Traceback (most recent call last):
  File "/app/core/joblib/SubprocessorBase.py", line 62, in _subprocess_run
    self.on_initialize(client_dict)
  File "/app/mainscripts/Extractor.py", line 73, in on_initialize
    self.rects_extractor = facelib.S3FDExtractor(place_model_on_cpu=place_model_on_cpu)
  File "/app/facelib/S3FDExtractor.py", line 156, in __init__
    self.model.load_state_dict(state_dict_torch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for S3FD:
        Missing key(s) in state_dict: "conv1_1.weight", "conv1_1.bias", "conv1_2.weight", "conv1_2.bias", "conv2_1.weight", "conv2_1.bias", "conv2_2.weight", "conv2_2.bias", "conv3_1.weight", "conv3_1.bias", "conv3_2.weight", "conv3_2.bias", "conv3_3.weight", "conv3_3.bias", "conv4_1.weight", "conv4_1.bias", "conv4_2.weight", "conv4_2.bias", "conv4_3.weight", "conv4_3.bias", "conv5_1.weight", "conv5_1.bias", "conv5_2.weight", "conv5_2.bias", "conv5_3.weight", "conv5_3.bias", "fc6.weight", "fc6.bias", "fc7.weight", "fc7.bias", "conv6_1.weight", "conv6_1.bias", "conv6_2.weight", "conv6_2.bias", "conv7_1.weight", "conv7_1.bias", "conv7_2.weight", "conv7_2.bias", "conv3_3_norm.weight", "conv4_3_norm.weight", "conv5_3_norm.weight", "conv3_3_norm_mbox_conf.weight", "conv3_3_norm_mbox_conf.bias", "conv3_3_norm_mbox_loc.weight", "conv3_3_norm_mbox_loc.bias", "conv4_3_norm_mbox_conf.weight", "conv4_3_norm_mbox_conf.bias", "conv4_3_norm_mbox_loc.weight", "conv4_3_norm_mbox_loc.bias", "conv5_3_norm_mbox_conf.weight", "conv5_3_norm_mbox_conf.bias", "conv5_3_norm_mbox_loc.weight", "conv5_3_norm_mbox_loc.bias", "fc7_mbox_conf.weight", "fc7_mbox_conf.bias", "fc7_mbox_loc.weight", "fc7_mbox_loc.bias", "conv6_2_mbox_conf.weight", "conv6_2_mbox_conf.bias", "conv6_2_mbox_loc.weight", "conv6_2_mbox_loc.bias", "conv7_2_mbox_conf.weight", "conv7_2_mbox_conf.bias", "conv7_2_mbox_loc.weight", "conv7_2_mbox_loc.bias". 
        Unexpected key(s) in state_dict: "conv1_1/weight:0", "conv1_1/bias:0", "conv1_2/weight:0", "conv1_2/bias:0", "conv2_1/weight:0", "conv2_1/bias:0", "conv2_2/weight:0", "conv2_2/bias:0", "conv3_1/weight:0", "conv3_1/bias:0", "conv3_2/weight:0", "conv3_2/bias:0", "conv3_3/weight:0", "conv3_3/bias:0", "conv4_1/weight:0", "conv4_1/bias:0", "conv4_2/weight:0", "conv4_2/bias:0", "conv4_3/weight:0", "conv4_3/bias:0", "conv5_1/weight:0", "conv5_1/bias:0", "conv5_2/weight:0", "conv5_2/bias:0", "conv5_3/weight:0", "conv5_3/bias:0", "fc6/weight:0", "fc6/bias:0", "fc7/weight:0", "fc7/bias:0", "conv6_1/weight:0", "conv6_1/bias:0", "conv6_2/weight:0", "conv6_2/bias:0", "conv7_1/weight:0", "conv7_1/bias:0", "conv7_2/weight:0", "conv7_2/bias:0", "conv3_3_norm/weight:0", "conv4_3_norm/weight:0", "conv5_3_norm/weight:0", "conv3_3_norm_mbox_conf/weight:0", "conv3_3_norm_mbox_conf/bias:0", "conv3_3_norm_mbox_loc/weight:0", "conv3_3_norm_mbox_loc/bias:0", "conv4_3_norm_mbox_conf/weight:0", "conv4_3_norm_mbox_conf/bias:0", "conv4_3_norm_mbox_loc/weight:0", "conv4_3_norm_mbox_loc/bias:0", "conv5_3_norm_mbox_conf/weight:0", "conv5_3_norm_mbox_conf/bias:0", "conv5_3_norm_mbox_loc/weight:0", "conv5_3_norm_mbox_loc/bias:0", "fc7_mbox_conf/weight:0", "fc7_mbox_conf/bias:0", "fc7_mbox_loc/weight:0", "fc7_mbox_loc/bias:0", "conv6_2_mbox_conf/weight:0", "conv6_2_mbox_conf/bias:0", "conv6_2_mbox_loc/weight:0", "conv6_2_mbox_loc/bias:0", "conv7_2_mbox_conf/weight:0", "conv7_2_mbox_conf/bias:0", "conv7_2_mbox_loc/weight:0", "conv7_2_mbox_loc/bias:0".

以前我试着这样重命名键:

renamed_state_dict = {}
for key, value in state_dict.items():
    new_key = key.split("/")[0]
    if "weight" in key:
        new_key += ".weight"
    elif "bias" in key:
        new_key += ".bias"
    renamed_state_dict[new_key] = value

self.model.load_state_dict(renamed_state_dict)

但后来得到了这个错误:

Error while subprocess initialization: Traceback (most recent call last):
  File "/app/core/joblib/SubprocessorBase.py", line 62, in _subprocess_run
    self.on_initialize(client_dict)
  File "/app/mainscripts/Extractor.py", line 73, in on_initialize
    self.rects_extractor = facelib.S3FDExtractor(place_model_on_cpu=place_model_on_cpu)
  File "/app/facelib/S3FDExtractor.py", line 165, in __init__
    self.model.load_state_dict(renamed_state_dict)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for S3FD:
        While copying the parameter named "conv1_1.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'numpy.ndarray'>
        While copying the parameter named "conv1_1.bias", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'numpy.ndarray'>
        While copying the parameter named "conv1_2.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'numpy.ndarray'>
        While copying the parameter named "conv1_2.bias", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'numpy.ndarray'>
        ...
        While copying the parameter named "conv3_1.weight", expected torch.Tensor or Tensor-like object from checkpoint but received <class 'numpy.ndarray'>

我哪里做错了?

更新

根据用户的建议,我尝试了这个:

def can_squeeze(t):
    shape_set = set(t.shape)
    if len(shape_set) == 2 and 1 in shape_set:
        return True
    return False

def reshape_tensor(t):
    if can_squeeze(t):
        return t.squeeze()
    else:
        return torch.permute(t, [3, 2, 1, 0])

def rename_key(key):
    new_key = key.split(":")[0] # discard :0 and similar
    key_elements = new_key.split("/")
    new_key = ".".join(key_elements) # replace every / with .
    return new_key

# Load the state dictionary from the numpy file
state_dict_np = np.load(model_path, allow_pickle=True).item()

# Convert numpy arrays within the state dictionary to PyTorch tensors
state_dict_torch = {rename_key(k): reshape_tensor(torch.tensor(v, dtype=torch.float32)).cpu() for k, v in state_dict_np.items()}

# Load the converted state dictionary into the model
self.model.load_state_dict(state_dict_torch)

得到了这个错误:

Error while subprocess initialization: Traceback (most recent call last):
  File "/app/core/joblib/SubprocessorBase.py", line 62, in _subprocess_run
    self.on_initialize(client_dict)
  File "/app/mainscripts/Extractor.py", line 73, in on_initialize
    self.rects_extractor = facelib.S3FDExtractor(place_model_on_cpu=place_model_on_cpu)
  File "/app/facelib/S3FDExtractor.py", line 172, in __init__
    self.model.load_state_dict(state_dict_torch)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for S3FD:
        size mismatch for fc7.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 1, 1]).
        size mismatch for conv3_3_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1, 256, 1, 1]).
        size mismatch for conv4_3_norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1, 512, 1, 1]).
        size mismatch for conv5_3_norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1, 512, 1, 1]).
ulmd4ohb

ulmd4ohb1#

解决方案是将您的两个想法合并结合起来,因为它们解决了两个不同的问题:一个是重命名键,另一个是将numpy数组转换为torchTensor。这可能会起作用:

def can_squeeze(t):
    shape_set = set(t.shape)
    if len(shape_set) == 2 and 1 in shape_set:
        return True
    return False

def reshape_tensor(k, t):
    if k == 'fc7.weight':
        return t.reshape([1024, 1024, 1, 1])
    if k in ['conv4_3_norm.weight', 'conv5_3_norm.weight']:
        return t.reshape([1, 512, 1, 1])
    if k == 'conv3_3_norm.weight':
        return t.reshape([1, 256, 1, 1])
    if can_squeeze(t):
        return t.squeeze()
    else:
        return torch.permute(t, [3, 2, 1, 0])

def rename_key(key):
    new_key = key.split(":")[0] # discard :0 and similar
    key_elements = new_key.split("/")
    new_key = ".".join(key_elements) # replace every / with .
    return new_key

# Load the state dictionary from the numpy file
state_dict_np = np.load(model_path, allow_pickle=True).item()

# Convert numpy arrays within the state dictionary to PyTorch tensors
state_dict_torch = {rename_key(k): reshape_tensor(rename_key(k), torch.tensor(v, dtype=torch.float32)).cpu() for k, v in state_dict_np.items()}

# Load the converted state dictionary into the model
self.model.load_state_dict(state_dict_torch)

相关问题