我正在尝试即时跟踪并保存分割模型包中的pytorch模型。但是我收到一个错误。“无法导出Python函数调用'SwishImplementation'。在导出之前删除对python函数的调用。是否忘记添加@script或@scrript_method注解?如果这是nn.ModuleList
,add it to _ constants _”只有当我使用efficientnet Backbone.js 时才会发生这种情况。我如何才能让save()
函数工作呢?我需要能够在c++应用程序中使用该模型。
import torch
import segmentation_models_pytorch as smp
model = smp.Unet('efficientnet-b7')
model.eval()
input = torch.randn((1,3,224,224))
torch_out = model(input)
model = torch.jit.trace(model,input)
trace_out = model(input)
model.save('model.pt')
1条答案
按热度按时间1l5u6lss1#
segmentation_models_pytorch
模块中的UNET
模型使用EfficientNet
,而EfficientNet
使用MemoryEfficientSwish
模块。若要修复此错误,请在保存模型之前将MemoryEfficientSwish
的所有示例更改为Swish
。您可以迭代
UNET
模型,如果模块是EfficientNet
的示例,则调用函数.set_swish(memory_efficient = False)
。之后,您可以加载
state_dict
,然后跟踪并保存模型。