给定n × n矩阵A,其中A的每一行是[n]的置换,例如,
import torch n = 100 AA = torch.rand(n, n) A = torch.argsort(AA, dim=1)
同样给定另一个n × n矩阵P,我们想要构造一个3DTensorQs. t。
Q[i, j, k] = P[A[i, j], k]
在pytorch中有什么有效的方法吗?我知道 torch.gather,但似乎很难直接应用在这里。
az31mfrm1#
您可以直接用途:
Q = P[A]
yqkkidmi2#
为什么不简单地使用A作为索引:
A
Q = P[A, :]
2条答案
按热度按时间az31mfrm1#
您可以直接用途:
yqkkidmi2#
为什么不简单地使用
A
作为索引: