假设我有下面的嵌入emb_user = torch.randn(64, 128, 256)
,我希望从第二维(长度为128)中随机挑选出16个嵌入,我想知道是否有更有效的方法来完成下面的任务:
idx = torch.multinomial(torch.ones(64, 128), 16)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
我还发现,如果权重矩阵(torch.ones(64, 128)
)超过2维以上,上述多项式将不起作用。
1条答案
按热度按时间gcuhipw91#
因为在你的情况下你想要一个均匀的分布你可以加速它
而不是
我的机器上的运行时是
427 µs
和784 µs
与device='cpu'
;以及469 µs
与device='cuda'
的组合。它是如何工作的?
排序随机数给出了具有替换的多项式分布的指数,即递增,加上
arange
项使其严格递增,从而消除了替换。用一个小箱子说明
第一个
下面是一个可能更快的解决方案,但不是来自完全相同的发行版:
还有一种方法,我期望更接近确切的方法,不是非常严格的分析。
但最终它往往比使用sort的方法慢。