PySpark:基于列中的数字和多个条件创建新行(分解)

mdfafbf1  于 2023-06-28  发布在  Spark
关注(0)|答案(1)|浏览(110)

我有一个 Dataframe ,它有几列,一个唯一的ID,一个月和一个分割。我需要分解数据框并为每个id、month和split的唯一组合创建新行。要分解的数字已经计算出来并存储在列bad_call_dist中。例如,如果ID12345monthJansplit'A'bad_call_dist6,则总共需要6行。必须对每个唯一组合重复此过程。
我有适用于小数据集的代码,但是我需要将其应用于更大的 Dataframe ,每次都会超时。下面的代码所做的是从原始数据中提取一个单行数据框,其中包含一个临时范围列,该列表示对于一个唯一的列组合必须存在多少行。然后,我使用explode()生成新的行并将其合并到主 Dataframe 中。我正在寻找帮助,以优化代码,加快处理时间,同时产生相同的结果:

# unique ID-month-split combinations for the data
idMonthSplits = call_data.select('id', 'month', 'split').distinct().collect()

# set the schema to all cols except the bad call flag, which is set to 1 in the loop
master_explode = spark.createDataFrame([], schema=call_data.select([col for col in call_data.columns if col != 'bad_call_flag']).schema)

# loop
for ims in idMonthSplits:

id = ims ['id']
month = ims ['month']
split = ims ['split']

# explode the df one row per n, where n is the value in bad_call_dist.
explode_df = exploded.filter((exploded['id'] == id) & (exploded['month'] == month) & (exploded['split'] == split))\
    .withColumn('bad_call_flag', F.lit(1))

try:
    
    # extract the value that represents the number of rows to explode
    expVal = explode_df.select(F.first(F.col("bad_call_dist")).cast("int")).first()[0]

    # range that is used by explode() to convert single row to multiple rows
    explode_df = explode_df.withColumn(
        'range',
        F.array(
            [F.lit(i) for i in range(expVal + 1)]
        )
    )

    # explode the df, then drop cols no longer needed for union
    explode_df = explode_df.withColumn('explode', F.explode(F.col('range')))\
        .drop(*['explode', 'range', 'bad_call_dist'])

    # union to master df
    master_explode = master_explode.unionAll(explode_df)

# if the explode value is 0, no need to expand rows. This triggers to avoid an error.
except:
    continue
vcirk6k6

vcirk6k61#

Spark中的循环几乎总是灾难性的。最好尽可能多地使用spark函数,因为它们可以进行内部优化,并且可以使用expr()中的array_repeat()解决您的情况。
举个例子

# given the following data
# +---+-----+-----+-------------+
# | id|month|split|bad_call_dist|
# +---+-----+-----+-------------+
# |  1|  Jan|    A|            6|
# |  1|  Feb|    A|            8|
# +---+-----+-----+-------------+

# create a dummy array to explode using `array_repeat` and explode it
data_sdf. \
    withColumn('dummy_arr', func.expr('array_repeat(1, cast(bad_call_dist as int))')). \
    selectExpr(*data_sdf.columns, 'explode(dummy_arr) as exp_dummy'). \
    show()

# +---+-----+-----+-------------+---------+
# |id |month|split|bad_call_dist|exp_dummy|
# +---+-----+-----+-------------+---------+
# |1  |Jan  |A    |6            |1        |
# |1  |Jan  |A    |6            |1        |
# |1  |Jan  |A    |6            |1        |
# |1  |Jan  |A    |6            |1        |
# |1  |Jan  |A    |6            |1        |
# |1  |Jan  |A    |6            |1        |
# |1  |Feb  |A    |8            |1        |
# |1  |Feb  |A    |8            |1        |
# |1  |Feb  |A    |8            |1        |
# |1  |Feb  |A    |8            |1        |
# |1  |Feb  |A    |8            |1        |
# |1  |Feb  |A    |8            |1        |
# |1  |Feb  |A    |8            |1        |
# |1  |Feb  |A    |8            |1        |
# +---+-----+-----+-------------+---------+

注意我在expr中使用了array_repeat。这是因为您希望重复的次数来自列,而spark本机函数不接受第二个参数中的列,但SQL函数接受。

data_sdf. \
    withColumn('dummy_arr', func.expr('array_repeat(1, cast(bad_call_dist as int))')). \
    show(truncate=False)

# +---+-----+-----+-------------+------------------------+
# |id |month|split|bad_call_dist|dummy_arr               |
# +---+-----+-----+-------------+------------------------+
# |1  |Jan  |A    |6            |[1, 1, 1, 1, 1, 1]      |
# |1  |Feb  |A    |8            |[1, 1, 1, 1, 1, 1, 1, 1]|
# +---+-----+-----+-------------+------------------------+

相关问题