如何使用pyspark中的@pandas_udf实现groupby.agg

r3i60tvu  于 12个月前  发布在  Spark
关注(0)|答案(1)|浏览(116)

我在Spark上使用pandas API。我使用groupby.agg操作。我找到了a similar issue,但解决方案不适合我。我也查了官方文件。但这些文档并没有提供足够的例子。
以下是我的样本数据:
| | O_ORDER优先级| O_ORDERPRIORITY |
| --|--|--|
| 0 |邮件|2 -高|
| 1 |船|1 -紧急|
我只想通过以下udf对L_SHIPMODE进行分组并对O_ORDERPRIORITY进行计数:

@pandas_udf(IntegerType())
def g1(x):
  return ((x == "1-URGENT") | (x == "2-HIGH")).sum()

@pandas_udf(IntegerType())
def g2(x):
  return ((x != "1-URGENT") & (x != "2-HIGH")).sum()

# tryied register, but it seems this is not the problem
# spark.udf.register('g1_udf', g1)
# spark.udf.register('g2_udf', g2)

total = jn.groupby("L_SHIPMODE", as_index=False)["O_ORDERPRIORITY"].agg({"O_ORDERPRIORITY": [g1, g2]})

我阴阳怪气:
ValueError: aggs must be a dict mapping from column name to aggregate functions (string or list of strings).
是否有关于如何在groupby.agg上使用UDF的详细示例?

bfrts1fy

bfrts1fy1#

我认为如果你在groupby之前创建一些新列(使用与你类似的示例框架),你可以完全避免使用udfs:

import pyspark.pandas as ps

jn = ps.DataFrame({
    'L_SHIPMODE': ['MAIL']*4+['SHIP']*4,
    'O_ORDERPRIORITY': [
        '1 -URGENT', '2 -HIGH', '3 -OTHER', '4 -OTHER',
        '3 -OTHER', '4 -OTHER', '3 -OTHER', '4 -OTHER',
    ]
})

    L_SHIPMODE  O_ORDERPRIORITY
0   MAIL    1 -URGENT
1   MAIL    2 -HIGH
2   MAIL    3 -OTHER
3   MAIL    4 -OTHER
4   SHIP    3 -OTHER
5   SHIP    4 -OTHER
6   SHIP    3 -OTHER
7   SHIP    4 -OTHER

jn['1_OR_2'] = jn['O_ORDERPRIORITY'].isin(["1 -URGENT", "2 -HIGH"]).astype('long')
jn['not_1_OR_2'] = (~jn['O_ORDERPRIORITY'].isin(["1 -URGENT", "2 -HIGH"])).astype('long')

jn.groupby(
    'L_SHIPMODE'
).agg({
    '1_OR_2': 'sum',
    'not_1_OR_2': 'sum'
})

测试结果:

1_OR_2  not_1_OR_2
L_SHIPMODE      
MAIL    2   2
SHIP    0   4

相关问题