Pytorch:测试每行第一个二维Tensor是否也存在于第二个Tensor中?

6mzjoqzu  于 2022-11-23  发布在  其他
关注(0)|答案(2)|浏览(180)

给定两个Tensort1t2

t1=torch.tensor([[1,2],[3,4],[5,6]])
t2=torch.tensor([[1,2],[5,6]])

如果t1的行元素存在于t2中,则返回True,否则返回False。理想的结果是[Ture, False, True]。我尝试了torch.isin(t1, t2),但它是按元素返回结果,而不是按行返回结果。顺便说一下,如果是numpy数组,则可以通过

np.in1d(t1.view('i,i').reshape(-1), t2.view('i,i').reshape(-1))

我想知道如何在Tensor中得到类似的结果?

ux6nzvsh

ux6nzvsh1#

def rowwise_in(a,b):
  """ 
  a - tensor of size a0,c
  b - tensor of size b0,c
  returns - tensor of size a1 with 1 for each row of a in b, 0 otherwise
  """
  
  # dimensions
  a0 = a.shape[0]
  b0 = b.shape[0]
  c  = a.shape[1]
  assert c == b.shape[1] , "Tensors must have same number of columns"

  a_expand = a.unsqueeze(1).expand(a0,b0,c)
  b_expand = b.unsqueeze(0).expand(a0,b0,c)

  # element-wise equality
  equal = a_expand == b_expand

  # sum along dim 2 (all elements along this dimension must be true for the summed dimension to be True)
  row_equal = torch.prod(equal,dim = 2)

  row_in_b = torch.max(row_equal, dim = 1)[0]
  return row_in_b
dgsult0t

dgsult0t2#

除了DerekG这个伟大的解决方案之外,这个小小的改变看起来更快、更健壮。

a,b = torch.tensor([[1,2,3],[3,4,5],[5,6,7]],device=torch.device(0)), torch.tensor([[1,2,3],[5,6,7]],device=torch.device(0))
# dimensions
shape1 = a.shape[0]
shape2 = b.shape[0]
c  = a.shape[1]
assert c == b.shape[1] , "Tensors must have same number of columns"

a_expand = a.unsqueeze(1).expand(-1,shape2,c)
b_expand = b.unsqueeze(0).expand(shape1,-1,c)
# element-wise equality
mask = (a_expand == b_expand).all(-1).any(-1)

我尝试了10000行的Tensor,它的工作相当快,没有内存浪费

相关问题