我有一个有1000万个条目的numpy数组。该阵列具有5列,其中前4列指定x、y、z和t的坐标。最后一列指定这些点中每个点的标量值。现在,对于这个数据集中的每个点,我想查询由x_min,x_max,y_min,y_max,z_min,z_max,t_min和t_max指定的n-D边界框内的点。对于边界框内的点,计算值的中位数和标准差并存储它。
需要注意的几点:
1.每个点的边界框规范将是不同的,对于一些点,它可能是小框,对于一些点,它可能是大框。
1.注意,x,y,z,t;所有4个轴具有不同的分辨率和不同的比例。
1.该阵列是沿着第一轴而不是其他三个轴的有序阵列。
现在,我已经利用了点3的信息来减少搜索空间,但我想创建某种树数据结构,它可以立即获取边界框内的点,因为我需要为每行执行1000万次查询(考虑以每个点为中心的行进边界框)。
有了上面的信息,我已经尝试实现了下面的伪代码
import numpy as np
import time
num_points = 10_000_000
df_us = np.random.rand(num_points, 5)
df = df_us[df_us[:,0].argsort()] #sort along the first axis to mimick the real data feature
# for this expt ignoring the fact that each axis has different limits and different sampling resolution
xmin = np.random.rand(num_points)
xmax = xmin + 0.2
ymin = np.random.rand(num_points)
ymax = ymin + 0.4
zmin = np.random.rand(num_points)
zmax = zmin + 0.4
tmin = np.random.rand(num_points)
tmax = tmin + 0.4
def bbox_stat(xr_df, xmin, xmax, ymin, ymax, zmin, zmax, tmin, tmax):
for i in range(len(xmin)):
# use the fact the first axis is presorted in the dataset
x_si = np.searchsorted(xr_df[:,0], xmin[i], side='left')
x_ei = np.searchsorted(xr_df[:,0], xmax[i], side='right')
y_conditions = (xr_df[x_si:x_ei, 1] >= ymin[i]) & (xr_df[x_si:x_ei, 1] <= ymax[i])
z_conditions = (xr_df[x_si:x_ei, 2] >= zmin[i]) & (xr_df[x_si:x_ei, 2] <= zmax[i])
t_conditions = (xr_df[x_si:x_ei, 3] >= tmin[i]) & (xr_df[x_si:x_ei, 3] <= tmax[i])
conditions = y_conditions & z_conditions & t_conditions
med_i = xr_df[x_si:x_ei,:][conditions,3].median() # use this fillup an array mean
std_i = xr_df[x_si:x_ei,:][conditions,3].std() # use this to fillup an array std
arguments = []
num_cores = 50
orig_len = df.shape[0]
unit_len = orig_len//num_cores
for i in range(num_cores):
arguments.append((df, xmin[i*uni_len:(i+1)*uni_len], xmax[i*uni_len:(i+1)*uni_len], ymin[i*uni_len:(i+1)*uni_len], ymax[i*uni_len:(i+1)*uni_len], zmin[i*uni_len:(i+1)*uni_len], zmax[i*uni_len:(i+1)*uni_len], tmin[i*uni_len:(i+1)*uni_len], tmax[i*uni_len:(i+1)*uni_len]))
start_time = time.time()
pool = Pool(processes=num_cores)
pool.starmap(bbox_stat, arguments)
pool. Close()
pool. Join()
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
1条答案
按热度按时间7lrncoxx1#
更新解决方案
我最初的解决方案(保存在下面供后人参考)是不正确的,因为我误解了。也就是说,这个问题仍然存在效率问题(由于其规模),尽管希望我的答案将是一个很好的起点。与我最初的解决方案类似,您可以使用kd树来加快搜索速度。使用scipy
KDTree
,您可以为每个方向(x,y,z)和t创建一个kd树,然后使用query_ball_point
查找距离长方体中心给定距离(定义为xmin
,xmax
,ymin
,...)内的点。使用使用每个范围找到的点的交集仅找到满足所有条件的点。我只能在我的机器上连续运行100,000个点,这花了12分钟。1000万可能是棘手的(至少在串行)。
这段代码目前的问题是,随着点的数量增加,kd树的查询速度越来越慢,这是针对4棵树完成的。现在,我将把这个解决方案作为起点,也许我或其他人会改进它。
旧(不正确)解决方案
这个问题的问题,我在你的last post中提到过,是你不可能存储你想要的所有信息。最后,您需要在感兴趣的点周围的任何长方体内的点的中位数和标准差。假设median和std值为1字节,则您将拥有10,000,000 x 10,000,000个数据点,每个数据点对应于100Tb数组。这是假设值是布尔值,而实际上它们是浮点数(假设最小的
np.float16
是2个字节,每个数组是200Tb)。假设你有一个更小、更易管理的问题,你可以使用kd树来加速搜索。使用scipy
KDTree
,可以将query_ball_point
方法与p=np.inf
一起使用,以查询多维数据集区域。这可以用来快速减少采样点的数量,首先查询最大可能的立方体。有了这个结果,你可以继续进行正常的检查,并从那里计算中位数和标准差。现在,这实际上可能运行10,000,000个点,这取决于点的实际间隔。你展示了使用从0到1的随机点分布,这意味着有10,000,000个点,在给定的长方体中仍然会有大量的点。如果这些点分布得更广,树查询将返回更少的点(希望是显著更少)。但是你仍然会遇到存储信息的问题。但是如果数据足够分散,你将有稀疏的中位数和标准矩阵,所以你可以使用
scipy.sparse
来减少存储需求。