PySpark:groupby()count('*')未按预期工作,或者我误解了

z31licg0  于 2022-12-22  发布在  Spark
关注(0)|答案(1)|浏览(152)

我在试着
1.类别中的一行事物
1.类别内所有事物的行
下面是我尝试过的方法。

# This is PySpark
# df has variables 'id' 'category' 'thing'
# 'category' one : many 'id'
#
# sample data:
# id    | category | thing
# alpha |        A |     X
# alpha |        A |     X
# alpha |        A |     Y
# beta  |        A |     X
# beta  |        A |     Z
# beta  |        A |     Z
# gamma |        B |     X
# gamma |        B |     Y
# gamma |        B |     Z

df_count_per_category = df.\
    select('category', 'thing').\
    groupby('category', 'thing').\
    agg(F.count('*').alias('thing_count'))

# Proposition total, to join with df_turnover_segmented
df_total = df.\
    select('category').\
    groupby('category').\
    agg(F.count('*').alias('thing_total'))

df_merge = df.\
    join(df_count_per_category,\
        (df_count_per_category.thing== df_count_per_category.thing) & \
        (df_count_per_category.category== df_count_per_category.category), \
    'inner').\
    drop(df_count_per_category.thing).\
    drop(df_count_per_category.category).\
    join(df_total,\
        (df.category== df_total.category), \
    'inner').\
    drop(df_total.category)

df_rate = df_merge.\
    withColumn('thing_rate', F.round(F.col('thing_count') / F.col('thing_total'), 3))

我希望thing_countthing_total,和thing_rate对于相同的thing是相同的,因为每个thing都是category互斥的。然而,尽管thing_count在各行之间是相同的值,但thing_rate不是。为什么呢?
这是我想要达到的R等效值:

# This is R
library(tidytable)
df_total = df |>
  mutate(.by = c(category, thing),
         thing_count = n()) |>
  mutate(.by = category,
         thing_total = n()) |>
  mutate(thing_rate = thing_count / thing_total)

这是预期结果(+/-一些列):

# This is a table
category | thing | thing_count | thing_total | thing_rate
       A |     X |           3 |           6 |        0,5
       A |     Y |           1 |           6 |     0,1667
       A |     Z |           2 |           6 |     0,3333
       B |     X |           1 |           3 |     0,3333
       B |     Y |           1 |           3 |     0,3333
       B |     Z |           1 |           3 |     0,3333
ikfrs5lh

ikfrs5lh1#

我认为你的第二个join不是你打算做的。
您正在第二个联接条件中引用原始df,这导致创建错误的关联。相反,您希望将df_total联接到第一个联接的结果。

df_merge = df.\
    join(df_count_per_category ,\
        (df.thing== df_count_per_category.thing) & \
        (df.category== df_count_per_category.category), \
    'inner').\
    drop(df_count_per_category .thing).\
    drop(df_count_per_category .category)

df_merge = df_merge.join(df_total ,\
        (df_merge.category== df_total.category), \  # Reference df_merge.category.
    'inner').\
    drop(df_total.category)

或者,您可以使用窗口函数实现预期的 Dataframe ,而无需多个连接。

from pyspark.sql import Window
from pyspark.sql import functions as F

df = (df.select('category', 'thing',
                F.count('*').over(Window.partitionBy('category', 'thing')).alias('thing_count'),
                F.count('*').over(Window.partitionBy('category')).alias('thing_total'))
      .withColumn('thing_rate', F.round(F.col('thing_count') / F.col('thing_total'), 3)))

相关问题