今天,我遇到了这样一个问题:
TensorA
是形状为(1, 4, 4)
的分割掩模,其值为0或1。
TensorB
是由torch.eye(2)
创建的对角数组。
我的问题是,为什么我们可以用B[A]
的形式的A
(3D)来索引B
(2D),为什么结果是(1, 4, 4, 2)
形状的Tensor?
上面是我的测试示例,socure代码是从一个dicloss类中获得的:
y_true_dummy = torch.eye(num_classes)[y_true.squeeze(1)]
y_true
的形状是(b, h, w)
,num_classes
等于c
。
顺便问一下,为什么我们需要函数.squeeze()
?
我想一些关于索引问题的解释和一些视频更赞赏。
1条答案
按热度按时间afdcj2ne1#
如果您处理一个较小的示例,就可以理解这个问题:
[1, 0]
和[0, 1]
是2x2单位矩阵B
的第一行和第二行。因此,使用形状为(4,)的一维数组A
作为索引是选择B
的4个“行”/沿着轴0选择B的4个元素。B[A]
基本上是[B[1], B[1], B[0], B[1]]
。因此,当
A
是形状为(1, 4, 4)
的三维数组时,B[A]
意味着**选择B的(1,4,4)行。**由于B中的每一行都有2个元素(2列),因此输出为(1,4,4,2)。B
是一个2x2单位矩阵,有2行。从这两行中选出16行,得到一个(16,2)矩阵,然后将其整形得到(1,4,4,2)Tensor。实际上,你可以很容易地检查:这也不是PyTorch特有的现象,你可以在NumPy中观察到同样的索引规则,它与torch保持着紧密的兼容性。