在Python中将一个LARGE(x,y)点集与另一个包含异常值的点集进行匹配

x0fgdtte  于 2022-11-21  发布在  Python
关注(0)|答案(1)|浏览(119)

我有两个大的(x,y)点集合,我想在Python中将一个集合的每个点与另一个集合的“对应点”关联起来。
第二组也可能包含异常值,即额外的噪声点,如图所示,其中绿色多于红点:

两组点之间的关联不是简单的平移,如下图所示:

在这两个链接中,您可以找到红点和绿色(图像坐标列表,原点位于左上角):
https://drive.google.com/file/d/1fptkxEDYbIJ93r_OXJSstDHMfk67DDYo/view?usp=share_linkhttps://drive.google.com/file/d/1Z_ghWIzUZv8sxfawOBoGG3fJz4h_z7Qv/view?usp=share_link显示器
我的问题与这两个类似:
Match set of x,y points to another set that is scaled, rotated, translated, and with missing elements
How to align two sets of points (translation+rotation) when those sets contain noise?
但我有一个很大的点集,所以这里提出的解决方案不适用于我的情况。我的点在行中有一定的结构,所以很难计算旋转-缩放-平移函数,因为点的行彼此混淆。

rt4zxlrg

rt4zxlrg1#

我发现了一种方法,它可以通过两个阶段相当精确地恢复哪些点对应于哪些点,第一阶段校正仿射变换,第二阶段校正非线性失真。
注意:我选择将红点与绿色匹配,而不是相反。
假设
该方法做了三个假设:
1.它知道三个或更多的绿色和匹配的红点。
1.两者之间的差异大多是线性的。
1.差异的非线性部分是局部相似的,即如果一个点具有(-10,10)的匹配偏移,则相邻点将具有相似的偏移。这由max_search_dist控制。
编号
首先加载两个数据集:

import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import KDTree
from collections import Counter

with open('red_points.json', 'rb') as f:
    red_points = json.load(f)
red_points = pd.DataFrame(red_points, columns=list('xy'))
with open('green_points.json', 'rb') as f:
    green_points = json.load(f)
green_points = pd.DataFrame(green_points, columns=list('xy'))

我发现使用一个函数来可视化这两个数据集非常有用:

def plot_two(green, red):
    if isinstance(red, np.ndarray):
        red = pd.DataFrame(red, columns=list('xy'))
    if isinstance(green, np.ndarray):
        green = pd.DataFrame(green, columns=list('xy'))
    both = pd.concat([green.assign(hue='green'), red.assign(hue='red')])
    ax = both.plot.scatter('x', 'y', c='hue', alpha=0.5, s=0.5)
    ax.ticklabel_format(useOffset=False)

接下来,选择三个绿色的点,并提供它们的XY坐标。找到对应的红色点,并提供它们的XY坐标。

green_sample = np.array([
    [5221, 12460],
    [2479, 2497],
    [6709, 6303],
])
red_sample = np.array([
    [5274, 12597],
    [2375, 2563],
    [6766, 6406],
])

接下来,使用这些点来找到一个仿射矩阵。这个仿射矩阵将涵盖旋转、平移、缩放和倾斜。由于它有六个未知数,你至少需要六个约束,否则方程是欠定的。这就是为什么我们之前至少需要三个点。

def add_implicit_ones(matrix):
    b = np.ones((matrix.shape[0], 1))
    return np.concatenate((matrix,b), axis=1)

def transform_points_affine(points, matrix):
    return add_implicit_ones(points) @ matrix

def fit_affine_matrix(red_sample, green_sample):
    red_sample = add_implicit_ones(red_sample)
    X, _, _, _ = np.linalg.lstsq(red_sample, green_sample, rcond=None)
    return X

X = fit_affine_matrix(red_sample, green_sample)
red_points_transformed = transform_points_affine(red_points.values, X)

现在我们进入非线性匹配步骤,这是在红色的值被转换为与绿色的值匹配之后运行的,算法如下:
1.从一个没有非线性分量的红点开始。green_sample点附近的一个点是一个很好的选择-仿射步骤将优先获得这些点。在这个点周围的半径内搜索相应的绿点。记录红点和相应绿点之间的差异为“漂移”。
1.查看该红点的所有红色邻居,并将其添加到列表中进行处理。
1.在其中一个红点附近,取所有相邻红点的平均漂移值,将该漂移值加到红点上,然后在一个半径内搜索绿点。
1.红点和相应绿点之间的差异就是该红点的漂移。
1.将该点的所有红色相邻点添加到列表中进行处理,然后返回步骤3。

def find_nn_graph(red_points_np):
    nbrs = NearestNeighbors(n_neighbors=8, algorithm='ball_tree').fit(red_points_np)
    _, indicies = nbrs.kneighbors(red_points_np)
    return indicies

def point_search(red_points_np, green_points_np, starting_point, max_search_radius):
    starting_point_idx = (((red_points_np - starting_point)**2).mean(axis=1)).argmin()
    green_tree = KDTree(green_points_np)
    dirty = Counter()
    visited = set()
    indicies = find_nn_graph(red_points_np)
    # Mark starting point as dirty
    dirty[starting_point_idx] += 1

    match = {}

    drift = np.zeros(red_points_np.shape)
    # NaN = unknown drift
    drift[:] = np.nan
    while len(dirty) > 0:
        point_idx, num_neighbors = dirty.most_common(1)[0]
        neighbors = indicies[point_idx]
        if point_idx != starting_point_idx:
            neighbor_drift_all = drift[neighbors]
            if np.isnan(neighbor_drift_all).all():
                # All neighbors have no drift
                # Unmark as dirty and come back to this one
                del dirty[point_idx]
                continue
            neighbor_drift = np.nanmean(neighbor_drift_all, axis=0)
            assert not np.isnan(neighbor_drift).any(), "No neighbor drift found"
        else:
            neighbor_drift = np.array([0, 0])
        # Find the point in the green set
        red_point = red_points_np[point_idx]
        green_points_idx = green_tree.query_ball_point(red_point + neighbor_drift, r=max_search_radius)

        assert len(green_points_idx) != 0, f"No green point found near {red_point}"
        assert len(green_points_idx) == 1, f"Too many green points found near {red_point}"
        green_point = green_points_np[green_points_idx[0]]
        real_drift = green_point - red_point
        match[point_idx] = green_points_idx[0]

        # Save drift
        drift[point_idx] = real_drift
        # Mark unvisited neighbors as dirty
        if point_idx not in visited:
            neighbors = indicies[point_idx, 1:]
            neighbors = [n for n in neighbors if n not in visited]
            dirty.update(neighbors)
        # Remove this point from dirty
        del dirty[point_idx]
        # Mark this point as visited
        visited.add(point_idx)
    # Check that there are no duplicates
    assert len(set(match.values())) == len(match)
    # Check that every point in red_points_np was matched
    assert len(match) == red_points_np.shape[0]
    return match, drift

# This point is assumed to have a drift of zero
# Pick one of the points which was used for the linear correction
starting_point = green_sample[0]
# Maximum distance that a point can be found from where it is expected
max_search_radius = 10
green_points_np = green_points.values
match, drift = point_search(red_points_transformed, green_points_np, starting_point, max_search_radius)

接下来,这里有一个工具,你可以用来检查匹配的质量。这是显示前一千个匹配。下面是一个箭图,箭头从红点指向匹配的绿点。(注意:箭头不是按比例的。

red_idx, green_idx = zip(*match.items())
def show_match_subset(start_idx, length):
    end_idx = start_idx + length
    plot_two(green_points_np[np.array(green_idx)][start_idx:end_idx], red_points_np[np.array(red_idx)][start_idx:end_idx])
    plt.show()
    red_xy = red_points_np[np.array(red_idx)][start_idx:end_idx]
    red_drift_direction = drift[np.array(red_idx)][start_idx:end_idx]
    plt.quiver(red_xy[:, 0], red_xy[:, 1], red_drift_direction[:, 0], red_drift_direction[:, 1])
    
show_subset(0, 1000)

图:
x1c 0d1x

匹配

下面是我找到的匹配项的副本。它是JSON格式的,其中的键表示红色点文件中的点的索引,值表示绿点文件中的点的索引。https://pastebin.com/SBezpstu

相关问题