numpy 如何在3D和4D矩阵上做matmul?

e5nqia27  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(146)

我有一个数组a,它的大小是torch.Size([330,330,36])
原始数据结构为[330,6,330,6]
意思是:

系统中有330个原子,每个原子有6个轨道
我想知道所有原子和所有轨道之间的相互作用。

我想执行这些操作:

(1) a.reshape(330,330,6,6).
permute(0,2,1,3).reshape(1980, 1980)

将矩阵转换为(330 x 6)x(330 x 6)

(2) torch.sum(torch.diag(b@b)[1:6])

执行matmul运算并对对角线元素1-5求和
我想知道是否有任何方法可以执行matmul操作而不重塑330 x330 x36矩阵。
非常感谢.
(1)a.重塑(330,330,6,6).置换(0,2,1,3).重塑(1980,1980)
(2)torch.sum(torch.diag(B@B)[1:6])
如果我有一个矩阵列表,如何在一个命令中进行matmul操作?

biswetbf

biswetbf1#

你要求了几件事,而你所做的是没有效率的。

Matmul无整形

正如我将在下面解释的那样,你不应该做这种收缩。但假设你想。您无法避免“拆分”轴36 -> 6 * 6的整形,但可以通过使用torch.tensordot来避免合并6 * 303 -> 1980。对你来说

b = a.reshape(330, 330, 6, 6)
c = torch.tensordot(b, b, ([1, 3], [0, 2]))  # shape [330, 6, 330, 6]

矩阵列表

如果它是一个torch.Tensor的列表,你不能绕过doint某种类型的循环,所以没有“一个命令”的解决方案。如果您有一个Tensor,创建例如通过torch.tensor,比方说对as.shape == (42, 330, 330, 36)的形状进行42不同的“矩阵”,可以批量进行 Torch 操作;

bs = as.reshape(42, 330, 330, 6, 6)
cs = torch.tensordot(bs, bs, ([2, 4], [1, 3]))  # shape [42, 330, 6, 330, 6]

更有效的方式来计算你所追求的

看起来你只对矩阵乘积的几个对角元素感兴趣。在您的情况下,只有1980 * 1980总条目的5。因此,您应该只计算这些条目,因为不需要计算其他大约4000000条目。例如

b = a.reshape(330, 330, 6, 6)
c = torch.sum(b[0, :, 1:5, :] * b[:, 0, :, 1:5])

应该给予给予和你在上面的片段中得到的一样。请注意,由于C风格的重塑,索引1:5变为01:5,例如。

after_reshape = before_reshape.reshape(330, 6)
before_reshape[1:5] == after_reshape[0, 1:5]

相关问题