我用pytorch写了下面的代码,遇到了一个运行时错误:
tns = torch.tensor([1,0,1])
tns.mean()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-666-194e5ab56931> in <module>
----> 1 tns.mean()
RuntimeError: mean(): input dtype should be either floating point or complex dtypes. Got Long instead.
但是,如果我将Tensor更改为float,错误就会消失:
tns = torch.tensor([1.,0,1])
tns.mean()
---------------------------------------------------------------------------
tensor(0.6667)
我的问题是为什么会出现错误。第一个tenor的数据类型是int64而不是Long,为什么PyTorch将其作为Long?
2条答案
按热度按时间72qzrwbm1#
这是因为
torch.int64
和torch.long
都引用相同的数据类型,即64位有符号整数。有关所有数据类型的概述,请参见here。gc0ot86w2#
您应该将'torch.tensor([1,0,1])'更改为'torch.Tensor([1,0,1])。