当负Spark1时,累计和重置为0,6

zphenhs4  于 2023-05-01  发布在  Apache
关注(0)|答案(1)|浏览(168)

我有这段代码(Spark 1.6),其计算窗口分区列上的和:

val sconf = new SparkConf()
  .setAppName("TestPonderacion")
  .setMaster("local[*]")
var sc = new SparkContext(sconf)

val sqlContext = new HiveContext(sc)           

val schema = StructType(List(
  StructField("CLIENT", IntegerType, true),
  StructField("FAMILY", StringType, true),      
  StructField("MOV_DATE", StringType, true),      
  StructField("IMP_PONDERACION", DoubleType, true)      
))    
    
val data = List(
    Row(20898871, "VAUL 10", "2023-01-01 00:00:00", 25.00),
    Row(20898871, "VAUL 10", "2023-02-03 00:00:00", 25.00),
    Row(20898871, "VAUL 10", "2023-02-04 00:00:00", 1250.00),
    Row(20898871, "VAUL 10", "2023-03-01 00:00:00", -750.00),
    Row(20898871, "VAUL 10", "2023-03-02 00:00:00", 25.00),
    Row(20898871, "VAUL 10", "2023-03-03 00:00:00", 25.00),
    Row(20898871, "VAUL 10", "2023-04-01 00:00:00", -750.00),
    Row(20898871, "VAUL 10", "2023-04-02 00:00:00", 25.00),
    Row(20898871, "VAUL 10", "2023-04-03 00:00:00", 25.00)
)

val mov = sqlContext.createDataFrame(sc.parallelize(data), schema)

val movPonderados = mov.withColumn("CUMULATIVE_SUM", sum("IMP_PONDERACION").over(Window.partitionBy("FAMILY", "CLIENT").orderBy("MOV_DATE")).cast(DecimalType(17, 2)))

movPonderados.printSchema()
movPonderados.show(false)

结果是

|CLIENT  |FAMILY |MOV_DATE           |IMP_PONDERACION|CUMULATIVE_SUM|
+--------+-------+-------------------+---------------+--------------+
|20898871|VAUL 10|2023-01-01 00:00:00|25.0           |25.00         |
|20898871|VAUL 10|2023-02-03 00:00:00|25.0           |50.00         |
|20898871|VAUL 10|2023-02-04 00:00:00|1250.0         |1300.00       |
|20898871|VAUL 10|2023-03-01 00:00:00|-750.0         |550.00        |
|20898871|VAUL 10|2023-03-02 00:00:00|25.0           |575.00        |
|20898871|VAUL 10|2023-03-03 00:00:00|25.0           |600.00        |
|20898871|VAUL 10|2023-04-01 00:00:00|-750.0         |-150.00       |
|20898871|VAUL 10|2023-04-02 00:00:00|25.0           |-125.00       |
|20898871|VAUL 10|2023-04-03 00:00:00|25.0           |-100.00       |
+--------+-------+-------------------+---------------+--------------+

但我想要的是将中间结果重置为零,如果变为负数:

|CLIENT  |FAMILY |MOV_DATE           |IMP_PONDERACION|CUMULATIVE_SUM|
+--------+-------+-------------------+---------------+--------------+
|20898871|VAUL 10|2023-01-01 00:00:00|25.0           |25.00         |
|20898871|VAUL 10|2023-02-03 00:00:00|25.0           |50.00         |
|20898871|VAUL 10|2023-02-04 00:00:00|1250.0         |1300.00       |
|20898871|VAUL 10|2023-03-01 00:00:00|-750.0         |550.00        |
|20898871|VAUL 10|2023-03-02 00:00:00|25.0           |575.00        |
|20898871|VAUL 10|2023-03-03 00:00:00|25.0           |600.00        |
|20898871|VAUL 10|2023-04-01 00:00:00|-750.0         |0.00          | <- Reset to 0 since 600-750 is negative
|20898871|VAUL 10|2023-04-02 00:00:00|25.0           |25.00         | <- Start acumulating again from 0
|20898871|VAUL 10|2023-04-03 00:00:00|25.0           |50.00         |
+--------+-------+-------------------+---------------+--------------+

我找到的所有解决方案都适用于Spark 2。X,如:Cumulative Sum with Reset BEFORE Negative in Pyspark
但我需要解决Spark1号的问题6.有人能给予我一把吗?
先谢谢你了

wwwo4jvm

wwwo4jvm1#

对于可能感兴趣的人,我最终通过用以下代码替换窗口函数来解决它:

val calculaPozoUDF = sqlContext.udf.register("calculaPozo", (data: Seq[Row]) => {
  var accumulatedPonderacion = new java.math.BigDecimal("0.00").setScale(2) 
  val zero = new java.math.BigDecimal("0.00").setScale(2)      
  var result = Array[(String, Double, java.math.BigDecimal)]()
  
  for (i <- data.indices) {             
    val currentPonderacion = new java.math.BigDecimal(data(i).getDouble(1))
    accumulatedPonderacion = if(accumulatedPonderacion.add(currentPonderacion).compareTo(zero) < 0) zero else accumulatedPonderacion.add(currentPonderacion)
    result = result:+ (data(i).getString(0), data(i).getDouble(1), accumulatedPonderacion)
  }            
  
  result
})

val movPonderados = mov  
  .groupBy("FAMILY", "CLIENT")
  .agg(sort_array(collect_list(struct("MOV_DATE", "IMP_PONDERACION"))).as("MOV_GROUPED"))
  .withColumn("POZO_TEMP", calculaPozoUDF(col("MOV_GROUPED")))
  .withColumn("POZO_EXPLOD", explode(col("POZO_TEMP")))
  .select(
      col("CLIENT"), 
      col("FAMILY"),                    
      col("POZO_EXPLOD._1").as("MOV_DATE"),
      col("POZO_EXPLOD._2").cast(DecimalType(17, 2)).as("IMP_PONDERACION"),
      col("POZO_EXPLOD._3").cast(DecimalType(17, 2)).as("CUMULATIVE_SUM")
  )

相关问题