pandas pands_udf,以pd.Series和其他对象作为参数

b5lpy0ml  于 2023-01-15  发布在  其他
关注(0)|答案(1)|浏览(123)

我在创建一个Pandas UDF时遇到了麻烦,该UDF基于底层Spark Dataframe的同一行中的值对pd系列执行计算。
然而,最直接的解决方案似乎并不支持Pandas on Spark API:
下面是一个非常简单的例子

from pyspark.sql.types import IntegerType

import pyspark.sql.functions as F
import pandas as pd

@F.pandas_udf(IntegerType())
def addition(arr: pd.Series, addition: int) -> pd.Series:
  return arr.add(addition)

df = spark.createDataFrame([([1,2,3],10),([4,5,6],20)],["array","addition"])
df.show()

df.withColumn("added", addition(F.col("array"),F.col("addition")))

在udf定义行抛出以下异常

NotImplementedError: Unsupported signature: (arr: pandas.core.series.Series, addition: int) -> pandas.core.series.Series.

我处理这个问题的方式是错误的吗?我可以在原生PySpark中重新实现整个“加法”函数,但我所谈论的真实的函数非常复杂,这意味着大量的返工。

34gzjxbg

34gzjxbg1#

加载示例,添加import array

from pyspark.sql.types as T
import pyspark.sql.functions as F
import pandas as pd
from array import array

df = spark.createDataFrame([([1,2,3],10),([4,5,6],20)],["array","addition"])
df.show(truncate=False)
print(df.schema.fields)

答案是,

+---------+--------+
|    array|addition|
+---------+--------+
|[1, 2, 3]|      10|
|[4, 5, 6]|      20|
+---------+--------+

[StructField('array', ArrayType(LongType(), True), True), StructField('addition', LongType(), True)]

如果你必须使用Pandas函数来完成你的任务,这里有一个在PySpark UDF中使用Pandas函数的解决方案。

  • Spark DF arr列为ArrayType,转换为Pandas系列
  • 应用Pandas功能
  • 然后,将Pandas系列转换回阵列
@F.udf(T.ArrayType(T.LongType()))
def addition_pd(arr, addition):
    pd_arr = pd.Series(arr)
    added = pd_arr.add(addition)
    return array("l", added)

df = df.withColumn("added", addition_pd(F.col("array"),F.col("addition")))
df.show(truncate=False)
print(df.schema.fields)

返回

+---------+--------+------------+
|array    |addition|added       |
+---------+--------+------------+
|[1, 2, 3]|10      |[11, 12, 13]|
|[4, 5, 6]|20      |[24, 25, 26]|
+---------+--------+------------+

[StructField('array', ArrayType(LongType(), True), True), StructField('addition', LongType(), True), StructField('added', ArrayType(LongType(), True), True)]

然而,值得说明的是,在可能的情况下,建议使用PySpark函数而不是PySpark UDF(参见here

相关问题