在将pytorch代码导出到ONNX的上下文中,我得到了以下警告:
TracerWarning:torch.tensor结果在跟踪中被注册为常量。如果您使用此函数从常量变量创建Tensor,并且每次调用此函数时都是相同的,则可以安全地忽略此警告。在任何其他情况下,这可能会导致跟踪不正确。
以下是令人反感的一行:
text_lengths = torch.tensor([text_inputs.shape[1]]).long().to(text_inputs.device)
字符串text_inputs
是形状为torch.Size([1, 81])
的torch.Tensor
这个警告是正确的,不能忽略,因为text_inputs
的形状应该是动态的。
我需要text_lengths
是一个torch.Tensor
,它包含来自text_inputs
的shape
的数字81
。上面的“冒犯行”成功地做到了这一点,但我们实际上从pytorch到Python int
再回到pytorch的往返,因为torch.Size
对象中的元素是Python int
s。这(1)有点奇怪,(2)可能在GPU -> CPU -> GPU方面效率低下,如上所述,这是ONNX导出上下文中的实际问题。
有没有其他方法可以让我在不“离开”torch世界的情况下,在torch计算中使用Tensor的形状?
1条答案
按热度按时间vm0i2vca1#
使用
torch.tensor
将导致导出的图中有一个常量值。在跟踪时,torch.tensor.shape
和torch.tensor.size
的输出应该返回Tensor(而不是python整数),因此上面的代码应该按预期导出,只需使用字符串
返回正确值的
LongTensor
。由于构造函数中的括号,代码似乎没有按照预期的方式运行。
[text_inputs.shape[1]]
中的形状在python列表中被解释为整数,因此在跟踪过程中保存为常量。在创建Tensor时删除额外的括号将正确导出为在运行时由
text_inputs
形状定义的标量。型