使用Scala在Spark DataFrame中删除和插入记录

rdrgkggo  于 12个月前  发布在  Scala
关注(0)|答案(3)|浏览(122)

我有一个名为original_df的嵌套框架,它的列数是可变的,还有另外两个嵌套框架delete_dfinsert_df

original_df = 
+---+----+----+---+
|  s|  p1|  p2| p3|
+---+----+----+---+
| s1|  o1|  o2| o7|
| s1|  o2|  o2| o7|
| s2|  o3|null| o4|
| s3|null|  o5| o6|
| s4|null|null| o6|
+---+----+----+---+

delete_df =
+---+---+---+
|  s|  p|  o|
+---+---+---+
| s1| p3| o7|
| s2| p1| o3|
| s4| p3| o6|
+---+---+---+

insert_df = 
+---+---+---+
|  s|  p|  o|
+---+---+---+
| s1| p3| o8|
| s5| p2| o9|
+---+---+---+

final_df = 
+---+----+----+----+
|  s|  p1|  p2|  p3|
+---+----+----+----+
| s1|  o1|  o2|  o8|
| s1|  o2|  o2|  o8|
| s2|null|null|  o4|
| s3|null|  o5|  o6|
| s5|null|  o9|null|
+---+----+----+----+

如果delete_df中特定p的元组(s, o)属于original_df中的(s, P),其中P可以是p1p2p3等。(P是变量;它可以具有任意数量的列。)则对于特定的sfinal_df中的列P的值将是null
示例:final_df具有元组(s2, null, null, o4),因为delete_df具有元组(s2, p1, o3)

+---+----+----+---+
|  s|  p1|  p2| p3|
+---+----+----+---+
| s1|  o1|  o2| o7|
| s1|  o2|  o2| o7|
| s2|null|null| o4|
| s3|null|  o5| o6|
| s4|null|null| o6|
+---+----+----+---+

类似地,我们将得到delete_df(s1, p3, o7)(s4, p3, o6)元组的以下df。

+---+----+----+----+
|  s|  p1|  p2|  p3|
+---+----+----+----+
| s1|  o1|  o2|null|
| s1|  o2|  o2|null|
| s2|null|null|  o4|
| s3|null|  o5|  o6|
| s4|null|null|null|
+---+----+----+----+

我们还必须从final_df中删除记录,如果所有列p1p2p3等。对于特定的s值,具有null值。
例如:我们在final_df中有(s4, null, null, null),因为delete_df中有(s4, p3, o6)。我们将从final_df中丢弃此记录,因为所有P列都有null值。

+---+----+----+----+
|  s|  p1|  p2|  p3|
+---+----+----+----+
| s1|  o1|  o2|null|
| s1|  o2|  o2|null|
| s2|null|null|  o4|
| s3|null|  o5|  o6|
+---+----+----+----+

insert_df将用于将记录插入final_df。示例:我们将获得insert_df的元组(s1, p3, o8)的以下df。

+---+----+----+----+
|  s|  p1|  p2|  p3|
+---+----+----+----+
| s1|  o1|  o2|  o8|
| s1|  o2|  o2|  o8|
| s2|null|null|  o4|
| s3|null|  o5|  o6|
+---+----+----+----+

请让我知道如何有效地从original_df使用delete_dfinsert_df在Scala中实现final_df

cu6pst1q

cu6pst1q1#

下面是另一种方法,与我之前提出的方法非常相似,但不涉及original_df的爆炸。相反,我们可以用sinsert_dfdelete_df进行分组,并在p上旋转,使它们具有与original_df相同的结构。然后,我们加入s并应用您在问题中描述的逻辑。最后,我们仍然需要删除所有p列都为null的行:

// the list of p columns
val P = original_df.columns.filter(_ startsWith "p")

val insert_pivot = insert_df
    .groupBy("s").pivot("p", P).agg(first('o))
    .select($"s" +: P.map(p => col(p) as s"insert_$p") : _* )
val delete_pivot = delete_df
    .groupBy("s").pivot("p", P).agg(first('o))
    .select($"s" +: P.map(p => col(p) as s"delete_$p") : _* )

// simply displaying insert_pivot to visualize the structure:
insert_pivot.show
+---+---------+---------+---------+
|  s|insert_p1|insert_p2|insert_p3|
+---+---------+---------+---------+
| s1|     null|     null|       o8|
| s5|     null|       o9|     null|
+---+---------+---------+---------+
original_df
    .join(delete_pivot, Seq("s"), "left")
    .select($"s" +: P.map(p => when( col(p) === col(s"delete_$p"), null).otherwise(col(p)) as p) : _*)
    .join(insert_pivot, Seq("s"), "full")
    .select($"s" +: P.map(p => coalesce(col(s"insert_$p"), col(p)) as p) : _* )
    .where(P.map(col).map(_.isNotNull).reduce(_||_))
    .show
+---+----+----+----+
|  s|  p1|  p2|  p3|
+---+----+----+----+
| s1|  o1|  o2|  o8|
| s1|  o2|  o2|  o8|
| s2|null|null|  o4|
| s3|null|  o5|  o6|
| s5|null|  o9|null|
+---+----+----+----+
8yparm6h

8yparm6h2#

如果delete_dfinsert_df足够小,可以收集到驱动程序中并在when子句中使用,那么这可能是最有效的方法。
在一般情况下,如果不能做出这样的假设,我建议将original_df分解如下:

+---+---+----+
|  s|  p|   o|
+---+---+----+
| s1| p1|  o1|
| s1| p2|  o2|
| s1| p3|  o7|
| s1| p1|  o2|
| s1| p2|  o2|
| s1| p3|  o7|
| s2| p1|  o3|
| s2| p2|null|
| s2| p3|  o4|
| s3| p1|null|
| s3| p2|  o5|
| s3| p3|  o6|
| s4| p1|null|
| s4| p2|null|
| s4| p3|  o6|
+---+---+----+

然后用[s,p,o]上的delete_df连接这个分解的框架,找出需要删除的内容,用[s,p]上的insert_df插入需要插入的内容。然后我们可以按s分组,并在p上旋转,以获得final_df
代码如下:

// the list of p columns
val P = original_df.columns.filter(_ startsWith "p")

original_df
    .withColumn("id", monotonically_increasing_id)
    .select($"s", 'id, explode(array(
        P.map(p => struct(lit(p).alias("p"), col(p).alias("o"))) : _*
    )) as "P" )
    .select('s, 'id, $"P.*")
    .join(delete_df.withColumn("delete", lit(true)), Seq("s", "p", "o"), "left")
    .withColumn("o", when('delete, null).otherwise('o))
    .join(insert_df.withColumnRenamed("o", "new_o"), Seq("s", "p"), "full")
    .withColumn("o", coalesce('new_o, 'o))
    .groupBy("id", "s").pivot("p").agg(first('o) as "o")
    // removing values for which all P columns are null
    .where(P.map(col).map(_.isNotNull).reduce(_||_)).orderBy("s", "id")
    .drop("id")
    .show
+---+----+----+----+
|  s|  p1|  p2|  p3|
+---+----+----+----+
| s1|  o1|  o2|  o8|
| s1|  o2|  o2|  o8|
| s2|null|null|  o4|
| s3|null|  o5|  o6|
| s5|null|  o9|null|
+---+----+----+----+

由于s列中可以有重复项,因此还有一个额外的微妙之处。为了能够在旋转时返回所有行,我们需要一个唯一的行id来跟踪哪些值属于哪一行。

snvhrwxg

snvhrwxg3#

python中的代码!
在这里,你有另一种方法来做到这一点。它不使用groupBy().pivot(),而是使用ArrayMap操作,如array_exceptarray_union

original_df = spark.createDataFrame([
  ("s1", "o1", "o2", "o7"),
  ("s1", "o2", "o2", "o7"),
  ("s2", "o3", None, "o4"),
  ("s3", None, "o5", "o6"),
  ("s4", None, None, "o6")],
  schema=["s", "p1", "p2", "p3"]
)

delete_df = spark.createDataFrame([
  ("s1", "p3", "o7"),
  ("s2", "p1", "o3"),
  ("s4", "p3", "o6")],
  schema=["s", "p", "o"]
)

insert_df = spark.createDataFrame([
  ("s1", "p3", "o8"),
  ("s5", "p2", "o9")],
  schema=["s", "p", "o"]
)
# array with all "" columns
p_arr = [F.struct(F.lit(col).name("p"), F.col(col).name("o")) for col in original_df.columns if col.startswith("p")]

original_arr = original_df.select("s", F.array(*p_arr).name("p_arr"))
delete_arr = delete_df.groupBy("s").agg(F.collect_list(F.struct("p", "o")).name("delete_arr"))
insert_arr = insert_df.groupBy("s").agg(F.collect_list(F.struct("p", "o")).name("insert_arr"))
original_map = original_arr \
  .withColumn("p_arr",  F.filter("p_arr", lambda e: e.getItem("o").isNotNull())) \
  .join(delete_arr, "s", how="left") \
  .join(insert_arr, "s", how="outer") \
  .withColumn("p_arr",  F.expr("IF (p_arr IS NULL, Array(), p_arr)")) \
  .withColumn("delete_arr", F.expr("IF (delete_arr IS NULL, Array(), delete_arr)")) \
  .withColumn("insert_arr", F.expr("IF (insert_arr IS NULL, Array(), insert_arr)")) \
  .withColumn("p_arr", F.array_except("p_arr", "delete_arr")) \
  .withColumn("p_arr", F.array_union("p_arr", "insert_arr")) \
  .withColumn("p_map", F.map_from_entries("p_arr")) \
  .filter(F.size("p_arr") > 0)
# get all different keys
keys = original_map.select(F.explode(F.map_keys("p_map")).name("key")) \
  .distinct().orderBy("key").collect()
keys = [row.key for row in keys]

# unfold the map
original_map.select(
    "s",
    *[F.col("p_map").getItem(key).name(key) for key in keys]
).orderBy("s").show()
+---+----+----+----+
|  s|  p1|  p2|  p3|
+---+----+----+----+
| s1|  o1|  o2|  o8|
| s1|  o2|  o2|  o8|
| s2|null|null|  o4|
| s3|null|  o5|  o6|
| s5|null|  o9|null|
+---+----+----+----+

奥利的解决方案似乎更清晰和有效,但我花了太多的时间不张贴我的解决方案!😊

相关问题