pytorch 检查Tensor值是否包含在另一Tensor中

ajsxfq5m  于 8个月前  发布在  其他
关注(0)|答案(1)|浏览(160)

我有一个这样的 Torch Tensor:

a=[1, 234, 54, 6543, 55, 776]

字符串
和其他Tensor一样:

b=[234, 54]
c=[55, 776]


我想创建一个新的掩码Tensor,如果有另一个Tensor(bc)与a相等,则a的值将为true。
例如,在上面的Tensor中,我想创建以下掩码Tensor:

a_masked =[False, True, True, False, True, True]
# The first two True values correspond to tensor `b` while the last two True values 
correspond to tensor `c`.


我见过其他方法来检查一个完整的Tensor是否包含在另一个Tensor中,但这里不是这样。
有没有一个 Torch 的方式来做到这一点有效?谢谢!

cgyqldqp

cgyqldqp1#

根据PyTorch论坛here上的答案,您可以显式地使用for循环,例如,

import torch

a = torch.tensor([1, 234, 54, 6543, 55, 776])
b = torch.tensor([234, 54])
c = torch.tensor([55, 776])

a_masked = sum(a == i for i in b).bool() + sum(a == i for i in c).bool()

print(a_masked)
tensor([False,  True,  True, False, True, True])

字符串
然而,实际上有一个PyTorch isin函数,你可以这样做:

a_masked = torch.isin(a, torch.cat([b, c]))


这比sum方法快几倍。

相关问题