在PySpark中复制SAS Retain语句

qmb5sa22  于 2023-04-05  发布在  Spark
关注(0)|答案(1)|浏览(110)

我正在将代码从SAS迁移到PySpark,我正在努力解决以下SAS retain语句:

data ds;
set  ds;
 
    by group date;
   
    retain Target; 
 
    if first.group                             then Target = Orig;
    if first.group and ( Orig in (1,2,3,4,5) ) then Target = 6;
 
    if not first.group and   Target  = 6 and (Orig in (1,2,3,4,5) )   then Target = 6 ;
    if not first.group and ~(Target  = 6 and (Orig in (1,2,3,4,5) ) ) then Target = Orig ;
   
    run;

如何才能做到这一点?
如果第一个原始值不为0,则向上舍入为6。如果不是组中的第一个,并且目标已经为6,并且原始值在1、2、3、4、5中,则将目标保持为6。如果不是组中的第一个,并且目标不是6或原始值不在1、2、3、4、5中,则将目标设置为原始值。
我为一个单独的组提供了一个例子(为冗长的一个道歉):

df = SparkSession.createDataFrame([
(999, 5,6) ,
(999, 6,6) ,
(999, 4,6) ,
(999, 6,6) ,
(999, 3,6) ,
(999, 5,6) ,
(999, 4,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 5,6) ,
(999, 3,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 2,6) ,
(999, 1,6) ,
(999, 0,0) ,
(999, 0,0) ,
(999, 0,0) ,
(999, 0,0) ,
(999, 1,1) ,
(999, 1,1) ,
(999, 2,2) ,
(999, 2,2) ,
(999, 3,3) ,
(999, 2,2) ,
(999, 3,3) ,
(999, 4,4) ,
(999, 5,5) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 0,0) ,
(999, 1,1) ,
(999, 0,0) ,
(999, 1,1) ,
(999, 2,2) ,
(999, 3,3) ,
(999, 4,4) ,
(999, 5,5) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 6,6) ,
(999, 4,6) ,
(999, 3,6) ,
(999, 2,6) ,
(999, 3,6) ,
(999, 4,6) ,
(999, 5,6) ,
(999, 6,6) ,
(999, 6,6) ],
['Group', 'Orig', 'Target']
)
sd2nnvve

sd2nnvve1#

没有简单的方法来复制SAS的retain。但是可以建立在逻辑上。retain,在某种程度上,正在查看以前计算的值来计算当前值。* 所以,你本质上是滞后于你正在创建的列,而你正在创建它。*
pyspark可以使用结构体数组和高阶函数来实现。
下面是实现方法 (注意,我添加了一个日期字段-dt-这有助于数据排序-类似于您的by group date;语句)

# convert data to array of structs per group
arr_struct_sdf = data_sdf. \
    withColumn('allattr', func.struct('dt', 'orig', 'exp_tgt')). \
    groupBy('group'). \
    agg(func.array_sort(func.collect_list('allattr')).alias('allattr')). \
    withColumn('frst_elm', func.col('allattr')[0])

# +-----+--------------------+--------------------+
# |group|             allattr|            frst_elm|
# +-----+--------------------+--------------------+
# |  999|[{2020-01-01 00:0...|{2020-01-01 00:00...|
# +-----+--------------------+--------------------+

# use `aggregate` higher order function to generate `target` field
arr_struct_sdf. \
    withColumn('new_attr',
               func.aggregate(func.expr('slice(allattr, 2, size(allattr))'),
                              func.array(func.col('frst_elm').withField('tgt', 
                                                                        func.when(func.col('frst_elm.orig').isin(1,2,3,4,5), func.lit(6)).
                                                                        otherwise(func.col('frst_elm.orig'))
                                                                        )
                                         ),
                              lambda x, y: func.array_union(x,
                                                            func.array(y.withField('tgt',
                                                                                   func.when((func.element_at(x, -1).tgt == 6) & (y.orig.isin(1,2,3,4,5)), func.lit(6)).
                                                                                   otherwise(y.orig)
                                                                                   )
                                                                       )
                                                            )
                              )
               ). \
    selectExpr('group', 'inline(new_attr)'). \
    show(100, False)

# +-----+-------------------+----+-------+---+
# |group|dt                 |orig|exp_tgt|tgt|
# +-----+-------------------+----+-------+---+
# |999  |2020-01-01 00:00:00|5   |6      |6  |
# |999  |2020-01-02 00:00:00|6   |6      |6  |
# |999  |2020-01-03 00:00:00|4   |6      |6  |
# |999  |2020-01-04 00:00:00|6   |6      |6  |
# |999  |2020-01-05 00:00:00|3   |6      |6  |
# |999  |2020-01-06 00:00:00|5   |6      |6  |
# |999  |2020-01-07 00:00:00|4   |6      |6  |
# |999  |2020-01-08 00:00:00|6   |6      |6  |
# |999  |2020-01-09 00:00:00|6   |6      |6  |
# |999  |2020-01-10 00:00:00|6   |6      |6  |
# |999  |2020-01-11 00:00:00|6   |6      |6  |
# |999  |2020-01-12 00:00:00|5   |6      |6  |
# |999  |2020-01-13 00:00:00|3   |6      |6  |
# |999  |2020-01-14 00:00:00|2   |6      |6  |
# |999  |2020-01-15 00:00:00|2   |6      |6  |
# |999  |2020-01-16 00:00:00|2   |6      |6  |
# |999  |2020-01-17 00:00:00|2   |6      |6  |
# |999  |2020-01-18 00:00:00|2   |6      |6  |
# |999  |2020-01-19 00:00:00|2   |6      |6  |
# |999  |2020-01-20 00:00:00|2   |6      |6  |
# |999  |2020-01-21 00:00:00|2   |6      |6  |
# |999  |2020-01-22 00:00:00|1   |6      |6  |
# |999  |2020-01-23 00:00:00|0   |0      |0  |
# |999  |2020-01-24 00:00:00|0   |0      |0  |
# |999  |2020-01-25 00:00:00|0   |0      |0  |
# |999  |2020-01-26 00:00:00|0   |0      |0  |
# |999  |2020-01-27 00:00:00|1   |1      |1  |
# |999  |2020-01-28 00:00:00|1   |1      |1  |
# |999  |2020-01-29 00:00:00|2   |2      |2  |
# |999  |2020-01-30 00:00:00|2   |2      |2  |
# |999  |2020-01-31 00:00:00|3   |3      |3  |
# |999  |2020-02-01 00:00:00|2   |2      |2  |
# |999  |2020-02-02 00:00:00|3   |3      |3  |
# |999  |2020-02-03 00:00:00|4   |4      |4  |
# |999  |2020-02-04 00:00:00|5   |5      |5  |
# |999  |2020-02-05 00:00:00|6   |6      |6  |
# |999  |2020-02-06 00:00:00|6   |6      |6  |
# |999  |2020-02-07 00:00:00|6   |6      |6  |
# |999  |2020-02-08 00:00:00|0   |0      |0  |
# |999  |2020-02-09 00:00:00|1   |1      |1  |
# |999  |2020-02-10 00:00:00|0   |0      |0  |
# |999  |2020-02-11 00:00:00|1   |1      |1  |
# |999  |2020-02-12 00:00:00|2   |2      |2  |
# |999  |2020-02-13 00:00:00|3   |3      |3  |
# |999  |2020-02-14 00:00:00|4   |4      |4  |
# |999  |2020-02-15 00:00:00|5   |5      |5  |
# |999  |2020-02-16 00:00:00|6   |6      |6  |
# |999  |2020-02-17 00:00:00|6   |6      |6  |
# |999  |2020-02-18 00:00:00|6   |6      |6  |
# |999  |2020-02-19 00:00:00|6   |6      |6  |
# |999  |2020-02-20 00:00:00|4   |6      |6  |
# |999  |2020-02-21 00:00:00|3   |6      |6  |
# |999  |2020-02-22 00:00:00|2   |6      |6  |
# |999  |2020-02-23 00:00:00|3   |6      |6  |
# |999  |2020-02-24 00:00:00|4   |6      |6  |
# |999  |2020-02-25 00:00:00|5   |6      |6  |
# |999  |2020-02-26 00:00:00|6   |6      |6  |
# |999  |2020-02-27 00:00:00|6   |6      |6  |
# +-----+-------------------+----+-------+---+

解释
aggregate高阶函数接受一个数组,初始值和一个要合并的函数(类似于python的reduce)。

  • 在本例中,我传递了包含除组的第一行以外的每一行的结构体的源数组。
  • 第二个参数是初始值,它是组的第一行沿着它的计算target(这是first.group计算)。
  • 第三个参数是merge函数,它接受初始值,并递归地合并其他值
  • 这是计算not first.group条件的地方
  • y是当前正在计算的值,它查看先前计算的值,而先前计算的值又是来自xelement_at(x, -1))的最后一个值
  • P.S. exp_tgt(预期目标)是问题中共享的示例数据中已有的target字段。tgt是pyspark生成的最终目标字段。*

对于那些由于旧版本的spark而无法使用aggregate函数的人,他们可以使用expr()中的aggregate SQL函数,如下所示。

data_sdf. \
    withColumn('allattr', func.struct('dt', 'orig', 'exp_tgt')). \
    groupBy('group'). \
    agg(func.array_sort(func.collect_list('allattr')).alias('allattr')). \
    withColumn('frst_elm', func.col('allattr')[0]). \
    withColumn('new_attr', 
               func.expr('''
                    aggregate(slice(allattr, 2, size(allattr)), 
                              array(struct(frst_elm.dt as dt, frst_elm.orig as orig, frst_elm.exp_tgt as exp_tgt, if(frst_elm.orig in (1,2,3,4,5), 6, frst_elm.orig) as tgt)),
                              (x, y) -> array_union(x,
                                                    array(struct(y.dt as dt, y.orig as orig, y.exp_tgt as exp_tgt, if(element_at(x, -1).tgt=6 and y.orig in (1,2,3,4,5), 6, y.orig) as tgt))
                                                    )
                              )
               ''')
               ). \
    selectExpr('group', 'inline(new_attr)'). \
    show(100, truncate=False)

相关问题