(PySpark)创建一个新的数组列,其中包含列表列和静态列表的二进制比较结果

pu82cl6c  于 2022-11-01  发布在  Spark
关注(0)|答案(1)|浏览(145)

场景

我有一个包含以下数据的 Dataframe :

import pandas as pd
from pyspark.sql.types import ArrayType, StringType, IntegerType, FloatType, StructType, StructField
import pyspark.sql.functions as F

a = [1,2,3]
b = [['a', 'b', 'c'], ['d', 'e', 'f'], ['g', 'h', 'i']]

df = pd.DataFrame({
    'id': a,
    'list1': b,
})

df=spark.createDataFrame(df) 
df.printSchema()
df.show()

+---+---------+
| id|    list1|
+---+---------+
|  1|[a, b, c]|
|  2|[d, e, f]|
|  3|[g, h, i]|
+---+---------+

我还有一个包含以下值的静态列表

list2 = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']

"我想做的事"
我想将list2的每个值与数据中list1的每个值进行比较,并构建一个0/1值数组,其中1表示list2的值是否存在于list1中。
产生的输出应如下所示:

+---+-----------+-----------------------------+
| id|    list1  |   result                    |
+---+-----------+-----------------------------+
|  1| [a, b, c] | [1, 1, 1, 0, 0, 0, 0, 0, 0] |
|  2| [d, e, f] | [0, 0, 0, 1, 1, 1, 0, 0, 0] |
|  3| [g, h, i] | [0, 0, 0, 0, 0, 0, 1, 1, 1] |
+---+-----------+-----------------------------+

我需要这种格式的结果,因为我最终要将result数组乘以一个缩放因子。
"我的尝试"


# Insert the new_list into the dataframe

df = df.withColumn("list2", F.array([F.lit(x) for x in new_list]))

# Get the result arrays

differencer = F.udf(lambda list1, list2: F.array([1 if x in list1 else 0 for x in list2]), ArrayType(IntegerType()))

df = df.withColumn('result', differencer('list1', 'list2'))

df.show()

但是,我得到以下错误:

An error was encountered:
An error occurred while calling o151.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 11.0 failed 4 times, most recent failure: Lost task 0.3 in stage 11.0 (TID 287) (ip-10-0-0-142.ec2.internal executor 8): java.lang.RuntimeException: Failed to run command: /usr/bin/virtualenv -p python3 --system-site-packages virtualenv_application_1665327460183_0007_0
    at org.apache.spark.api.python.VirtualEnvFactory.execCommand(VirtualEnvFactory.scala:120)
    at org.apache.spark.api.python.VirtualEnvFactory.setupVirtualEnv(VirtualEnvFactory.scala:78)
    at org.apache.spark.api.python.PythonWorkerFactory.<init>(PythonWorkerFactory.scala:94)
    at org.apache.spark.SparkEnv.$anonfun$createPythonWorker$1(SparkEnv.scala:125)
    at scala.collection.mutable.HashMap.getOrElseUpdate(HashMap.scala:86)
    at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:125)
    at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:162)
    at org.apache.spark.sql.execution.python.BatchEvalPythonExec.evaluate(BatchEvalPythonExec.scala:81)
    at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:130)
    at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:863)
    at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:863)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:337)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:133)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:506)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1474)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:509)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:750)

我已经尝试了几十次迭代和方法,但实际上我所做的每一件事都导致了上述错误。
我如何才能让它工作呢?理想情况下,在运行比较之前,不必将list2插入 Dataframe 。
谢谢

ckocjqey

ckocjqey1#

其思想是将list2作为额外列添加到 Dataframe ,然后使用transform检查新添加列的每个元素是否是list1列中的part of the array

from pyspark.sql import functions as F

df.withColumn("result", F.array(*map(F.lit, list2))) \
    .withColumn("result", F.transform("result", lambda v: F.array_contains(F.col("list1"), v).cast("int"))) \
    .show(truncate=False)

输出量:

+---+---------+---------------------------+
|id |list1    |result                     |
+---+---------+---------------------------+
|1  |[a, b, c]|[1, 1, 1, 0, 0, 0, 0, 0, 0]|
|2  |[d, e, f]|[0, 0, 0, 1, 1, 1, 0, 0, 0]|
|3  |[g, h, i]|[0, 0, 0, 0, 0, 0, 1, 1, 1]|
+---+---------+---------------------------+

使用内置函数transform比udf更快,因为它避免了udfs附带的overhead

相关问题