PySpark相当于将常量数组作为列添加到 Dataframe 中

dojqjjoe  于 2023-05-06  发布在  Spark
关注(0)|答案(4)|浏览(119)

下面的代码可以在Scala Spark中运行。

scala> val ar = Array("oracle", "java")
ar: Array[String] = Array(oracle, java)

scala> df.withColumn("tags", lit(ar)).show(false)
+------+---+----------+----------+--------------+
|name  |age|role      |experience|tags          |
+------+---+----------+----------+--------------+
|John  |25 |Developer |2.56      |[oracle, java]|
|Scott |30 |Tester    |5.2       |[oracle, java]|
|Jim   |28 |DBA       |3.0       |[oracle, java]|
|Mike  |35 |Consultant|10.0      |[oracle, java]|
|Daniel|26 |Developer |3.2       |[oracle, java]|
|Paul  |29 |Tester    |3.6       |[oracle, java]|
|Peter |30 |Developer |6.5       |[oracle, java]|
+------+---+----------+----------+--------------+

如何在PySpark中获得相同的行为?我尝试了下面的,但它不工作,并抛出Java错误。

from pyspark.sql.types import *

tag = ["oracle", "java"]
df2.withColumn("tags", lit(tag)).show()

:java.lang.RuntimeException:不支持的文字类型类java.util.ArrayList [oracle,java]

lhcgjxsq

lhcgjxsq1#

可以从函数模块导入数组

>>> from pyspark.sql.types import *
>>> from pyspark.sql.functions import array

>>> tag=array(lit("oracle"),lit("java")
>>> df2.withColumn("tags",tag).show()

测试如下

>>> from pyspark.sql.functions import array

>>> tag=array(lit("oracle"),lit("java"))
>>> 
>>> ranked.withColumn("tag",tag).show()
+------+--------------+----------+-----+----+----+--------------+               
|gender|    ethinicity|first_name|count|rank|year|           tag|
+------+--------------+----------+-----+----+----+--------------+
|  MALE|      HISPANIC|    JAYDEN|  364|   1|2012|[oracle, java]|
|  MALE|WHITE NON HISP|    JOSEPH|  300|   2|2012|[oracle, java]|
|  MALE|WHITE NON HISP|    JOSEPH|  300|   2|2012|[oracle, java]|
|  MALE|      HISPANIC|     JACOB|  293|   4|2012|[oracle, java]|
|  MALE|      HISPANIC|     JACOB|  293|   4|2012|[oracle, java]|
|  MALE|WHITE NON HISP|     DAVID|  289|   6|2012|[oracle, java]|
|  MALE|WHITE NON HISP|     DAVID|  289|   6|2012|[oracle, java]|
|  MALE|      HISPANIC|   MATTHEW|  279|   8|2012|[oracle, java]|
|  MALE|      HISPANIC|   MATTHEW|  279|   8|2012|[oracle, java]|
|  MALE|      HISPANIC|     ETHAN|  254|  10|2012|[oracle, java]|
|  MALE|      HISPANIC|     ETHAN|  254|  10|2012|[oracle, java]|
|  MALE|WHITE NON HISP|   MICHAEL|  245|  12|2012|[oracle, java]|
|  MALE|WHITE NON HISP|   MICHAEL|  245|  12|2012|[oracle, java]|
|  MALE|WHITE NON HISP|     JACOB|  242|  14|2012|[oracle, java]|
|  MALE|WHITE NON HISP|     JACOB|  242|  14|2012|[oracle, java]|
|  MALE|WHITE NON HISP|     MOSHE|  238|  16|2012|[oracle, java]|
|  MALE|WHITE NON HISP|     MOSHE|  238|  16|2012|[oracle, java]|
|  MALE|      HISPANIC|     ANGEL|  236|  18|2012|[oracle, java]|
|  MALE|      HISPANIC|     AIDEN|  235|  19|2012|[oracle, java]|
|  MALE|WHITE NON HISP|    DANIEL|  232|  20|2012|[oracle, java]|
+------+--------------+----------+-----+----+----+--------------+
only showing top 20 rows
jjhzyzn0

jjhzyzn02#

我发现下面的列表理解工作

>>> arr=["oracle","java"]
>>> mp=[ (lambda x:lit(x))(x) for x in arr ]
>>> df.withColumn("mk",array(mp)).show()
+------+---+----------+----------+--------------+
|  name|age|      role|experience|            mk|
+------+---+----------+----------+--------------+
|  John| 25| Developer|      2.56|[oracle, java]|
| Scott| 30|    Tester|       5.2|[oracle, java]|
|   Jim| 28|       DBA|       3.0|[oracle, java]|
|  Mike| 35|Consultant|      10.0|[oracle, java]|
|Daniel| 26| Developer|       3.2|[oracle, java]|
|  Paul| 29|    Tester|       3.6|[oracle, java]|
| Peter| 30| Developer|       6.5|[oracle, java]|
+------+---+----------+----------+--------------+

>>>
ctzwtxfj

ctzwtxfj3#

scala中的ar声明和python中的tag声明是有区别的。ararray类型,但tagList类型,lit不允许List,这就是为什么它给出错误。
您需要安装numpy来声明array,如下所示

import numpy as np
tag = np.array(("oracle","java"))

仅供参考,如果您在scala中使用List,它也会给予错误

scala> val ar = List("oracle","java")
ar: List[String] = List(oracle, java)

scala> df.withColumn("newcol", lit(ar)).printSchema
java.lang.RuntimeException: Unsupported literal type class scala.collection.immutable.$colon$colon List(oracle, java)
  at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:78)
  at org.apache.spark.sql.catalyst.expressions.Literal$$anonfun$create$2.apply(literals.scala:164)
  at org.apache.spark.sql.catalyst.expressions.Literal$$anonfun$create$2.apply(literals.scala:164)
  at scala.util.Try.getOrElse(Try.scala:79)
  at org.apache.spark.sql.catalyst.expressions.Literal$.create(literals.scala:163)
  at org.apache.spark.sql.functions$.typedLit(functions.scala:127)
  at org.apache.spark.sql.functions$.lit(functions.scala:110)
m528fe3b

m528fe3b4#

Spark 3.4+

F.lit(["oracle", "java"])

完整示例:

from pyspark.sql import functions as F

df = spark.range(5)
df = df.withColumn("tags", F.lit(["oracle", "java"]))

df.show()
# +---+--------------+
# | id|          tags|
# +---+--------------+
# |  0|[oracle, java]|
# |  1|[oracle, java]|
# |  2|[oracle, java]|
# |  3|[oracle, java]|
# |  4|[oracle, java]|
# +---+--------------+

相关问题