如何使用新的scala列在内存转换中应用窗口函数

ohfgkhjo  于 2021-05-19  发布在  Spark
关注(0)|答案(2)|浏览(380)

我有一个Dataframe,我想转换如下输出,其中每一行的开始持续时间和结束持续时间将从上一行的开始持续时间和结束持续时间的结果,请让我知道如何实现它在Spark使用scala。
以下是计算开始持续时间和结束持续时间的公式:

start_duration = max(previous end_duration + 1, current date); 
end_duration = min(presciption_end date, start_duration + duration – 1)

下面是我的输入数据框:

+--------

--------+-----------+---------+-----------+----------------+----------+--------+----------+----------+
|prescription_uid|patient_uid|ndc      |label      |dispensation_uid|date      |duration|start_date|end_date  |
+----------------+-----------+---------+-----------+----------------+----------+--------+----------+----------+
|0               |0          |16714-128|sinvastatin|0               |2015-06-10|30      |2015-06-01|2015-12-01|
|0               |0          |16714-128|sinvastatin|1               |2015-07-15|30      |2015-06-01|2015-12-01|
|0               |0          |16714-128|sinvastatin|2               |2015-08-01|30      |2015-06-01|2015-12-01|
|0               |0          |16714-128|sinvastatin|3               |2015-10-01|30      |2015-06-01|2015-12-01|
+----------------+-----------+---------+-----------+----------------+----------+--------+----------+----------+

预期输出Dataframe:

EXPECTED RESULT:
    +--------

        --------+-----------+---------+-----------+----------------+----------+--------+----------+----------+--------------------+------------------+--------------+------------+
        |prescription_uid|patient_uid|ndc      |label      |dispensation_uid|date      |duration|start_date|end_date  |first_start_duration|first_end_duration|start_duration|end_duration|
        +----------------+-----------+---------+-----------+----------------+----------+--------+----------+----------+--------------------+------------------+--------------+------------+
        |0               |0          |16714-128|sinvastatin|0               |2015-06-10|30      |2015-06-01|2015-12-01|2015-06-10          |2015-07-09        |2015-06-10    |2015-07-09  |
        |0               |0          |16714-128|sinvastatin|1               |2015-07-15|30      |2015-06-01|2015-12-01|2015-06-10          |2015-07-09        |2015-07-15    |2015-08-13  |
        |0               |0          |16714-128|sinvastatin|2               |2015-08-01|30      |2015-06-01|2015-12-01|2015-06-10          |2015-07-09        |2015-08-14    |2015-09-13  |
        |0               |0          |16714-128|sinvastatin|3               |2015-10-01|30      |2015-06-01|2015-12-01|2015-06-10          |2015-07-09        |2015-10-01    |2015-10-30  |
        +----------------+-----------+---------+-----------+----------------+----------+--------+----------+----------+--------------------+------------------+--------------+------------+

Code tried : 
val windowByPatient = Window.partitionBy($"patient_uid").orderBy($"date")
    val windowByPatientBeforeCurrentRow = windowByPatient.rowsBetween(Window.unboundedPreceding, -1)
    joinedPrDF = joinedPrDF
      .withColumn("first_start_duration", firstStartDuration(first($"date").over(windowByPatient), $"start_date"))
      .withColumn("first_end_duration", firstEndDuration($"first_start_duration", $"end_date", $"duration"))
      .withColumn("start_duration", when(count("*").over(windowByPatient) === 1, $"first_start_duration")
        .otherwise(startDurationCalc($"first_start_duration", $"date", $"start_date", coalesce(sum($"duration").over(windowByPatientBeforeCurrentRow), lit("0")))))
      .withColumn("end_duration", when(count("*").over(windowByPatient) === 1, $"first_end_duration")
        .otherwise(endDurationCalc($"end_date", $"start_duration", $"duration")))

自定义项:

val startDurationCalc = udf( (firstStrtDur:java.sql.Date, currentDsDate:java.sql.Date,
                                      prsStartDate:java.sql.Date,duration:Int) => {
      println("==="+firstStrtDur+"==="+currentDsDate +"==="+prsStartDate +"==="+duration )

        var startDate = java.sql.Date.valueOf(firstStrtDur.toLocalDate.plusDays(duration))
        if (startDate.after(currentDsDate)) {
          startDate
        } else {
          currentDsDate
        }
    } : java.sql.Date)

    val endDurationCalc = udf( (prsEndDate:java.sql.Date, startDuration:java.sql.Date,duration:Int) => {

      println("endDateCalcContRow ==="+prsEndDate+"==="+startDuration +"==="+duration )

      val currEndDate = java.sql.Date.valueOf(startDuration.toLocalDate.plusDays(duration-1))
      if (currEndDate.before(prsEndDate)) {
        currEndDate
      } else {
        prsEndDate
      }

    } : java.sql.Date)
bis0qfac

bis0qfac1#

您不应该期望窗口函数对Dataframe中不存在的数据进行计算,而是在执行期间计算的数据(您称之为“内存行”)。这是不可能的。
你可以尝试不同的方法。计算每个 start_duration 形成第一个,基于 duration (您可以考虑可能的差距)。

val windowByPatient = Window.partitionBy("patient_uid").orderBy("date")
val windowByPatientBeforeCurrentRow = windowByPatient.rowsBetween(Window.unboundedPreceding, -1)

data
  .withColumn("previous_date", lag("date", 1).over(windowByPatient))
  .withColumn("diff_from_prev", datediff(col("date"), coalesce(col("previous_date"), col("date"))))
  .withColumn("diff_with_duration", when(col("diff_from_prev") >= lag("duration", 1, 0).over(windowByPatient), col("diff_from_prev")).otherwise(col("duration")))
  .withColumn("first_date_by_patient", first("date").over(windowByPatient))
  .withColumn("duration_from_first_with_gaps", col("diff_with_duration") + coalesce(sum("diff_from_prev").over(windowByPatientBeforeCurrentRow), lit("0")))
  .withColumn("start_duration", expr("date_add(first_date_by_patient, duration_from_first_with_gaps)"))
  .withColumn("end_duration", expr("date_add(start_duration, duration - 1)"))
  .select((data.columns ++ Seq("start_duration", "end_duration")).map(col): _*)
  .show()
``` `date_add` 被包裹在 `expr` ,因为它需要 `Int` 作为第二个参数,但可以与sql上下文中的列一起使用。
woobm2wo

woobm2wo2#

以下是在上一个持续时间和上一个豁免日期使用滞后窗口功能的最终开始持续时间计算器:

val startDurationCalc = udf((currentDsDate: java.sql.Date, prevDsDate: java.sql.Date, prevDuration: Int, prsEndDate: java.sql.Date,
                                 firstStrtDur:java.sql.Date,acDuration:Int) => {
      println("startDurationCalc===currentDsDate===" + currentDsDate + "===prevDsDate===" + prevDsDate +
        "===prevDuration===" + prevDuration +"===prsEndDate==="+prsEndDate+"===firstStrtDur=="+firstStrtDur+"===acDuration==="+acDuration)
      val prevDurStartDate = prevDsDate.toLocalDate.plusDays(prevDuration - 1)
      var derivedDsDate = java.sql.Date.valueOf(prevDurStartDate.plusDays(1))
      val accumulatedDSDate = java.sql.Date.valueOf(firstStrtDur.toLocalDate.plusDays(acDuration))

      if (derivedDsDate.before(accumulatedDSDate)) {
        derivedDsDate = accumulatedDSDate
      }

      if (derivedDsDate.after(prsEndDate)) {
        val derPrsEndDate = java.sql.Date.valueOf(prsEndDate.toLocalDate.plusDays(1))
        derPrsEndDate
      } else {
        if (currentDsDate.after(derivedDsDate)) {
          currentDsDate
        } else {
          derivedDsDate
        }
      }
    }: java.sql.Date).asNondeterministic()

相关问题