我有4个PyTorchTensor:
data
形状(l, m, n)
a
,形状为(k,)
,数据类型为long
b
,形状为(k,)
,数据类型为long
c
,形状为(k,)
,数据类型为long
我想对Tensordata
进行切片,使得它在0th
维度中选择a
所寻址的元素。在1st
和2nd
维度中,我想根据b
和c
所寻址的元素选择一片值。具体来说,我想选择9个值-一个3x3
补丁围绕b
寻址的值。因此,我的切片Tensor应该具有(k, 3, 3)
的形状。
MWE:
data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()
data1 = data[a, b-1:b+1, c-1:c+1] # gives error
>>> TypeError: only integer tensors of a single element can be converted to an index
预期产出
data1[0] = [[143,144,145],[153,154,155],[163,164,165]]
data1[1] = [[52,53,54],[62,63,64],[72,73,74]]
data1[2] = [[126,127,128],[136,137,138],[146,147,148]]
and so on
我如何在不使用for循环的情况下做到这一点?
附言:
- 我填充了
data
,以确保a,b,c
寻址的位置在限制范围内。 - 我不需要梯度来通过这个操作。所以,我可以将这些转换为NumPy并切片,如果这样更快的话。但我更喜欢PyTorch中的解决方案。
2条答案
按热度按时间vcirk6k61#
我会先展开索引,然后在重复的索引上添加移位。注意,行和列的移位应该是相反的。例如,
输出:
avwztpqn2#
我可以用
slice
来实现它,尽管在最后有一个列表解析。然而,它是一个只有k个元素的循环。