numpy 获取TensorA中存在的TensorB中的值的索引位置

b91juud3  于 9个月前  发布在  其他
关注(0)|答案(3)|浏览(86)

我本质上是在寻找一种完全矢量化的方法来获取TensorB:[1, 2, 3, 9]和TensorA:[1,2,3,3,2,1,4,5,9],对于TensorB中的每个值,找到它在TensorA中的索引位置,这样输出就像这样:[[0,5], [1,4], [2,3], [-1,8]]个(尽管它是一维的也没问题,只要我能检索到哪个索引对应于TensorB中的哪个值的信息),其中每一行对应于TensorB中的一个值,其中列值是该给定值在A中出现的索引。
这种方法的工作原理是:

def vectorized_find_indices(A, B):
    # Expand dimensions of A for broadcasting
    A_expanded = A[:, None, None]

    # Compare B with expanded A to create a boolean mask
    mask = (B == A_expanded)

    # Get the indices where A matches B
    indices = torch.where(mask, torch.arange(A.size(0), device=A.device)[:, None, None], torch.tensor(-1, device=A.device))

    # Reshape the indices to match the shape of B with an additional dimension for indices
    result = indices.permute(1, 2, 0)

    return result

字符串
但是我使用的Tensor太大了,无法进行广播,所以我受到的限制更大。
我也尝试了几种更简单的方法,比如searchsorted,我找到了这个解决方案:(A[..., None] == B).any(-1).nonzero(),这是接近但不够的,因为返回的索引不再直接附加到值。例如,上面的代码片段将返回:[0, 1, 2, 3, 4, 5, 8],这确实是找到匹配的正确索引,但是信息不再嵌套在第二维中,将其与相应的值联系起来,就像我需要的那样,但是我对pytorch很不熟悉,所以也许有可能以某种方式获取这些信息并使用这些信息重建它?

wlp8pajw

wlp8pajw1#

我不确定是否有一种方法可以做到这一点,而不需要沿着B广播A或在B上执行for循环以减少内存开销。
一个解决方案可能是

overlap_idxs = (a.unsqueeze(1) == b).nonzero()

output = [[] for i in b]

for (a_idx, b_idx) in overlap_idxs:
    output[b_idx].append(a_idx.item())

output
>[[0, 5], [1, 4], [2, 3], [8]]

字符串
或者在B上使用Python级别的循环:

output = []

for _b in b:
    idxs = (a==_b).nonzero().squeeze().tolist()
    if type(idxs) != list:
        idxs = [idxs]
        
    output.append(idxs)
    
output
>[[0, 5], [1, 4], [2, 3], [8]]

lx0bsm1f

lx0bsm1f2#

给定输入B = np.array([1, 2, 3, 9])A = np.array([1,2,3,3,2,1,4,5,9]),为了使算法完全可向量化,我会这样进行:
indices定义为形状为(len(B), len(A), 2)的网格,以便对于每个ijindices[i,j] = [i, j]

indices = np.moveaxis(np.c_[np.meshgrid(np.arange(len(A)), np.arange(len(B)))], 0, -1)

字符串
一旦你有了索引,你就可以过滤它们来检索所有满足A[i]==B[j]的对[i,j]

found = indices[B[:,None]==A[None,:]]
#array([[0, 0],
#       [5, 0],
#       [1, 1],
#       [4, 1],
#       [2, 2],
#       [3, 2],
#       [8, 3]])


这个矩阵中的每一行都是一对匹配元素的索引:例如[0,0]表示A[0]==B[0][5,0]表示A[5]==B[0][8,3]表示A[8]==B[3]等等。
一旦你有了这些数据,你就可以决定如何组织这些信息了。因为你想要的输出是像[[0,5], [1,4], [2,3], [-1,8]]这样的东西,我会使用为this question提出的group_by矢量化函数,由@Carlos Pinzón编写:

output = group_by(found[:,1], lambda idx: found[idx,0])
#array([array([0, 5]), array([1, 4]), array([2, 3]), array([8])], dtype=object)


为了完整起见,我附上了group_by函数:

def group_by(by, func=lambda idx: idx, transform=False, equal_nan=True):
    """
    https://stackoverflow.com/a/77150915
    Groups by the unique values of `by`, calls `func` on each group of
    indices and returns the combined results as a numpy array.
    If `transform=True` the output has as many rows as x (like pandas "transform"),
    and as many rows as unique values in x otherwise (like pandas "apply"),
    ```
    # Examples:
    x = np.array([1, 3, 1, 2, 3])
    y = np.array([5, 4, 3, 2, 1])
    print(group_by(x)) # [array([0, 2]) array([3]) array([1, 4])]
    means = group_by(x, lambda idx: np.mean(y[idx]))
    print(means) # [4.  2.  2.5]
    means = group_by(x, lambda idx: np.mean(y[idx]), transform=True)
    print(means) # [4.  2.5 4.  2.  2.5]
    mins, maxs = group_by(x, lambda idx: [np.min(y[idx]), np.max(y[idx])]).T
    print(mins, maxs) # [3 2 1] [5 2 4]
    for idx in group_by(x):
        print(x[idx[0]], y[idx]) # 1 [5 3]; 2 [2]; 3 [4 1]
    ```
    """
    _, invs, cnts = np.unique(
        by, return_counts=True, return_inverse=True, axis=0, equal_nan=equal_nan
    )
    idxs = np.split(np.argsort(by), np.cumsum(cnts)[:-1])
    out = [func(idx) for idx in idxs]
    if transform:
        out = [out[invs[i]] for i in range(len(by))]
    try:
        out = np.array(out)
    except:
        out = np.array(out, dtype=object)
    return out

gjmwrych

gjmwrych3#

如果AB都是1D数组,并且您想要它们的公共值的索引,则可以广播相等检查,然后使用对应维度的OR折叠结果(这可以使用any并指定轴来完成)。

import numpy as np

B = np.array([1, 2, 3, 9])
A = np.array([1, 2, 3, 3, 2, 1, 4, 5, 9])

equal = A[:,None] == B[None,:]
A_indices = np.nonzero(equal.any(1))[0]
B_indices = np.nonzero(equal.any(0))[0]
print(A_indices)  # [0 1 2 3 4 5 8]
print(B_indices)  # [0 1 2 3]

字符串

相关问题