寻找pyspark数组的逆数组

mbzjlibv  于 2021-05-29  发布在  Spark
关注(0)|答案(3)|浏览(347)

我有以下格式化的输入Dataframe:

from pyspark.sql import SparkSession
import pyspark.sql.functions as F

spark = SparkSession.builder.master("local").getOrCreate()

input_df = spark.createDataFrame(
    [
        ('Alice;Bob;Carol',),
        ('12;13;14',),
        ('5;;7',),
        ('1;;3',),
        (';;3',)
    ],
    ['data']
)

input_df.show()

# +---------------+

# |           data|

# +---------------+

# |Alice;Bob;Carol|

# |       12;13;14|

# |           5;;7|

# |           1;;3|

# |            ;;3|

# +---------------+

实际输入是一个分号分隔的csv文件,其中一列包含一个人的值。每个人可以有不同数量的值。在这里,爱丽丝有3个值,鲍勃只有一个值,卡罗尔有4个值。
我想在pyspark中将其转换为一个输出Dataframe,该Dataframe为每个人保存一个数组,在本例中,输出为:

result = spark.createDataFrame(
    [
        ("Alice", [12, 5, 1]),
        ("Bob", [13,]),
        ("Carol", [14, 7, 3, 3])
    ],
    ['name', 'values']
)

result.show()

# +-----+-------------+

# | name|       values|

# +-----+-------------+

# |Alice|   [12, 5, 1]|

# |  Bob|         [13]|

# |Carol|[14, 7, 3, 3]|

# +-----+-------------+

我该怎么做?我想这会是 F.arrays_zip() , F.split() 和/或 F.explode() ,但我想不通。
我现在被困在这里,这是我目前的尝试:

(input_df
    .withColumn('splits', F.split(F.col('data'), ';'))
    .drop('data')
).show()

# +-------------------+

# |             splits|

# +-------------------+

# |[Alice, Bob, Carol]|

# |       [12, 13, 14]|

# |           [5, , 7]|

# |           [1, , 3]|

# |            [, , 3]|

# +-------------------+
rkttyhzu

rkttyhzu1#

一种方法是将第一行作为标题读取,然后取消对数据的分割

df1 = spark.createDataFrame([(12,13,14),(5,None,7),(1,None,3),(None,None,3)], ['Alice','Bob','Carol'])

df1.show()
+-----+----+-----+
|Alice| Bob|Carol|
+-----+----+-----+
|   12|  13|   14|
|    5|null|    7|
|    1|null|    3|
| null|null|    3|
+-----+----+-----+

df1.select(f.expr('''stack(3,'Alice',Alice,'Bob',Bob,'Carol',Carol) as (Name,Value)'''))\
   .groupBy('Name').agg(f.collect_list('value').alias('Value')).orderBy('Name').show()

+-----+-------------+
| Name|        Value|
+-----+-------------+
|Alice|   [12, 5, 1]|
|  Bob|         [13]|
|Carol|[14, 7, 3, 3]|
+-----+-------------+

对于动态传递列,请使用下面的代码

cols = ','.join([f"'{i[0]}',{i[1]}" for i in zip(df1.columns,df1.columns)])
df1.select(f.expr(f'''stack(3,{cols}) as (Name,Value)''')).groupBy('Name').agg(f.collect_list('value').alias('Value')).orderBy('Name').show()

+-----+-------------+
| Name|        Value|
+-----+-------------+
|Alice|   [12, 5, 1]|
|  Bob|         [13]|
|Carol|[14, 7, 3, 3]|
+-----+-------------+
slwdgvem

slwdgvem2#

Solution for Spark-2.4+: 使用 groupBy 使用 collect_list 然后拆分以创建新列。
使用 arrays_zip 压缩数组并创建嵌套数组 [key,[values]] 最后 explode 嵌套数组。 Example: ```
df.show()

+---------------+

| data|

+---------------+

|Alice;Bob;Carol|

| 12;13;14|

| 5;;7|

| 1;;3|

| ;;3|

+---------------+

from pyspark.sql.functions import *

df.agg(split(concat_ws("|",collect_list(col("data"))),"\|").alias("tmp")).
withColumn("col1",split(element_at(col("tmp"),1),";")).
withColumn("col2",split(element_at(col("tmp"),2),";")).
withColumn("col3",split(element_at(col("tmp"),3),";")).
withColumn("col4",split(element_at(col("tmp"),4),";")).
withColumn("zip",arrays_zip(col("col1"),arrays_zip(col("col2"),col("col3"),col("col4")))).
selectExpr("explode(zip)as tmp").
selectExpr("tmp.*").
toDF("name","values").
show(10,False)

+-----+----------+

|name |values |

+-----+----------+

|Alice|[12, 5, 1]|

|Bob |[13, , ] |

|Carol|[14, 7, 3]|

+-----+----------+

为了 `spark < 2.4` 对数组使用自定义项 `getItem(<n>)` 而不是 `element_at` 功能。
kgsdhlau

kgsdhlau3#

我建议把数据读作 ; 分离csv,然后处理以获取 name 以及 values 列如下-
请注意,这段代码是用scala编写的,但是类似的代码可以在pyspark中实现,只需很少的修改

加载;分开的csv

val data =
      """
        |Alice;Bob;Carol
        |       12;13;14
        |           5;;7
        |           1;;3
        |            ;;3
      """.stripMargin
    val stringDS = data.split(System.lineSeparator())
      .map(_.split("\\;").map(_.replaceAll("""^[ \t]+|[ \t]+$""", "")).mkString(";"))
      .toSeq.toDS()
    val df = spark.read
      .option("sep", ";")
      .option("inferSchema", "true")
      .option("header", "true")
      .option("nullValue", "null")
      .csv(stringDS)
    df.printSchema()
    df.show(false)
    /**
      * root
      * |-- Alice: integer (nullable = true)
      * |-- Bob: integer (nullable = true)
      * |-- Carol: integer (nullable = true)
      *
      * +-----+----+-----+
      * |Alice|Bob |Carol|
      * +-----+----+-----+
      * |12   |13  |14   |
      * |5    |null|7    |
      * |1    |null|3    |
      * |null |null|3    |
      * +-----+----+-----+
      */

派生名称和值列

val columns = df.columns.map(c => expr(s"named_struct('name', '$c', 'values',  collect_list($c))"))
    df.select(array(columns: _*).as("array"))
      .selectExpr("inline_outer(array)")
      .show(false)
    /**
      * +-----+-------------+
      * |name |values       |
      * +-----+-------------+
      * |Alice|[12, 5, 1]   |
      * |Bob  |[13]         |
      * |Carol|[14, 7, 3, 3]|
      * +-----+-------------+
      */

相关问题