numpy 如果小于或等于前一个索引,则将索引i处的值加+1,否则不执行任何操作

ars1skjm  于 12个月前  发布在  其他
关注(0)|答案(3)|浏览(109)

我尝试使用numpy内置函数优化这段代码:

result = [test[0]]
for x in range(1, len(test)):
    if test[x] <= result[-1]:
        result.append(result[-1]+1)
    else:
        result.append(test[x])
print(result)

字符串
上面的代码循环遍历数组,检查前一个值是否等于或上级当前值,如果是,则加+1。否则,它什么也不做。
它是递归的(当前值取决于先前计算的值)。
对于test = np.array([0, 0, 0, 1, 4, 15, 16, 16, 16, 17]),我希望得到[0, 1, 2, 3, 4, 15, 16, 17, 18, 19]
有没有更好的方法呢?我需要多次使用这个方法来处理非常大的数组(>=10M长度)。

zi8p0yeb

zi8p0yeb1#

我已经测试了不同的方法以下Onyambu评论:

import numpy as np
import pandas as pd
from timeit import timeit
import itertools

# Original loop-based version
def adjust_array_if(test):
    result = [test[0]]
    for i in range(1, len(test)):
        if test[i] <= result[-1]:
            result.append(result[-1]+1)
        else:
            result.append(test[i])
    return np.array(result)

# Max loop-based version
def adjust_array_max(test):
    result = [test[0]]
    for i in range(1, len(test)):
        result.append(max(test[i], result[-1] + 1))
    return np.array(result)

# Version using itertools
def adjust_array_itertools(test):
    result = np.fromiter(itertools.accumulate(test, lambda x, y: max(x + 1, y)), dtype=np.int32)
    return result

# Set up the DataFrame
functions = pd.Index(['adjust_array_if', 'adjust_array_max', 'adjust_array_itertools'], name='function')
lengths = np.arange(100, 11000, 1000)
results = pd.DataFrame(index=lengths, columns=functions)

# Time the functions
for i in lengths:
    a = np.random.randint(1, i, size=i)
    for j in functions:
        stmt = f"{j}(a)"
        setup = f"from __main__ import {j}, a"
        timing = timeit(stmt, setup, number=10)
        results.at[i, j] = timing

# Plot the results
results.plot()

字符串
结果如下:

在这种情况下,更简单更好。

cyvaqqii

cyvaqqii2#

怎麽样?

import numpy as np

def adjust_array(test):
    result = [test[0]]
    for i in range(1, len(test)):
        result.append(max(test[i], result[-1] + 1))
    return np.array(result)

# Test the function
test = np.array([0, 0, 0, 1, 4, 15, 16, 16, 16, 17])
result = adjust_array(test)
print(result)

字符串
这样可以确保,如果输入数组中的相应元素不大于前一个元素,则结果数组中的每个元素都严格大于前一个元素。
输出量:

>>> [ 0  1  2  3  4 15 16 17 18 19]

yvt65v4c

yvt65v4c3#

一些更快的解决方案基于您的原始,使用您的答案的基准:


的数据
第一个在开始时只添加了test = test.tolist()。因为Python在自己的int上比在NumPy int上工作得更快。另外两个避免了索引并使用了列表解析。

def adjust_array_Kelly(test):
    test = test.tolist()
    result = [test[0]]
    for i in range(1, len(test)):
        if test[i] <= result[-1]:
            result.append(result[-1]+1)
        else:
            result.append(test[i])
    return np.array(result)

def adjust_array_Kelly2(test):
    last = float('-inf')
    return np.array([
        last := last + 1 if last >= t else t
        for t in test.tolist()
    ])

def adjust_array_Kelly3(test):
    return np.array([
        last
        for last in [float('-inf')]
        for t in test.tolist()
        for last in [last + 1 if last >= t else t]
    ])

字符串

相关问题