pandas 将DataFrame.iterrows()改为List Comprehension / Vectorization以提高性能

xnifntxz  于 2023-06-28  发布在  其他
关注(0)|答案(2)|浏览(116)

我有下面的代码来计算DataFrame中输出的平均值,数据来自XLSX文件。calculate_score()将返回float分数,例如5.12.

import pandas as pd

testset = pd.read_excel(xlsx_filename_here)
total_score = 0
num_records = 0
for index, row in testset.iterrows():
    if row['Data1'].isna() or row['Data2'].isna() or row['Data3'].isna():
        continue
    else:
        score = calculate_score([row['Data1'], row['Data2']], row['Data3'])
        total_score += score
        num_records += 1

print("Average score:", round(total_score/num_records, 2))

根据这个答案,df.iterrows()是缓慢和反模式的。我如何将上述代码更改为使用矢量化或列表解析?

    • 更新**

我在上面的例子中过度简化了calculate_score(),它实际上是使用SacreBLEU库计算一些句子的BLEU分数:

import evaluate
sacrebleu = evaluate.load("sacrebleu")

def calculate_score(ref, translation):
    return sacrebleu.compute(predictions=[translation], references=[ref])

请注意,原始代码也略有更新。如何修改calculate_score()以使用列表解析?谢谢

fjaof16o

fjaof16o1#

以下是如何使用向量化修改代码:

import pandas as pd
import numpy as np

testset = pd.read_excel(xlsx_filename_here)

valid_rows = testset['Data1'].notna() & testset['Data2'].notna()

scores = calculate_score(testset.loc[valid_rows, 'Data1'], testset.loc[valid_rows, 'Data2'])

average_score = np.mean(scores)

print("Average score:", round(average_score, 2))
ev7lccsx

ev7lccsx2#

您必须修改calculate_score的实现,以接受两个Series作为参数(或两列的一个DataFrame),而不是两个标量值:

def calculate_score(sr1, sr2):
    out = sr1 / sr2
    return out  # out is a Series

# Hide unwanted rows
cols = ['Data1', 'Data2']
m = testset[cols].notna().all(axis=1)

# Compute score with vectorized function
score = calculate_score(testset.loc[m, cols[0]], testset.loc[m, cols[1]])

# Stats
total_score, average_score = score.agg(['sum', 'mean'])

输出:

>>> score
0    0.333333
1    0.142857
3    2.000000
5    0.500000
6    1.000000
7    0.000000
9    0.375000
dtype: float64

>>> total_score
4.351190476190476

>>> average_score
0.6215986394557823

输入:

>>> testset
   Data1  Data2
0    2.0    6.0
1    1.0    7.0
2    NaN    4.0
3    4.0    2.0
4    4.0    NaN
5    4.0    8.0
6    1.0    1.0
7    0.0    5.0
8    NaN    5.0
9    3.0    8.0

相关问题