用于过滤pyspark中的值的函数

rwqw0loc  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(482)

我试图在pyspark中运行一个for循环,该循环需要为算法过滤变量。
下面是我的Dataframedfïu prods的一个示例:

+----------+--------------------+--------------------+
|ID        |        NAME        |           TYPE     |
+----------+--------------------+--------------------+
|    7983  |SNEAKERS 01         |            Sneakers|
|    7034  |SHIRT 13            |               Shirt|
|    3360  |SHORTS 15           |               Short|

我想迭代一个id列表,从算法中获得匹配,然后过滤产品的类型。
我创建了一个函数来获取以下类型:

def get_type(ID_PROD):
    return [row[0] for row in df_prods.filter(df_prods.ID == ID_PROD).select("TYPE").collect()]

想要它回来:

print(get_type(7983))
Sneakers

但我发现两个问题:
1-这样做需要很长时间(比我在python上做类似的事情要长)
2-它返回一个字符串数组类型:['sneakers'],当我尝试过滤产品时,会发生以下情况:

type = get_type(7983)
df_prods.filter(df_prods.type == type)
java.lang.RuntimeException: Unsupported literal type class java.util.ArrayList [Sneakers]

有人知道在Pypark上更好的方法吗?
事先非常感谢。我学习Pypark很困难。

wtlkbnrh

wtlkbnrh1#

对你的功能稍作调整。这将从筛选后找到的第一条记录返回目标列的实际字符串。

from pyspark.sql.functions import col

def get_type(ID_PROD):
  return df.filter(col("ID") == ID_PROD).select("TYPE").collect()[0]["TYPE"]

type = get_type(7983)
df_prods.filter(col("TYPE") == type) # works

我发现使用 col("colname") 更具可读性。
关于您提到的性能问题,如果没有更多的细节(例如,检查数据和应用程序的其余部分),我真的不能不说。试试这个语法,告诉我性能是否提高了。

相关问题