Groupby并从其他列pyspark中获取具有排序值的列表

pzfprimi  于 2022-11-01  发布在  Spark
关注(0)|答案(2)|浏览(152)

嘿,我有这样的 Dataframe :

+----------+----------+------------------+
|      id_A|      id_B|   Distance       | 
+----------+----------+------------------+
| 120745612| 122913167|0.6142857142857143|
|1243257970| 370926553|0.8061224489795918|
|1305652409| 253051944|0.8252427184466019|
|1350805455| 311286173|0.5789473684210527|
|1544864070| 390580289|0.7894736842105263|
| 164533143| 763751752|0.8153846153846154|
|1683553267| 787287056|0.9117647058823529|
| 175951349| 175951349|               0.0|

现在我想用id_A进行分组,并得到id_B的列表,按Distance值的升序排列。这意味着在Distance最小的地方,我想把id_B的相应值放在列表的第一位,依此类推。
预期输出(_O):

|  id_A  |   id_B  |
175951349  [175951349, 390580289, ...]
pes8fvy9

pes8fvy91#

首先使用collect_list函数构造一个struct的数组,然后使用array_sort根据struct中的Distance字段进行排序,最后使用transform函数将数组转换为您需要的格式。

df = df.groupBy('id_A').agg(
    F.expr("""
        transform(
            array_sort(
                collect_list(struct(id_B, Distance)),
                (l, r) -> case when l.Distance < r.Distance then -1 when l.Distance > r.Distance then 1 else 0 end
            ),
            x -> x.id_B
        )
    """)
)
brccelvz

brccelvz2#

为此,您可以先将id_BDistance合并为Struct,然后使用array_sort基于Distance进行排序,最后提取所需字段

数据准备

s = StringIO("""
id_A,id_B,Distance
120745612,122913167,0.6142857142857143
1243257970,370926553,0.8061224489795918
1305652409,253051944,0.8252427184466019
1350805455,311286173,0.5789473684210527
1544864070,390580289,0.7894736842105263
164533143,763751752,0.8153846153846154
1683553267,787287056,0.9117647058823529
175951349,175951349,0.0
1683553267,787287056,0.67217647058823529
1683553267,787287056,0.51236647058823529
1683553267,787287056,0.98176470588235291
""")

### I have manually added the last 3 records to demonstrate the working

df = pd.read_csv(s,delimiter=',')

sparkDF = sql.createDataFrame(df).orderBy('id_A')

sparkDF.show()

+----------+---------+------------------+
|      id_A|     id_B|          Distance|
+----------+---------+------------------+
| 120745612|122913167|0.6142857142857143|
| 164533143|763751752|0.8153846153846154|
| 175951349|175951349|               0.0|
|1243257970|370926553|0.8061224489795918|
|1305652409|253051944|0.8252427184466019|
|1350805455|311286173|0.5789473684210527|
|1544864070|390580289|0.7894736842105263|
|1683553267|787287056|0.9117647058823528|
|1683553267|787287056|0.6721764705882352|
|1683553267|787287056|0.9817647058823528|
|1683553267|787287056|0.5123664705882351|
+----------+---------+------------------+

结构-数组排序

sparkDF.groupby("id_A") \
       .agg(F.sort_array(F.collect_set(F.struct("Distance","id_B"))).alias("collected_list")) \
       .withColumn("sorted_list",F.col("collected_list.id_B")) \
       .drop("collected_list")\
       .show(truncate=False)

+----------+--------------------------------------------+
|id_A      |sorted_list                                 |
+----------+--------------------------------------------+
|120745612 |[122913167]                                 |
|164533143 |[763751752]                                 |
|175951349 |[175951349]                                 |
|1243257970|[370926553]                                 |
|1305652409|[253051944]                                 |
|1350805455|[311286173]                                 |
|1544864070|[390580289]                                 |
|1683553267|[787287056, 787287056, 787287056, 787287056]|
+----------+--------------------------------------------+

相关问题