pytorch 逐列点-产品torch.einsum不匹配torch.sum(torch.穆尔(),轴=0)

aiqt4smr  于 2023-04-06  发布在  其他
关注(0)|答案(1)|浏览(160)

我试图在两个Tensor的列之间执行点积。我试图以最有效的方式做到这一点。然而,我的两种方法并不匹配。
我使用torch.sum(torch.mul(a, b), axis=0)的第一个方法给了我预期的结果,torch.einsum('ji, ji -> i', a, b)(取自Efficient method to compute the row-wise dot product of two square matrices of the same size in PyTorch)没有。可重复的代码如下:

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)

a = torch.randn(3,1, dtype=torch.float).to(device)
b = torch.randn(3,4, dtype=torch.float).to(device)

print(f"a : \n{a}\n")
print(f"b : \n{b}\n")
print(f"Expected:    {a[0,0]*b[0,0] + a[1,0]*b[1,0] + a[2,0]*b[2,0]}")

c = torch.sum(torch.mul(a, b), axis=0)
print(f"sum and mul: {c[0].item()}")

d = torch.einsum('ji, ji -> i', a, b)
print(f"einsum:      {d[0].item()}\n")

print(torch.eq(c,d))

输出为:

注意:在CPU上(我所做的只是删除.to(device)),最后一行torch.eq(c,d)都是真的,但是,我需要Tensor在GPU上。
此外,对于一些种子,如torch.manual_seed(100),Tensor相等…
我觉得它必须是与einsum的东西,因为我可以得到我的预期答案与其他方式。

blmhpbnm

blmhpbnm1#

正如@Hayoung所强调的,这是由于操作计算中的数值错误。即使使用 double 作为数据类型,您也会得到不同的结果。您可以使用torch.isclose来比较Tensor值。

>>> torch.isclose(c, d)
tensor([True, True, True, True], device='cuda:0')

相关问题