假设有一个矩阵和一个向量,如下所示:
import torch x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) y = torch.tensor([0, 2, 1])
有没有一种方法可以将其分割为x[y],结果是:
x[y]
res = [1, 6, 8]
所以基本上我取y的第一个元素和x中对应于第一行和元素列的元素。
y
x
vjhs03f71#
可以将相应的行索引指定为:
import torch x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) y = torch.tensor([0, 2, 1]) x[range(x.shape[0]), y] tensor([1, 6, 8])
6rqinv9w2#
pytorch中的高级索引就像NumPy's一样工作,即索引数组在轴上一起广播,所以你可以像FBruzzesi的答案那样做。虽然与np.take_along_axis类似,但在pytorch中,您也可以使用torch.gather来获取沿着特定轴的值:
NumPy's
np.take_along_axis
torch.gather
x.gather(1, y.view(-1,1)).view(-1) # tensor([1, 6, 8])
2条答案
按热度按时间vjhs03f71#
可以将相应的行索引指定为:
6rqinv9w2#
pytorch中的高级索引就像
NumPy's
一样工作,即索引数组在轴上一起广播,所以你可以像FBruzzesi的答案那样做。虽然与
np.take_along_axis
类似,但在pytorch中,您也可以使用torch.gather
来获取沿着特定轴的值: