为什么pd.DataFrame.stack与numpy.flatten相比如此之慢,以及如何加速它?

tez616oj  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(120)

pandas.DataFrame.stack非常慢。pandas.stack的平均成本是numpy.flatten的6倍。怎么这么慢?有没有办法加快速度?

> df.shape    # dtype = float64
  (2578, 809)
> %timeit df.stack()
  42 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
> %timeit df.values.flatten()
  7.35 ms ± 17.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

字符串

gopyfrb3

gopyfrb31#

pd.stack较慢,因为它不仅通过blockmanager在DataFrame中存储为np.ndarray的数据块上操作,而且还需要处理索引。特别是stack需要为生成的附加索引列创建一个多索引。
pandas专注于在数据操作方面提供极大的灵活性,而numpy则提供了一个低级别的抽象,而不是高效的内存数据存储。
pd.DataFrames具有行和列标签,这需要在stack()过程中进行额外的索引和标签操作。
此外,所有标签和索引也需要额外的内存,内存访问需要额外的时间,特别是在处理大型数据集时。
pd.stack的性能确实比np.flatten差很多:

def generate_dataframe(N):
    data = np.random.rand(N, N)
    df = pd.DataFrame(data)
    return df
# List of num_directories values
data_size = [10, 50, 100, 200, 500, 1000]
def stack(df):
    return df.stack()
def flatten(df):
    return df.values.flatten()

approaches = [
    stack,
    flatten
]
run_performance_comparison(approaches, data_size,setup=generate_dataframe)

字符串


的数据
分析代码:

import timeit
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
from typing import List, Dict, Callable, ContextManager

from contextlib import contextmanager

@contextmanager
def data_provider(data_size, setup=lambda N: N, teardown=lambda: None):
    data = setup(data_size)
    yield data
    teardown()

def run_performance_comparison(approaches: List[Callable], data_size: List[int],
                               setup=lambda N: N, teardown=lambda: None, number_of_repetitions=5, title='N'):
    approach_times: Dict[Callable, List[float]] = {approach: [] for approach in approaches}

    for N in data_size:
        with data_provider(N, setup, teardown) as data:
            for approach in approaches:
                approach_time = timeit.timeit(lambda: approach(data), number=number_of_repetitions)
                approach_times[approach].append(approach_time)

    for approach in approaches:
        plt.plot(data_size, approach_times[approach], label=approach.__name__)

    plt.xlabel(title)
    plt.ylabel('Execution Time (seconds)')
    plt.title('Performance Comparison')
    plt.legend()
    plt.show()


如果您知道您不需要为您创建pd.MultiIndex的额外便利,那么就坚持使用numpy实现。它在C或FORTRAN中有效地实现。

相关问题