pyspark:根据一个 Dataframe 中的数组值过滤另一个 Dataframe 中的值

ff29svar  于 2022-11-01  发布在  Spark
关注(0)|答案(1)|浏览(244)

我有一个pypsark Dataframe ,如下所示:

|                name|segment_list|rung_list  |
+--------------------+------------+-----------+
|   Campaign 1       | [1.0,  5.0]|  [L2,  L3]|
|   Campaign 1       |       [1.1]|       [L1]|
|   Campaign 2       |       [1.2]|       [L2]|
|   Campaign 2       |       [1.1]|  [L4,  L5]|
+--------------------+------------+-----------+

我有另一个pyspark Dataframe ,其中包含每个客户的段和梯级:

+-----------+---------------+---------+
|customer_id|     segment   |rung     |
+-----------+---------------+---------+
|  124001823|            1.0|       L2|
|  166001989|            5.0|       L2|
|  768002266|            1.1|       L1|
+-----------+---------------+---------+

我想要的是一个最终输出,它根据细分和梯级列表计算出客户。最终输出应该如下所示:

|                name|customer_id |   
+--------------------+------------+
|   Campaign 1       | 124001823  | 
|   Campaign 1       | 166001989  | 
|   Campaign 1       | 768002266  | 
+--------------------+------------+

我试过使用udf,但是这种方法不太管用。我希望避免在collect操作中使用for循环,或者逐行进行操作。所以我主要是在name列上寻找groupby操作。
因此,我希望有一种更好的方法来执行以下操作:

for row in x.collect():
    y = eligible.filter(eligible.segment.isin(row['segment_list'])).filter(eligible.rung.isin(row['rung_list']))
7vhp5slm

7vhp5slm1#

您可以尝试使用array_contains作为连接条件。
下面是一个示例

data1_sdf. \
    join(data2_sdf, 
         func.expr('array_contains(segment_list, segment)') & func.expr('array_contains(rung_list, rung)'), 
         'left'
         ). \
    select('name', 'customer_id'). \
    dropDuplicates(). \
    show(truncate=False)

# +----------+-----------+

# |name      |customer_id|

# +----------+-----------+

# |Campaign 1|166001989  |

# |Campaign 1|124001823  |

# |Campaign 1|768002266  |

# |Campaign 2|null       |

# +----------+-----------

粘贴Spark生成的查询计划

== Parsed Logical Plan ==
Deduplicate [name#123, customer_id#129]
+- Project [name#123, customer_id#129]
   +- Join LeftOuter, (array_contains(segment_list#124, segment#130) AND array_contains(rung_list#125, rung#131))
      :- LogicalRDD [name#123, segment_list#124, rung_list#125], false
      +- LogicalRDD [customer_id#129, segment#130, rung#131], false

== Analyzed Logical Plan ==
name: string, customer_id: string
Deduplicate [name#123, customer_id#129]
+- Project [name#123, customer_id#129]
   +- Join LeftOuter, (array_contains(segment_list#124, segment#130) AND array_contains(rung_list#125, rung#131))
      :- LogicalRDD [name#123, segment_list#124, rung_list#125], false
      +- LogicalRDD [customer_id#129, segment#130, rung#131], false

== Optimized Logical Plan ==
Aggregate [name#123, customer_id#129], [name#123, customer_id#129]
+- Project [name#123, customer_id#129]
   +- Join LeftOuter, (array_contains(segment_list#124, segment#130) AND array_contains(rung_list#125, rung#131))
      :- LogicalRDD [name#123, segment_list#124, rung_list#125], false
      +- Filter (isnotnull(segment#130) AND isnotnull(rung#131))
         +- LogicalRDD [customer_id#129, segment#130, rung#131], false

== Physical Plan ==

* (4) HashAggregate(keys=[name#123, customer_id#129], functions=[], output=[name#123, customer_id#129])

+- Exchange hashpartitioning(name#123, customer_id#129, 200), ENSURE_REQUIREMENTS, [id=#267]
   +- *(3) HashAggregate(keys=[name#123, customer_id#129], functions=[], output=[name#123, customer_id#129])
      +- *(3) Project [name#123, customer_id#129]
         +- BroadcastNestedLoopJoin BuildRight, LeftOuter, (array_contains(segment_list#124, segment#130) AND array_contains(rung_list#125, rung#131))
            :- *(1) Scan ExistingRDD[name#123,segment_list#124,rung_list#125]
            +- BroadcastExchange IdentityBroadcastMode, [id=#261]
               +- *(2) Filter (isnotnull(segment#130) AND isnotnull(rung#131))
                  +- *(2) Scan ExistingRDD[customer_id#129,segment#130,rung#131]

似乎没有很好的优化,我想可以有其他的优化方法。

相关问题