在PyTorch中,内置的torch.roll函数只能以相同的偏移量移动列(或行)。但我想用不同的偏移量移动列。假设输入Tensor为
torch.roll
[[1,2,3], [4,5,6], [7,8,9]]
假设,我想对第i列偏移i。因此,预期输出为
i
[[1,8,6], [4,2,9], [7,5,3]]
这样做的一个选项是使用torch.roll单独移动每一列,并对每一列进行concat。但出于效率和代码紧凑性的考虑,我不想引入循环结构。有更好的办法吗?
ukxgm1gy1#
让我们定义一些名称:
import torch mat = torch.Tensor( [[1,2,3], [4,5,6], [7,8,9]]) indices = torch.LongTensor([0, 1, 2]) # Could also use arange in this specific scenario
首先,你可以做一个Tensor,
[[0, 0, 0], [1, 1, 1], [2, 2, 2]]
使用
arange1 = torch.arange(3).view((3, 1)).repeat((1, 3))
现在,让我们为目标索引创建一个Tensor
[[0, 2, 1], [1, 0, 2], [2, 1, 0]]
与
arange2 = (arange1 - indices) % 3
最后,我们得到预期的输出,
torch.gather(mat, 0, arange2)
yrdbyhpb2#
我对torch.gather的性能持怀疑态度,所以我用numpy搜索了类似的问题,找到了this的帖子。
torch.gather
我从@Andy L那里得到了解决方案,并将其翻译成了pytorch。然而,带着一粒盐,因为我不知道步幅是如何工作的:
from numpy.lib.stride_tricks import as_strided # NumPy solution: def custom_roll(arr, r_tup): m = np.asarray(r_tup) arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].copy() #need `copy` #print(arr_roll) strd_0, strd_1 = arr_roll.strides #print(strd_0, strd_1) n = arr.shape[1] result = as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1)) return result[np.arange(arr.shape[0]), (n-m)%n] # Translated to PyTorch def pcustom_roll(arr, r_tup): m = torch.tensor(r_tup) arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].clone() #need `copy` #print(arr_roll) strd_0, strd_1 = arr_roll.stride() #print(strd_0, strd_1) n = arr.shape[1] result = torch.as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1)) return result[torch.arange(arr.shape[0]), (n-m)%n]
这里也是解决方案从@丹尼尔M作为即插即用。
def roll_by_gather(mat,dim, shifts: torch.LongTensor): # assumes 2D array n_rows, n_cols = mat.shape if dim==0: #print(mat) arange1 = torch.arange(n_rows).view((n_rows, 1)).repeat((1, n_cols)) #print(arange1) arange2 = (arange1 - shifts) % n_rows #print(arange2) return torch.gather(mat, 0, arange2) elif dim==1: arange1 = torch.arange(n_cols).view(( 1,n_cols)).repeat((n_rows,1)) #print(arange1) arange2 = (arange1 - shifts) % n_cols #print(arange2) return torch.gather(mat, 1, arange2)
首先,我在CPU上运行这些方法。令人惊讶的是,上面的gather解决方案是最快的:
gather
n_cols = 10000 n_rows = 100 shifts = torch.randint(-100,100,size=[n_rows,1]) data = torch.arange(n_rows*n_cols).reshape(n_rows,n_cols) npdata = np.arange(n_rows*n_cols).reshape(n_rows,n_cols) npshifts = shifts.numpy() %timeit roll_by_gather(data,1,shifts) %timeit pcustom_roll(data,shifts) %timeit custom_roll(npdata,npshifts) >> 2.41 ms ± 68.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) >> 90.4 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) >> 247 ms ± 6.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
在GPU上运行代码会显示类似的结果:
%timeit roll_by_gather(data,shifts) %timeit pcustom_roll(data,shifts) 131 µs ± 6.79 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 3.29 ms ± 46.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(注意:roll_by_gather方法中需要torch.arange(...,device='cuda:0'))
roll_by_gather
torch.arange(...,device='cuda:0')
jecbmhm33#
@DanielM解决方案的通用版本。给出:
mat = torch.tensor( [[1,2,3], [4,5,6], [7,8,9]] ) shifts = torch.tensor([0, 1, 2])
indices = (torch.arange(mat.shape[0])[:, None] - shifts[None, :]) % mat.shape[0] torch.gather(mat, 0, indices)
indices = (torch.arange(mat.shape[1])[None, :] - shifts[:, None]) % mat.shape[1] torch.gather(mat, 1, indices)
def roll_along(arr, shifts, dim): assert arr.ndim - 1 == shifts.ndim dim %= arr.ndim shape = (1,) * dim + (-1,) + (1,) * (arr.ndim - dim - 1) dim_indices = torch.arange(arr.shape[dim]).reshape(shape) indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim] return torch.gather(arr, dim, indices)
roll_along(mat, shifts, dim=0) # roll rows roll_along(mat, shifts, dim=1) # roll columns
3条答案
按热度按时间ukxgm1gy1#
让我们定义一些名称:
首先,你可以做一个Tensor,
使用
现在,让我们为目标索引创建一个Tensor
与
最后,我们得到预期的输出,
yrdbyhpb2#
我对
torch.gather
的性能持怀疑态度,所以我用numpy搜索了类似的问题,找到了this的帖子。NumPy到Pytorch的类似解决方案
我从@Andy L那里得到了解决方案,并将其翻译成了pytorch。然而,带着一粒盐,因为我不知道步幅是如何工作的:
这里也是解决方案从@丹尼尔M作为即插即用。
标杆管理
首先,我在CPU上运行这些方法。令人惊讶的是,上面的
gather
解决方案是最快的:在GPU上运行代码会显示类似的结果:
(注意:
roll_by_gather
方法中需要torch.arange(...,device='cuda:0')
)jecbmhm33#
@DanielM解决方案的通用版本。给出:
滚动行
滚动列
沿沿着任意维度滚动