给定一个数字数组和一个目标计数,我想找到阈值,使得高于它的元素的数量等于目标(或尽可能接近)。比如
arr = np.random.rand(100) target = 80 for i in range(100): t = i * 0.01 if (arr > t).sum() < target: break print(t)
然而,这不是有效的,也不是很精确,也许有人已经解决了这个问题。
hpxqektj1#
虽然建议的二分搜索相当快,但这里有一个用于相对较小的数组的单行替代方案(注意,中间计算形成矩阵arr.size x arr.size)。它一次比较初始arr和阈值数组,以找到绝对差和“最接近”阈值的位置:
arr.size
arr
target = 80 t_arr = np.arange(0, 1, 0.01) t = t_arr[np.abs((arr > t_arr[:, None]).sum(1) - 80).argmin()]
ui7jx7zq2#
你可以尝试对你的值进行二进制搜索(首先创建阈值数组):
def binary_search(arr, target, thresholdarr): low = 0 high = len(thresholdarr) - 1 mid = 0 while low <= high: mid = (high + low) // 2 s1 = (arr > thresholdarr[mid]).sum() s2 = (arr > thresholdarr[mid - 1]).sum() if s1 < target and s2 >= target: return thresholdarr[mid] elif s2 < target: high = mid - 1 else: low = mid + 1 # If we reach here, then the element was not present # return last element of thresholdarray return thresholdarr[high] np.random.seed(42) arr = np.random.rand(100) thresholds = np.arange(0, 1, 0.01) x = binary_search(arr, 80, thresholds) print(x)
图纸:
0.16
下面是使用perfplot的快速基准测试:
perfplot
import perfplot import numpy as np def search_roman_perekhrest(arr, target, t_arr): return t_arr[np.abs((arr > t_arr[:, None]).sum(1) - target).argmin()] def normal_search(arr, target): for i in range(100): t = i * 0.01 if (arr > t).sum() < target: break return t def binary_search(arr, target, thresholdarr): low = 0 high = len(thresholdarr) - 1 mid = 0 while low <= high: mid = (high + low) // 2 s1 = (arr > thresholdarr[mid]).sum() s2 = (arr > thresholdarr[mid - 1]).sum() if s1 < target and s2 >= target: return thresholdarr[mid] elif s2 < target: high = mid - 1 else: low = mid + 1 # If we reach here, then the element was not present # return last element of thresholdarray return thresholdarr[high] np.random.seed(42) thresholds = np.arange(0, 1, 0.01) perfplot.show( setup=lambda n: np.random.rand(n), kernels=[ lambda arr: binary_search(arr, 80, thresholds), lambda arr: normal_search(arr, 80), lambda arr: search_roman_perekhrest(arr, 80, thresholds), ], labels=["bin_search", "normal_search", "search_roman_perekhrest"], n_range=[2**k for k in range(7, 16)], xlabel="N", logx=True, logy=True, equality_check=None, )
输出量:
h7appiyu3#
你当前的实现看起来像是O(m*n)复杂度,其中n是数组的长度,m是你迭代的阈值候选的长度。对数组进行排序并从末尾选择target元素可能会更有效地完成这项工作。但如果你的数组可以包含相等的值(否则你可以总是选择一个精确的阈值)复杂的开始。你的例子很可能不是这种情况,但我不确定你的实际数据是什么样的。因此,可以通过搜索相邻值并检查它们是否比初始候选者更适合来进一步继续从最后进行排序和挑选。下面是我的代码:
n
m
target
def get_threshold(array, target): sorted_array = sorted(array)[::-1] candidate = sorted_array[target] next_val_idx = target prev_val_idx = target while prev_val_idx < len(array) and sorted_array[prev_val_idx] == candidate: prev_val_idx += 1 while next_val_idx >= 0 and sorted_array[next_val_idx] == candidate: next_val_idx -= 1 candidates = [candidate] if next_val_idx >= 0: candidates.append(sorted_array[next_val_idx]) if prev_val_idx < len(sorted_array): candidates.append(sorted_array[prev_val_idx]) return min( candidates, key=lambda t: abs(target - (array > t).sum()) ) get_threshold(np.array([4, 5, 2, 3, 3, 4, 4, 4]), 2)
4
然而,如果你不期望相等的元素-可能np.quantile会工作?
np.quantile
np.quantile(arr, q=1 - target / len(arr))
vltsax254#
感谢@Andrej建议使用二进制搜索。这让我想到了其他各种根查找方法,比如Newtons,我找到了这个wonderful post和一些示例代码。最后,我写了自己的。
def find_root_binary(fn, x0:float = 0, initial_stepsize:float = 1, max_step:float = 1e-7, epsilon:float = 1e-7, max_iter:int = 100): """ Given a function, y = f(x) find the root using binary search :param function fn: a function that maps float to float :param float x0: initial guess, defaults to 0 :param float initial_stepsize: defaults to 1 :param float max_step: quit when step size is below this :param float epsilon: quit when abs(y) is below this :param float max_iter: quit after this many iterations """ y0 = fn(x0) if abs(y0) <= epsilon: return x0 dx = initial_stepsize for i in range(max_iter): x1 = x0 + dx y1 = fn(x1) print(i,'\t',y1) if abs(y1) <= epsilon or abs(dx) <= max_step: break if (y0 < 0) != (y1 < 0): # zoom in dx *= 0.5 if abs(y0) > abs(y1): # change direction x0 = x1 y0 = y1 dx = -dx elif (y0 < 0) != (y0 <= y1): # change direction dx = -dx else: # keep going same direction y0 = y1 x0 = x1 dx *= 1.5 # increase step size?? return x1
试试看
arr = np.random.rand(10000) def fn1(x): return (arr - x > 0).sum() - 8000 x = find_root_binary(fn1, epsilon=10) print(f'When {x=}, then {(arr > x).sum()=}')
输出结果是
0 -8000 1 -3043 2 -467 3 797 4 165 5 -152 6 17 7 -63 8 -28 9 3 When x=0.205078125, then (arr > x).sum()=8003
这似乎很快就找到了根源。
4条答案
按热度按时间hpxqektj1#
虽然建议的二分搜索相当快,但这里有一个用于相对较小的数组的单行替代方案(注意,中间计算形成矩阵
arr.size
xarr.size
)。它一次比较初始arr
和阈值数组,以找到绝对差和“最接近”阈值的位置:ui7jx7zq2#
你可以尝试对你的值进行二进制搜索(首先创建阈值数组):
图纸:
下面是使用
perfplot
的快速基准测试:输出量:
h7appiyu3#
你当前的实现看起来像是O(m*n)复杂度,其中
n
是数组的长度,m
是你迭代的阈值候选的长度。对数组进行排序并从末尾选择
target
元素可能会更有效地完成这项工作。但如果你的数组可以包含相等的值(否则你可以总是选择一个精确的阈值)复杂的开始。你的例子很可能不是这种情况,但我不确定你的实际数据是什么样的。
因此,可以通过搜索相邻值并检查它们是否比初始候选者更适合来进一步继续从最后进行排序和挑选。
下面是我的代码:
输出量:
然而,如果你不期望相等的元素-可能
np.quantile
会工作?vltsax254#
感谢@Andrej建议使用二进制搜索。这让我想到了其他各种根查找方法,比如Newtons,我找到了这个wonderful post和一些示例代码。最后,我写了自己的。
试试看
输出结果是
这似乎很快就找到了根源。