基于其他Tensor中的值对多维pytorchTensor进行切片

lf3rwulv  于 2023-03-23  发布在  其他
关注(0)|答案(2)|浏览(162)

我有4个PyTorchTensor:

  • data形状(l, m, n)
  • a,形状为(k,),数据类型为long
  • b,形状为(k,),数据类型为long
  • c,形状为(k,),数据类型为long

我想对Tensordata进行切片,使得它在0th维度中选择a所寻址的元素。在1st2nd维度中,我想根据bc所寻址的元素选择一片值。具体来说,我想选择9个值-一个3x3补丁围绕b寻址的值。因此,我的切片Tensor应该具有(k, 3, 3)的形状。
MWE:

data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()

data1 = data[a, b-1:b+1, c-1:c+1]  # gives error

>>> TypeError: only integer tensors of a single element can be converted to an index

预期产出

data1[0] = [[143,144,145],[153,154,155],[163,164,165]]
data1[1] = [[52,53,54],[62,63,64],[72,73,74]]
data1[2] = [[126,127,128],[136,137,138],[146,147,148]]
and so on

我如何在不使用for循环的情况下做到这一点?
附言:

  • 我填充了data,以确保a,b,c寻址的位置在限制范围内。
  • 我不需要梯度来通过这个操作。所以,我可以将这些转换为NumPy并切片,如果这样更快的话。但我更喜欢PyTorch中的解决方案。
vcirk6k6

vcirk6k61#

我会先展开索引,然后在重复的索引上添加移位。注意,行和列的移位应该是相反的。例如,

import torch

data = torch.arange(200).reshape((2, 10, 10))
a = torch.Tensor([1, 0, 1, 1, 0]).long()
b = torch.Tensor([5, 6, 3, 4, 7]).long()
c = torch.Tensor([4, 3, 7, 6, 5]).long()

index1 = a.repeat_interleave(9) # kernel_size^2

index2 = b.repeat_interleave(9) # kernel_size^2
shift = torch.arange(-1, 2).repeat_interleave(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1, -1, -1,  0,  0,  0,  1,  1,  1]
shifted_index2 = index2 + shift

index3 = c.repeat_interleave(9) 
shift = torch.arange(-1, 2).repeat(3).repeat(5) # Shape: (kernel_size^2 x 5) -> [-1,  0,  1, -1,  0,  1, -1,  0,  1]
shifted_index3 = index3 + shift

# Use the indexing arrays to select the patches
data1 = data[index1, shifted_index2, shifted_index3].view(5, 3, 3)

print(data1[0])
print(data1[1])
print(data1[2])

输出:

tensor([[143, 144, 145],
        [153, 154, 155],
        [163, 164, 165]])
tensor([[52, 53, 54],
        [62, 63, 64],
        [72, 73, 74]])
tensor([[126, 127, 128],
        [136, 137, 138],
        [146, 147, 148]])
avwztpqn

avwztpqn2#

我可以用slice来实现它,尽管在最后有一个列表解析。然而,它是一个只有k个元素的循环。

import numpy as np 

a = torch.IntTensor([1, 0, 1, 1, 0]).long()
b = torch.IntTensor([5, 6, 3, 4, 7]).long()
c = torch.IntTensor([4, 3, 7, 6, 5]).long()
data = torch.arange(200).reshape((2, 10, 10))

a = list(slice(val, val+1)   for val in a)
b = list(slice(val-1, val+2) for val in  b)
c = list(slice(val-1, val+2) for val in  c)

data1 = [data[a_, b_, c_] for a_, b_, c_ in zip(a,b,c)]

相关问题