设a
和b
是两个PyTorchTensor,分别为a.shape=[A,3]
和b.shape=[B,3]
,且b
的类型为long
。
然后我知道有几种方法可以对a
进行切片。
c = a[N1:N2:jump,[0,2]] # N1<N2<A
对于N1=1、N2=4和jump=2,将返回c.shape = [2,2]
。
但是下面的代码应该会抛出一个错误,
c = a[b]
而是c.shape = [B,3,3]
。
例如,
a = torch.rand(10,3)
b = torch.rand(20,3).long()
print(a[b].shape) #torch.Size([20, 3, 3])
有人能解释一下a[b]
的切片工作原理吗?
2条答案
按热度按时间vlju58qv1#
基础知识
例如
假设B具有以下值:
以下是这些值的计算方法:
B的第二行是[3,4,5]。
B的第三行是[1,2,3]。
所有这些切片沿着第一维度连接以产生具有形状[3,3,3]的最终结果。
cxfofazt2#
由于B是长的,torch把它当作指数仓位,如果它不是长的,上面的操作就不起作用。
注意,
a[b]
的第一个元素是a
的第一个元素,并且最后一个元素和再次的第一个元素对应于索引[0, -1, 0]
,并且因此由于它对于a
的相关位置的每个条目进行采样,所以你得到[20, 3, 3]
形状。因此,假设
b
中的每个条目对应于a
焊炬切片a
中具有给定位置的有效索引,并且对于b
的每个条目也是如此,并将所有条目连接到具有上述形状的新Tensor。如果将存在无效索引(b = torch.randn(20, 3).long() * 10
),则您将得到: