numpy 如何找到将产生所需数量的数组元素的阈值

8yparm6h  于 2023-10-19  发布在  其他
关注(0)|答案(4)|浏览(96)

给定一个数字数组和一个目标计数,我想找到阈值,使得高于它的元素的数量等于目标(或尽可能接近)。
比如

arr = np.random.rand(100)
target = 80
for i in range(100):
    t = i * 0.01
    if (arr > t).sum() < target: break
print(t)

然而,这不是有效的,也不是很精确,也许有人已经解决了这个问题。

hpxqektj

hpxqektj1#

虽然建议的二分搜索相当快,但这里有一个用于相对较小的数组的单行替代方案(注意,中间计算形成矩阵arr.size x arr.size)。它一次比较初始arr和阈值数组,以找到绝对差和“最接近”阈值的位置:

target = 80
t_arr = np.arange(0, 1, 0.01)
t = t_arr[np.abs((arr > t_arr[:, None]).sum(1) - 80).argmin()]
ui7jx7zq

ui7jx7zq2#

你可以尝试对你的值进行二进制搜索(首先创建阈值数组):

def binary_search(arr, target, thresholdarr):
    low = 0
    high = len(thresholdarr) - 1
    mid = 0

    while low <= high:
        mid = (high + low) // 2

        s1 = (arr > thresholdarr[mid]).sum()
        s2 = (arr > thresholdarr[mid - 1]).sum()

        if s1 < target and s2 >= target:
            return thresholdarr[mid]
        elif s2 < target:
            high = mid - 1
        else:
            low = mid + 1

    # If we reach here, then the element was not present
    # return last element of thresholdarray
    return thresholdarr[high]

np.random.seed(42)
arr = np.random.rand(100)

thresholds = np.arange(0, 1, 0.01)
x = binary_search(arr, 80, thresholds)
print(x)

图纸:

0.16

下面是使用perfplot的快速基准测试:

import perfplot
import numpy as np

def search_roman_perekhrest(arr, target, t_arr):
    return t_arr[np.abs((arr > t_arr[:, None]).sum(1) - target).argmin()]

def normal_search(arr, target):
    for i in range(100):
        t = i * 0.01
        if (arr > t).sum() < target:
            break
    return t

def binary_search(arr, target, thresholdarr):
    low = 0
    high = len(thresholdarr) - 1
    mid = 0

    while low <= high:
        mid = (high + low) // 2

        s1 = (arr > thresholdarr[mid]).sum()
        s2 = (arr > thresholdarr[mid - 1]).sum()

        if s1 < target and s2 >= target:
            return thresholdarr[mid]
        elif s2 < target:
            high = mid - 1
        else:
            low = mid + 1

    # If we reach here, then the element was not present
    # return last element of thresholdarray
    return thresholdarr[high]

np.random.seed(42)
thresholds = np.arange(0, 1, 0.01)

perfplot.show(
    setup=lambda n: np.random.rand(n),
    kernels=[
        lambda arr: binary_search(arr, 80, thresholds),
        lambda arr: normal_search(arr, 80),
        lambda arr: search_roman_perekhrest(arr, 80, thresholds),
    ],
    labels=["bin_search", "normal_search", "search_roman_perekhrest"],
    n_range=[2**k for k in range(7, 16)],
    xlabel="N",
    logx=True,
    logy=True,
    equality_check=None,
)

输出量:

h7appiyu

h7appiyu3#

你当前的实现看起来像是O(m*n)复杂度,其中n是数组的长度,m是你迭代的阈值候选的长度。
对数组进行排序并从末尾选择target元素可能会更有效地完成这项工作。
但如果你的数组可以包含相等的值(否则你可以总是选择一个精确的阈值)复杂的开始。你的例子很可能不是这种情况,但我不确定你的实际数据是什么样的。
因此,可以通过搜索相邻值并检查它们是否比初始候选者更适合来进一步继续从最后进行排序和挑选。
下面是我的代码:

def get_threshold(array, target):
    sorted_array = sorted(array)[::-1]
    candidate = sorted_array[target]
    
    next_val_idx = target
    prev_val_idx = target
    
    while prev_val_idx < len(array) and sorted_array[prev_val_idx] == candidate:
        prev_val_idx += 1
        
    while next_val_idx >= 0 and sorted_array[next_val_idx] == candidate:
        next_val_idx -= 1
        
    candidates = [candidate]
    if next_val_idx >= 0:
        candidates.append(sorted_array[next_val_idx])
    if prev_val_idx < len(sorted_array):
        candidates.append(sorted_array[prev_val_idx])
        
    return min(
        candidates,
        key=lambda t: abs(target - (array > t).sum())
    )

get_threshold(np.array([4, 5, 2, 3, 3, 4, 4, 4]), 2)

输出量:

4

然而,如果你不期望相等的元素-可能np.quantile会工作?

np.quantile(arr, q=1 - target / len(arr))
vltsax25

vltsax254#

感谢@Andrej建议使用二进制搜索。这让我想到了其他各种根查找方法,比如Newtons,我找到了这个wonderful post和一些示例代码。最后,我写了自己的。

def find_root_binary(fn, x0:float = 0, initial_stepsize:float = 1, 
        max_step:float = 1e-7, epsilon:float = 1e-7, max_iter:int = 100):
    """ 
    Given a function, y = f(x) find the root using binary search

    :param function fn: a function that maps float to float
    :param float x0: initial guess, defaults to 0
    :param float initial_stepsize: defaults to 1
    :param float max_step: quit when step size is below this
    :param float epsilon: quit when abs(y) is below this
    :param float max_iter: quit after this many iterations
    """
    y0 = fn(x0)
    if abs(y0) <= epsilon: 
        return x0
    dx = initial_stepsize
    for i in range(max_iter):
        x1 = x0 + dx
        y1 = fn(x1)
        print(i,'\t',y1)
        if abs(y1) <= epsilon or abs(dx) <= max_step: 
            break
        if (y0 < 0) != (y1 < 0):     # zoom in
            dx *= 0.5
            if abs(y0) > abs(y1):  # change direction
                x0 = x1
                y0 = y1
                dx = -dx
        elif (y0 < 0) != (y0 <= y1):  # change direction
            dx = -dx
        else:          # keep going same direction
            y0 = y1
            x0 = x1
            dx *= 1.5  # increase step size??
    return x1

试试看

arr = np.random.rand(10000)
def fn1(x): return (arr - x > 0).sum() - 8000
x = find_root_binary(fn1, epsilon=10)
print(f'When {x=}, then {(arr > x).sum()=}')

输出结果是

0        -8000
1        -3043
2        -467
3        797
4        165
5        -152
6        17
7        -63
8        -28
9        3
When x=0.205078125, then (arr > x).sum()=8003

这似乎很快就找到了根源。

相关问题