pytorch 在Tensor列表中查找n个最小值

lokaqttq  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(208)

我试图在pytorch中找到一个Tensor列表中n个最小值的索引。由于这些Tensor可能包含许多非唯一值,我不能简单地计算百分位数来获得索引。然而,非唯一值的顺序并不重要。
我想出了下面的解决方案,但我想知道是否有一个更优雅的方式来做它:

import torch

n = 10

tensor_list = [torch.randn(10, 10), torch.zeros(20, 20), torch.ones(30, 10)]
all_sorted, all_sorted_idx = torch.sort(torch.cat([t.view(-1) for t in tensor_list]))

cum_num_elements = torch.cumsum(torch.tensor([t.numel() for t in tensor_list]), dim=0)
cum_num_elements = torch.cat([torch.tensor([0]), cum_num_elements])

split_indeces_lt = [all_sorted_idx[:n] < cum_num_elements[i + 1] for i, _ in enumerate(cum_num_elements[1:])]
split_indeces_ge = [all_sorted_idx[:n] >= cum_num_elements[i] for i, _ in enumerate(cum_num_elements[:-1])]
split_indeces = [all_sorted_idx[:n][torch.logical_and(lt, ge)] - c for lt, ge, c in zip(split_indeces_lt, split_indeces_ge, cum_num_elements[:-1])]

n_smallest = [t.view(-1)[idx] for t, idx in zip(tensor_list, split_indeces)]

编辑:理想情况下,解决方案将选择非唯一值的随机子集,而不是选择列表中第一个Tensor的条目。

mnemlml8

mnemlml81#

Pytorch确实提供了一种更优雅(我认为)的方式来做这件事,用torch.unique_consecutive(见这里)
我将研究Tensor,而不是Tensor的列表,因为就像你自己做的那样,只需要做一个cat,然后解开指数也不难。


# We want to find the n=3 min values and positions in t

n = 3
t = torch.tensor([1,2,3,2,0,1,4,3,2])

# To get a random occurrence, we create a random permutation

randomizer = torch.randperm(len(t))

# first, we sort t, and get the indices

sorted_t, idx_t = t[randomizer].sort()

# small util function to extract only the n smallest values and positions

head = lambda v,w : (v[:n], w[:n])

# use unique_consecutive to remove duplicates

uniques_t, counts_t = head(*torch.unique_consecutive(sorted_t, return_counts=True))

# counts_t.cumsum gives us the position of the unique values in sorted_t

uniq_idx_t = torch.cat([torch.tensor([0]), counts_t.cumsum(0)[:-1]], 0)

# And now, we have the positions of uniques_t values in t :

final_idx_t = randomizer[idx_t[uniq_idx_t]]
print(uniques_t, final_idx_t)

# >>> tensor([0,1,2]), tensor([4,0,1])

# >>> tensor([0,1,2]), tensor([4,5,8])

# >>> tensor([0,1,2]), tensor([4,0,8])

编辑:我认为增加的排列解决了你的需求随机发生问题

相关问题