Pytorch,使用多个索引从Tensor中检索值,计算效率最高的解决方案

r7knjye2  于 2023-06-29  发布在  其他
关注(0)|答案(1)|浏览(149)

如果我有一个3dTensor的例子

a = [[4, 2, 1, 6],[1, 2, 3, 8], [92, 4, 23, 54]]
tensor_a = torch.tensor(a)

我可以得到2个一维Tensor沿着第一维使用

tensor_a[[0, 1]]
tensor([[4, 2, 1, 6],
        [1, 2, 3, 8]])

但是,如何使用多个指数呢?
所以我有这样的东西
list_indices = [[0, 0], [0, 2], [1, 2]]
我可以做一些

combos = []
for indi in list_indices:
    combos.append(tensor_a[indi])

但我想知道,既然有for循环,是否有一种更计算化的方法来做到这一点,也许也可以使用PyTorch

wz3gfoph

wz3gfoph1#

使用预定义的Pytorch函数torch.index_select使用索引列表选择Tensor元素在计算上更有效:

a = [[4, 2, 1, 6], [1, 2, 3, 8], [92, 4, 23, 54]]
tensor_a = torch.tensor(a)
list_indices = [[0,  0], [0, 2], [1, 2]]

# Convert list_indices to Tensor.
indices = torch.tensor(list_indices)

# Get elements from 'tensor_a' using indices.
tensor_a = torch.index_select(tensor_a, 0, indices.view(-1))
print(tensor_a)

如果你希望结果是一个列表而不是一个Tensor,你可以将tensor_a转换为一个列表:

tensor_a_list = tensor_a.tolist()

为了测试计算效率,我创建了1,000,000个索引并比较了执行时间。使用循环比使用我建议的PyTorch方法花费更多的时间:

import time
import torch
start_time = time.time()
a = [[4, 2, 1, 6], [1, 2, 3, 8], [92, 4, 23, 54]]
tensor_a = torch.tensor(a)
indices = torch.randint(0, 2, (1000000,)).tolist()
for indi in indices:
   combos.append(tensor_a[indi])
print("--- %s seconds ---" % (time.time() - start_time))
--- 3.3966853618621826 seconds ---

start_time = time.time()
indices = torch.tensor(indices)
tensor_a=torch.index_select(tensor_a, 0, indices)
print("--- %s seconds ---" % (time.time() - start_time))
--- 0.10641193389892578 seconds ---

相关问题