如何在sparkDataframe中基于b列获取a列的5条记录

qco9c6ql  于 2021-05-27  发布在  Spark
关注(0)|答案(4)|浏览(367)

我是新来的 Spark 我有一个
Spark dataframe 低于记录。我想根据counts列对数据进行降序排序,得到code列的前5位。

counts = to_df.groupBy("id", "code").count().select("id", "code", sf.col("count").alias("counts"))  
sorting = counts.sort("code", ascending=False)
sorting.show()

我试过用 filter 以及 where 以及 head . 但运气不好。如果不显式地提供值,我就不能使用where。我想让它充满活力。
数据

+--------+-------+-+
| id|   code|counts|
+--------+-------+-+
| 1 |     ZB|     2|
| 2 |     ZB|     2|
| 3 |     ZB|     1|
| 4 |     ZB|     3|
| 5 |     ZB|     1|
| 6 |     ZB|     1|
| 7 |     ZB|     3|
| 8 |     XY|     3|
| 9 |     XY|     1|
| 10|     XY|     2|
| 11|     XY|     1|
| 12|     XY|     1|
| 13|     XY|     1|
| 14|     SD|     2|
| 15|     SD|     1|
| 16|     SD|     1|
| 17|     SD|     3|
| 18|     SD|     1|
| 19|     SD|     2|
| 1 |     SD|     7|
+--------+-------+-+

我想得到如下输出。有人能帮忙吗。

+--------+-------+-+
| id|   code|counts|
+--------+-------+-+
| 7 |     ZB|     3|
| 4 |     ZB|     3|
| 1 |     ZB|     2|
| 2 |     ZB|     2|
| 5 |     ZB|     1|
| 8 |     XY|     3|
| 10|     XY|     2|
| 11|     XY|     1|
| 12|     XY|     1|
| 13|     XY|     1|
| 1 |     SD|     7|
| 17|     SD|     3|
| 14|     SD|     2|
| 19|     SD|     2|
| 18|     SD|     1|
+--------+-------+-+
pwuypxnk

pwuypxnk1#

pyspark版本,使用row\ u number()分析函数

from pyspark.sql import Window as W
from pyspark.sql import functions as F

w = W().partitionBy('code').orderBy('id')
df = df.withColumn("row_number", F.row_number().over(w))
df = df.filter(F.col('row_number') <= 5)
df = df.drop('row_number')
df.show()

+--------+-------+-+
| id|   code|counts|
+--------+-------+-+
| 7 |     ZB|     3|
| 4 |     ZB|     3|
| 1 |     ZB|     2|
| 2 |     ZB|     2|
| 5 |     ZB|     1|
| 8 |     XY|     3|
| 10|     XY|     2|
| 11|     XY|     1|
| 12|     XY|     1|
| 13|     XY|     1|
| 1 |     SD|     7|
| 17|     SD|     3|
| 14|     SD|     2|
| 19|     SD|     2|
| 18|     SD|     1|
+--------+-------+-+
jtw3ybtb

jtw3ybtb2#

你可以试着用 row_number 然后过滤出row num>5的每组行。

import org.apache.spark.sql.expressions._  
df.select('*,row_number.over(Window.orderBy('counts.desc).partitionBy('code)).as("num"))
  .where('num<=5)
  .select('id,'code,'counts)
  .show()

输出:

+---+----+------+
| id|code|counts|
+---+----+------+
|  8|  XY|     3|
| 10|  XY|     2|
|  9|  XY|     1|
| 11|  XY|     1|
| 12|  XY|     1|
|  1|  SD|     7|
| 17|  SD|     3|
| 14|  SD|     2|
| 19|  SD|     2|
| 15|  SD|     1|
|  4|  ZB|     3|
|  7|  ZB|     3|
|  1|  ZB|     2|
|  2|  ZB|     2|
|  3|  ZB|     1|
+---+----+------+
gc0ot86w

gc0ot86w3#

import spark.implicits._
import org.apache.spark.sql.functions._

val df: DataFrame = Seq(
  (1 , "ZB", 2),
  (2 , "ZB", 2),
  (3 , "ZB", 1),
  (4 , "ZB", 3),
  (5 , "ZB", 1),
  (6 , "ZB", 1),
  (7 , "ZB", 3),
  (8 , "XY", 3),
  (9 , "XY", 1),
  (10, "XY", 2),
  (11, "XY", 1),
  (12, "XY", 1),
  (13, "XY", 1),
  (14, "SD", 2),
  (15, "SD", 1),
  (16, "SD", 1),
  (17, "SD", 3),
  (18, "SD", 1),
  (19, "SD", 2),
  (1 , "SD", 7)
).toDF("id", "code", "counts")

val w: WindowSpec = Window.partitionBy($"code").orderBy($"counts".desc)
df.select($"*", row_number().over(w).alias("rn"))
  .where($"rn" < 6)
  .drop($"rn")
  .show()

输出:

+---+----+------+
| id|code|counts|
+---+----+------+
|  8|  XY|     3|
| 10|  XY|     2|
|  9|  XY|     1|
| 11|  XY|     1|
| 12|  XY|     1|
|  1|  SD|     7|
| 17|  SD|     3|
| 14|  SD|     2|
| 19|  SD|     2|
| 15|  SD|     1|
|  4|  ZB|     3|
|  7|  ZB|     3|
|  1|  ZB|     2|
|  2|  ZB|     2|
|  3|  ZB|     1|
+---+----+------+
dpiehjr4

dpiehjr44#

// sourceDF is your Data 
import spark.implicits._
val sourceDF = Seq(( 1, "ZB",     2),
  (2, "ZB",     2),
  (3, "ZB",     1),
  (4, "ZB",     3),
  (5, "ZB",     1),
  (6, "ZB",     1),
  (7, "ZB",     3),
  (8, "XY",     3),
  (9, "XY",     1),
  (10, "XY",     2),
  (11, "XY",     1),
  (12, "XY",     1),
  (13, "XY",     1),
  (14, "SD",     2),
  (15, "SD",     1),
  (16, "SD",     1),
  (17, "SD",     3),
  (18, "SD",     1),
  (19, "SD",     2),
  (1, "SD",     7)).toDF("id", "code", "counts")

  val windowSpec = Window.partitionBy("code").orderBy('counts.desc)
  val resDF = sourceDF.withColumn("row_number", row_number.over(windowSpec))
    .filter('row_number <= 5)
    .drop("row_number")

    resDF.show(false)

//  +---+----+------+
//  |id |code|counts|
//  +---+----+------+
//  |8  |XY  |3     |
//  |10 |XY  |2     |
//  |9  |XY  |1     |
//  |11 |XY  |1     |
//  |12 |XY  |1     |
//  |1  |SD  |7     |
//  |17 |SD  |3     |
//  |14 |SD  |2     |
//  |19 |SD  |2     |
//  |15 |SD  |1     |
//  |4  |ZB  |3     |
//  |7  |ZB  |3     |
//  |1  |ZB  |2     |
//  |2  |ZB  |2     |
//  |3  |ZB  |1     |
//  +---+----+------+

相关问题