Pytorch -从维度x中选择n个索引而不进行替换

ds97pgxw  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(115)

假设我有下面的嵌入emb_user = torch.randn(64, 128, 256),我希望从第二维(长度为128)中随机挑选出16个嵌入,我想知道是否有更有效的方法来完成下面的任务:

idx = torch.multinomial(torch.ones(64, 128), 16)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]

我还发现,如果权重矩阵(torch.ones(64, 128))超过2维以上,上述多项式将不起作用。

gcuhipw9

gcuhipw91#

因为在你的情况下你想要一个均匀的分布你可以加速它

idx = torch.sort(torch.randint(
    0, 128 - 15, (64, 16), device=device
), axis=1).values + torch.arange(0, 16, device=device).reshape(1, -1)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]

而不是

idx = torch.multinomial(torch.ones(64, 128, device=device), 16)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]

我的机器上的运行时是427 µs784 µsdevice='cpu';以及469 µsdevice='cuda'的组合。
它是如何工作的?
排序随机数给出了具有替换的多项式分布的指数,即递增,加上arange项使其严格递增,从而消除了替换。
用一个小箱子说明
第一个
下面是一个可能更快的解决方案,但不是来自完全相同的发行版:

idx = torch.cumsum(torch.diff(
torch.sort(torch.randint(
    0, 128 - 16, (64, 17), device=device
), axis=1).values
, axis=1) + 1, axis=1) - 1
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]

还有一种方法,我期望更接近确切的方法,不是非常严格的分析。


# 1-rand() to include 1 and exclude zero.

d = torch.cumsum(1 - torch.rand(64, 17, device=device
), axis=1)

# this produces a sorted tensor with values in the range [0:128-16]

d = (((128 - 15) * d[:, :-1]) / d[:, -1:]).to(torch.long)
idx = d + torch.arange(0, 16, device=device).reshape(1, -1)

但最终它往往比使用sort的方法慢。

相关问题