pytorch Torch:使用非零元素更新Tensor

mccptt67  于 2023-01-20  发布在  其他
关注(0)|答案(2)|浏览(229)

假设我有:

>>> a = torch.tensor([1, 2, 3, 0, 0, 1])
>>> b = torch.tensor([0, 1, 3, 3, 0, 0])

我想用a中的元素来更新B,如果它不为零,我该怎么做呢?
预期:

>>> b = torch.tensor([1, 2, 3, 3, 0, 1])
6yt4nkrj

6yt4nkrj1#

要添加到前面的答案中,并且为了更加简单,您可以通过一行代码来完成:

b = torch.where(a!=0,a, b)

输出:

tensor([1, 2, 3, 3, 0, 1])
mfuanj7w

mfuanj7w2#

torch.where是您的答案,我假设基于您的示例,您还希望仅替换a中为0的元素。

mask = torch.logical_and(b!=0,a==0)
output = torch.where(mask,b,a)

相关问题