计算PySpark中每行的列中不同子串的出现次数?

ippsafx7  于 2023-06-28  发布在  Spark
关注(0)|答案(2)|浏览(126)

我的数据集看起来像这样,其中有一组逗号分隔的字符串值,分别为col1col2col3是连接在一起的两列。

+===========+========+=============
|col1       |col2    |col3   
+===========+========+=============
|a,b,c,d    |a,c,d   |a,b,c,d,a,c,d 
|e,f,g      |f,g,h   |e,f,g,f,g,h
+===========+========+=============

基本上,我尝试做的是获取col3中所有用逗号分隔的值,并将每个值及其计数附加到另一列。
我尝试在col4中得到这样的输出:

+===========+========+==============+======================
|col1       |col2    |col3          |col4
+===========+========+==============+======================
|a,b,c,d    |a,c,d   |a,b,c,d,a,c,d |a: 2, b: 1, c: 2, d: 2
|e,f,g      |f,g,h   |e,f,g,f,g,h   |e: 1, f: 2, g: 2, h: 1
+===========+========+==============+======================

我已经找到了如何将列concat组合在一起以得到col3,但在得到col4时遇到了一些麻烦。这里是我离开的地方,我有点不确定从这里去哪里:

from pyspark.sql.functions import concat, countDistinct

df = df.select(concat(df.col1, df.col2).alias('col3'), '*')
df.agg(countDistinct('col3')).show()
# +--------------------+ 
# |count(DISTINCT col3)|
# +--------------------+
# |                   2|
# +--------------------+

如何动态统计col3中以逗号分隔的子字符串,并创建一个最后一列,显示数据集中所有行的每个子字符串的频率?

fnx2tebb

fnx2tebb1#

使用自定义项

这里有一个使用udfs的方法。首先是数据生成。

from pyspark.sql.types import StringType, StructType, StructField
from pyspark.sql.functions import concat_ws, udf

data = [("a,b,c,d", "a,c,d", "a,b,c,d,a,c,d"),
        ("e,f,g", "f,g,h", "e,f,g,f,g,h")
       ]
schema = StructType([
    StructField("col1",StringType(),True),
    StructField("col2",StringType(),True),
    StructField("col3",StringType(),True),
])
df = spark.createDataFrame(data=data,schema=schema)

然后使用一些原生Python函数,如Counter和json来完成任务。

from collections import Counter
import json

@udf(StringType())
def count_occurances(row):
    return json.dumps(dict(Counter(row.split(','))))

df.withColumn('concat', concat_ws(',', df.col1, df.col2, df.col3))\
  .withColumn('counts', count_occurances('concat')).show(2, False)

结果

+-------+-----+-------------+---------------------------+--------------------------------+
|col1   |col2 |col3         |concat                     |counts                          |
+-------+-----+-------------+---------------------------+--------------------------------+
|a,b,c,d|a,c,d|a,b,c,d,a,c,d|a,b,c,d,a,c,d,a,b,c,d,a,c,d|{"a": 4, "b": 2, "c": 4, "d": 4}|
|e,f,g  |f,g,h|e,f,g,f,g,h  |e,f,g,f,g,h,e,f,g,f,g,h    |{"e": 2, "f": 4, "g": 4, "h": 2}|
+-------+-----+-------------+---------------------------+--------------------------------+

使用原生pyspark函数的解决方案

这个解决方案比使用udf要复杂一些,但是由于没有udf,它的性能可能会更好。这个想法是concat三个字符串列和爆炸。为了知道每一个被分解的行是从哪里来的,我们添加了一个索引。双重分组将帮助我们得到想要的结果。最后,我们将结果连接回原始框架,以获得所需的模式。

from pyspark.sql.functions import concat_ws, monotonically_increasing_id, split, explode, collect_list
df = df.withColumn('index', monotonically_increasing_id())

df.join(
    df
  .withColumn('concat', concat_ws(',', df.col1, df.col2, df.col3))\
  .withColumn('arr_col', split('concat', ','))\
  .withColumn('explode_col', explode('arr_col'))\
  .groupBy('index', 'explode_col').count()\
  .withColumn('concat_counts', concat_ws(':', 'explode_col', 'count'))\
  .groupBy('index').agg(concat_ws(',', collect_list('concat_counts')).alias('grouped_counts')), on='index').show()

导致

+-----------+-------+-----+-------------+---------------+
|      index|   col1| col2|         col3| grouped_counts|
+-----------+-------+-----+-------------+---------------+
|42949672960|a,b,c,d|a,c,d|a,b,c,d,a,c,d|a:4,b:2,c:4,d:4|
|94489280512|  e,f,g|f,g,h|  e,f,g,f,g,h|h:2,g:4,f:4,e:2|
+-----------+-------+-----+-------------+---------------+

请注意,我们在udf部分创建的json通常比使用原生pyspark函数在grouped_counts列中创建的简单字符串更方便。

qnakjoqk

qnakjoqk2#

我建议使用原生Spark选项,使用 * 数组 * 而不是字符串。

from pyspark.sql import functions as F
df = spark.createDataFrame([("a,b,c,d", "a,c,d"), ("e,f,g", "f,g,h")], ['col1', 'col2'])

conc = F.concat(F.split('col1', ','), F.split('col2', ','))
dist = F.array_distinct(conc)
m = F.map_from_arrays(dist, F.transform(dist, lambda x: F.size(F.filter(conc, lambda y: y == x))))
df = df.withColumn('counts', m)

df.show(truncate=0)
# +-------+-----+--------------------------------+
# |col1   |col2 |counts                          |
# +-------+-----+--------------------------------+
# |a,b,c,d|a,c,d|{a -> 2, b -> 1, c -> 2, d -> 2}|
# |e,f,g  |f,g,h|{e -> 1, f -> 2, g -> 2, h -> 1}|
# +-------+-----+--------------------------------+

输出类型为map而不是string。通过这种方式,可以使用例如F.col('counts')['g']

相关问题