使用矢量化查找Pandas Dataframe 中第一次出现的大于当前值的Price值

idfiyjo8  于 2023-01-19  发布在  其他
关注(0)|答案(1)|浏览(132)

让我们以Pandas Dataframe 为例,它有两列["date"]和["price"]:['date']总是升序['price']是随机的

df = pd.DataFrame({
'date':['01/01/2019','01/02/2019','01/03/2019','01/04/2019','01/05/2019','01/06/2019','01/07/2019','01/08/2019','01/09/2019','01/10/2019'],
'price': [10,2,5,4,12,8,9,19,12,3]
})

目标是再添加两列['next_date']包含第一次出现价格高于当前价格的日期['next_price']包含第一次出现价格高于当前价格的价格
像这样

date  price   next_date next_price
0  01/01/2019     10  01/05/2019         12
1  01/02/2019      2  01/03/2019          5
2  01/03/2019      5  01/05/2019         12
3  01/04/2019      4  01/05/2019         12
4  01/05/2019     12  01/08/2019         19
5  01/06/2019      8  01/07/2019          9
6  01/07/2019      9  01/08/2019         19
7  01/08/2019     19         NaN        NaN
8  01/09/2019     12         NaN        NaN
9  01/10/2019      3         NaN        NaN

我已经测试了一些解决方案,它们确实达到了我的要求,但是性能非常差,实际的df有超过一百万行
以下是我的测试溶液:
使用Pandasql

result = sqldf("SELECT l.date, l.price, min(r.date) as next_date from df as l left join df as r on (r.date > l.date and r.price > l.price) group by l.date, l.price  order by l.date")
result=pd.merge(result ,df, left_on='next_date', right_on='date', suffixes=('', '_next'), how='left')
print(result)

使用Pandas到SQLite

df.to_sql('df', conn, index=False)
qry = "SELECT l.date, l.price, min(r.date) as next_date from df as l left join df as r on (r.date > l.date and r.price > l.price) group by l.date, l.price  order by l.date "
result = pd.read_sql_query(qry, conn)
result=pd.merge(result ,df, left_on='next_date', right_on='date', suffixes=('', '_next'), how='left')
print(result)

使用应用

def find_next_price(row):
    mask = (df['price'] > row['price']) & (df['date'] > row['date'])
    if len(df[mask]):
        return df[mask]['date'].iloc[0], df[mask]['price'].iloc[0]
    else:
        return np.nan, np.nan

df[['next_date', 'next_price']] = list(df.apply(find_next_price, axis=1))
print(df)

其中一些解决方案在50000行DF上开始失败,而我需要在1000000行DF上执行此任务
注:这里有一个非常类似的问题:而且https://stackoverflow.com/questions/72047646/python-pandas-add-column-containing-first-index-where-future-column-value-is-gr性能较差

62o28rlo

62o28rlo1#

由于您需要在大量行(1M+)上执行此任务,传统的numpy方法可能不可行,尤其是在内存有限的情况下。这里我将介绍一种使用基本算法计算的函数方法,您可以使用numba's即时编译器编译此函数,以实现类似C的速度:

import numba

@numba.njit
def argmax(price: np.ndarray):
    for i in range(len(price)):
        idx = -1
        for j in range(i + 1, len(price)):
            if price[i] < price[j]:
                idx = j
                break

        yield idx
        idx = -1

i = np.array(list(argmax(df['price'].values)))
m = i != -1 # index is -1 if there's no next greater price

df.loc[m, 'next_date'] = df['date'].values[i[m]]
df.loc[m, 'next_price'] = df['price'].values[i[m]]

结果

date  price   next_date  next_price
0  01/01/2019     10  01/05/2019        12.0
1  01/02/2019      2  01/03/2019         5.0
2  01/03/2019      5  01/05/2019        12.0
3  01/04/2019      4  01/05/2019        12.0
4  01/05/2019     12  01/08/2019        19.0
5  01/06/2019      8  01/07/2019         9.0
6  01/07/2019      9  01/08/2019        19.0
7  01/08/2019     19         NaN         NaN
8  01/09/2019     12         NaN         NaN
9  01/10/2019      3         NaN         NaN

PS:在1M+行上测试解决方案。

相关问题