pytorch 从2D矩阵构造3DTensor

tquggr8v  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(148)

给定n × n矩阵A,其中A的每一行是[n]的置换,例如,

import torch
n = 100
AA = torch.rand(n, n)
A = torch.argsort(AA, dim=1)

同样给定另一个n × n矩阵P,我们想要构造一个3DTensorQs. t。

Q[i, j, k] = P[A[i, j], k]

在pytorch中有什么有效的方法吗?我知道 torch.gather,但似乎很难直接应用在这里。

az31mfrm

az31mfrm1#

您可以直接用途:

Q = P[A]
yqkkidmi

yqkkidmi2#

为什么不简单地使用A作为索引:

Q = P[A, :]

相关问题