我试图在pytorch中找到一个Tensor列表中n个最小值的索引。由于这些Tensor可能包含许多非唯一值,我不能简单地计算百分位数来获得索引。然而,非唯一值的顺序并不重要。
我想出了下面的解决方案,但我想知道是否有一个更优雅的方式来做它:
import torch
n = 10
tensor_list = [torch.randn(10, 10), torch.zeros(20, 20), torch.ones(30, 10)]
all_sorted, all_sorted_idx = torch.sort(torch.cat([t.view(-1) for t in tensor_list]))
cum_num_elements = torch.cumsum(torch.tensor([t.numel() for t in tensor_list]), dim=0)
cum_num_elements = torch.cat([torch.tensor([0]), cum_num_elements])
split_indeces_lt = [all_sorted_idx[:n] < cum_num_elements[i + 1] for i, _ in enumerate(cum_num_elements[1:])]
split_indeces_ge = [all_sorted_idx[:n] >= cum_num_elements[i] for i, _ in enumerate(cum_num_elements[:-1])]
split_indeces = [all_sorted_idx[:n][torch.logical_and(lt, ge)] - c for lt, ge, c in zip(split_indeces_lt, split_indeces_ge, cum_num_elements[:-1])]
n_smallest = [t.view(-1)[idx] for t, idx in zip(tensor_list, split_indeces)]
编辑:理想情况下,解决方案将选择非唯一值的随机子集,而不是选择列表中第一个Tensor的条目。
1条答案
按热度按时间mnemlml81#
Pytorch确实提供了一种更优雅(我认为)的方式来做这件事,用
torch.unique_consecutive
(见这里)我将研究Tensor,而不是Tensor的列表,因为就像你自己做的那样,只需要做一个
cat
,然后解开指数也不难。编辑:我认为增加的排列解决了你的需求随机发生问题