numpy (8,8)@(4,8,2)的 Torch 广播是如何工作的?

o4hqfura  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(69)

我有Tensor的大小:

a.shape 
=> (8, 8)
b.shape 
=> (4, 8, 2)

c = a @ b 
c.shape
=> (4, 8, 2)

我很惊讶地看到a @ b是可以广播的。我浏览了pytorch的广播规则,似乎这些不应该兼容。
有人能解释一下这是如何计算的吗?

qaxu7uf2

qaxu7uf21#

查看torch.matmul的文档
在这种情况下,b被解释为批量矩阵,第一个轴是批量维度。

相关问题