将带有多个INNER JOIN和WHERE的SQL UPDATE转换为PySpark

vohkndzv  于 2022-11-25  发布在  Apache
关注(0)|答案(2)|浏览(135)

我尝试使用PySpark执行以下代码:

join_on = (df_1.C1_PROFIT == df_2.C2_PROFIT) & \            # JOIN CONDITION
          (df_1.C1_REVENUE == df_3.C3_REVENUE_BREAK) & \    # JOIN CONDITION
          (df_1.C1_LOSS == df_4.C4_TOTAL_LOSS) & \          # JOIN CONDITION
          ((df_4.TOTAL_YEAR_PROFIT) > (df_3.TOTAL_GROWTH))  # WHERE CONDITION
df = (df_1.alias('a')
    .join(df_2.alias('b'), join_on, 'left')
    .join(df_3.alias('c'), join_on, 'left')
    .join(df_4.alias('d'), join_on, 'left')
    .select(
        *[c for c in df_2.columns if c != 'C2_TARGET'],
        F.expr("nvl2(b.C2_PROFIT, '500', a.C2_TARGET) C2_TARGET")
    )
)

运行查询后出错:
在df_1列中不存在'年利润合计'、'增长合计'、'亏损合计'和'收入突破':
原始SQL查询:

UPDATE (( companyc1
          INNER JOIN companyc2
                  ON company1.c1_profit = company2.c2_profit)
        INNER JOIN companyc3
                ON company1.c1_revenue = company3.revenue_break)
       INNER JOIN companyc4
               ON company1.c1_loss = company4.c4_total_loss
SET    companyc1.sales = "500"
WHERE  (( ( company4.total_year_profit ) > [company3].[total_growth] ))

有谁能帮我找出我在哪里犯了错误吗?

wwtsj6pe

wwtsj6pe1#

对于每个join操作,必须拆分join_on条件,如下所示:

df = (df_1.alias('a')
    .join(df_2.alias('b'), df_1.C1_PROFIT == df_2.C2_PROFIT, 'left')
    .join(df_3.alias('c'), df_1.C1_REVENUE == df_3.C3_REVENUE_BREAK, 'left')
    .join(df_4.alias('d'), df_1.C1_LOSS == df_4.C4_TOTAL_LOSS. 'left')
    .select(
        *[c for c in df_2.columns if c != 'C2_TARGET'],
        F.expr("nvl2(b.C2_PROFIT, '500', a.C2_TARGET) C2_TARGET")
    ).where("d.TOTAL_YEAR_PROFIT > c.TOTAL_GROWTH")
)
x8diyxa7

x8diyxa72#

在翻译包含多个连接的SQL UPDATE时,在我看来,普遍安全的方法可能涉及groupByaggmonotonically_increasing_id(以确保原始df的行号在聚合后不会缩小)。
我在MS Access中制作了以下表格,以确保我建议的方法在Spark中也能以同样的方式工作。

输入:

更新后的结果:


指令集
"Spark"
MS Access似乎聚合了列值,因此下面的代码也将这样做。
输入:

from pyspark.sql import functions as F

df_1 = spark.createDataFrame(
    [(2,   10,    5, 'replace'),
     (2,   10,    5, 'replace'),
     (1,   10, None,    'keep'),
     (2, None,    5,    'keep')],
    ['C1_PROFIT', 'C1_REVENUE', 'C1_LOSS', 'SALES']
)
df_2 = spark.createDataFrame([(1,), (2,), (8,)], ['C2_PROFIT'])
df_3 = spark.createDataFrame([(10, 51), (10, 50)], ['REVENUE_BREAK', 'TOTAL_GROWTH'])
df_4 = spark.createDataFrame([(5, 50), (5, 51),], ['C4_TOTAL_LOSS', 'TOTAL_YEAR_PROFIT'])

脚本:

df_1 = df_1.withColumn('_id', F.monotonically_increasing_id())
df = (df_1.alias('a')
    .join(df_2.alias('b'), df_1.C1_PROFIT == df_2.C2_PROFIT, 'left')
    .join(df_3.alias('c'), df_1.C1_REVENUE == df_3.REVENUE_BREAK, 'left')
    .join(df_4.alias('d'), df_1.C1_LOSS == df_4.C4_TOTAL_LOSS, 'left')
    .groupBy(*[c for c in df_1.columns if c != 'SALES'])
    .agg(F.when(F.max('d.total_year_profit') > F.min('c.total_growth'), '500')
          .otherwise(F.first('a.SALES')).alias('SALES')
    ).drop('_id')
)
df.show()
# +---------+----------+-------+-----+
# |C1_PROFIT|C1_REVENUE|C1_LOSS|SALES|
# +---------+----------+-------+-----+
# |        1|        10|   null| keep|
# |        2|      null|      5| keep|
# |        2|        10|      5|  500|
# |        2|        10|      5|  500|
# +---------+----------+-------+-----+

相关问题