为什么我在对PyTorch模型进行静态量化时会出现“AssertionError:did not find fuser method for:”错误

wa7juj8i  于 2023-11-19  发布在  其他
关注(0)|答案(2)|浏览(217)

当我尝试在模型上应用静态量化时,我得到了以下错误。错误在代码的fuse部分:torch.quantization.fuse_modules(model, modules_to_fuse)

model = torch.quantization.fuse_modules(model, modules_to_fuse)
  File "/Users/celik/PycharmProjects/GFPGAN/colorization/lib/python3.8/site-packages/torch/ao/quantization/fuse_modules.py", line 146, in fuse_modules
    _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
  File "/Users/celik/PycharmProjects/GFPGAN/colorization/lib/python3.8/site-packages/torch/ao/quantization/fuse_modules.py", line 77, in _fuse_modules
    new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
  File "/Users/celik/PycharmProjects/GFPGAN/colorization/lib/python3.8/site-packages/torch/ao/quantization/fuse_modules.py", line 45, in fuse_known_modules
    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
  File "/Users/celik/PycharmProjects/GFPGAN/colorization/lib/python3.8/site-packages/torch/ao/quantization/fuser_method_mappings.py", line 132, in get_fuser_method
    assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
AssertionError: did not find fuser method for: (<class 'torch.nn.modules.conv.Conv2d'>,)

字符串

u3r8eeie

u3r8eeie1#

modules_to_fuse列表应遵循以下规则:

Fuses only the following sequence of modules:
    conv, bn
    conv, bn, relu
    conv, relu
    linear, relu
    bn, relu
    All other sequences are left unchanged.
    For these sequences, replaces the first item in the list
    with the fused module, replacing the rest of the modules
    with identity.

字符串
我不能为'torch.nn.modules.conv.Conv2d'融合模型。它应该与“cone,bn”或“conv,bn,relu”或“conv,relu”融合,其他组合不起作用。使用上面的列表来准备你的融合列表。它为我工作。
这里还有另一个融合方法列表:

DEFAULT_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
(nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
(nn.Conv1d, nn.ReLU): nni.ConvReLU1d,
(nn.Conv2d, nn.ReLU): nni.ConvReLU2d,
(nn.Conv3d, nn.ReLU): nni.ConvReLU3d,
(nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
(nn.Linear, nn.ReLU): nni.LinearReLU,
(nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,
(nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,}

b4lqfgs4

b4lqfgs42#

我遇到了同样的错误,但对我来说,问题是,我使用的是不支持的LeakyReLU,将LeakyReLU()更改为nn。

相关问题