pytorch 保存JIT跟踪时SwishImplementation出错

8fsztsew  于 2023-01-09  发布在  其他
关注(0)|答案(1)|浏览(332)

我正在尝试即时跟踪并保存分割模型包中的pytorch模型。但是我收到一个错误。“无法导出Python函数调用'SwishImplementation'。在导出之前删除对python函数的调用。是否忘记添加@script或@scrript_method注解?如果这是nn.ModuleList,add it to _ constants _”只有当我使用efficientnet Backbone.js 时才会发生这种情况。我如何才能让save()函数工作呢?我需要能够在c++应用程序中使用该模型。

import torch
import segmentation_models_pytorch as smp

model = smp.Unet('efficientnet-b7')
model.eval()

input = torch.randn((1,3,224,224))
torch_out = model(input)

model = torch.jit.trace(model,input)
trace_out = model(input)

model.save('model.pt')
1l5u6lss

1l5u6lss1#

segmentation_models_pytorch模块中的UNET模型使用EfficientNet,而EfficientNet使用MemoryEfficientSwish模块。若要修复此错误,请在保存模型之前将MemoryEfficientSwish的所有示例更改为Swish
您可以迭代UNET模型,如果模块是EfficientNet的示例,则调用函数.set_swish(memory_efficient = False)
之后,您可以加载state_dict,然后跟踪并保存模型。

相关问题