我需要通过第一列的键值对一批二维矩阵的行进行排序:
原始批处理矩阵(3DTensor):
torch.tensor([[[2, 0],
[0, 1],
[1, 2]],
[[1, 2],
[0, 0],
[2, 1]]])
字符串
所需Tensor:
torch.tensor([[[0, 1],
[1, 2],
[2, 0]],
[[0, 0],
[1, 2],
[2, 1]]])
型
已经知道how to handle one of the batch,和another answer通过for循环解决问题,这不是并行的。那么如何处理整个批处理?
1条答案
按热度按时间rvpgvaaj1#
这可能有点令人困惑,但很有意义:
字符串
在第一行中,我们提取了
torch.argsort
的排序Tensor,并将其应用于my_tensor
,得到了(2, 2, 3, 2)
的形状Tensor。由于我们希望每个元素只根据其第一列进行排序,所以我们只对前两个维度的对角线感兴趣,可以通过切片来提取它(第二行代码)。