Pytorch随机选择?

3j86kqsm  于 2022-11-09  发布在  其他
关注(0)|答案(6)|浏览(538)

我有Tensor的图片,并希望从它随机选择。我正在寻找等效的np.random.choice()

import torch

pictures = torch.randint(0, 256, (1000, 28, 28, 3))

假设我想要10张这样的照片。

46qrfjad

46qrfjad1#

torch没有np.random.choice()的等价实现,参见讨论here。替代方案是使用混洗索引或随机整数进行索引。

若要执行 * 替换 *:

1.生成 n 个随机索引
1.使用这些索引为原始Tensor建立索引

pictures[torch.randint(len(pictures), (10,))]

* 无需 * 更换:

1.随机播放索引
1.取第 n 个元素

indices = torch.randperm(len(pictures))[:10]

pictures[indices]

阅读更多关于torch.randinttorch.randperm的信息。第二个代码片段的灵感来自PyTorch论坛中的post

vhmi4jdf

vhmi4jdf2#

对于此大小的Tensor:

N, D = 386363948, 2
k = 190973
values = torch.randn(N, D)

下面的代码运行得相当快,大约需要0.2秒:

indices = torch.tensor(random.sample(range(N), k))
indices = torch.tensor(indices)
sampled_values = values[indices]

但是,使用torch.randperm将花费20秒以上的时间:

sampled_values = values[torch.randperm(N)[:k]]
hujrc8aj

hujrc8aj3#

torch.multinomial提供了与numpy的random.choice等效的行为(包括有/无替换的采样):
第一个

ozxc1zmp

ozxc1zmp4#

试试看:

input_tensor = torch.randn(5, 8)
print(input_tensor)
indices = torch.LongTensor(np.random.choice(5,2, replace=False)) 
output_tensor = torch.index_select(input_tensor, 0, indices)
print(output_tensor)
kzipqqlq

kzipqqlq5#

一个简单的方法是使用代码从Tensor中选择一个元素。在你的情况下,你有一个大小为(1000,28,28,3)的Tensor,我们想从1000张图片中选择10张。

index = torch.randint(0,1000,(10,))
selected_pics = [pictures[i] for i in index]
tf7tbtn2

tf7tbtn26#

正如另一个答案所提到的,torch没有choice,你可以用randint或permutation来代替:

import torch

n = 4
replace = True # Can change
choices = torch.rand(4, 3)
choices_flat = choices.view(-1)

if replace:
    index = torch.randint(choices_flat.numel(), (n,))
else:
    index = torch.randperm(choices_flat.numel())[:n]

select = choices_flat[index]

相关问题