pytorch 如果你有一组索引,如何操作Tensor的元素?

js5cn81o  于 2022-11-29  发布在  其他
关注(0)|答案(1)|浏览(185)

假设我有一个Tensor,使用torch.topk函数,我得到Tensor的最大k个元素及其索引。如以下代码

>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))

现在假设我想把上面的最大k个元素设置为0,但是保持它们在原始Tensor中的相同位置。我该怎么做呢?
如果k =3,则新Tensor应该看起来像:
Tensor([ 1.,2.,0.,0.,0.])
基本上,我如何使用这些max k元素的索引(返回topk()torch函数)将这些位置的原始值清零(将它们的值设置为0)?
最好是一个 Torch 方法的建议,做什么我问。如果不是,这将是最好的解决方案是尽可能有效的。
提前感谢你的帮助:)

iyr7buue

iyr7buue1#

torch.topk函数会传回所提供维度上前k个元素的索引。您可以使用torch.scatter执行重新指派作业:

>>> _, i = x.topk(k=3)
>>> x.scatter(dim=0, index=i, value=0)
tensor([1., 2., 0., 0., 0.])

相关问题