pytorch 检查Tensor的每个元素是否包含在列表中

jxct1oxe  于 2022-11-09  发布在  其他
关注(0)|答案(4)|浏览(309)

假设我有一个TensorA和一个值vals的容器。有没有一个简洁的方法来返回一个与A形状相同的布尔Tensor,其中每个元素是A的元素是否包含在vals中?例如:
第一个

pw9qyyiw

pw9qyyiw1#

您可以使用for循环来实现这一点:

sum(A==i for i in B).bool()
wfveoks0

wfveoks02#

你可以简单地这样做:

result = A.apply_(lambda x: x in vals).bool()

那么result将包含这个Tensor:

tensor([[ True, False, False],
        [False,  True, False]])

在这里,我只是使用了一个lambda函数和apply_ method,您可以在official documentation中找到它们。

csbfibhn

csbfibhn3#

[list(map(lambda x: x in vals, thelist)) for thelist in A]
dldeef67

dldeef674#

torch.isin方法是最方便的方法,简单如下:torch.isin(A, vals)

相关问题