python Pyspark向现有 Dataframe 添加列

euoag5mw  于 2022-12-28  发布在  Python
关注(0)|答案(1)|浏览(171)

我用下面的代码来实现在一个 Dataframe 中包含多个条件列。

small_list = ["INFY","TCS", "SBIN", "ICICIBANK"]
frame = spark_frame.where(col("symbol") == small_list[0]).select('close')
## spark frame is a pyspark.sql.dataframe.DataFrame

for single_stock in small_list[1:]:
    print(single_stock)
    current_stock = spark_frame.where(col("symbol") == single_stock).select(['close'])
    current_stock.collect()
    frame.collect()
    frame = frame.withColumn(single_stock, current_stock.close)

但是当我做frame.collect时,我得到:

[Row(close=736.85, TCS=736.85, SBIN=736.85, ICICIBANK=736.85),
 Row(close=734.7, TCS=734.7, SBIN=734.7, ICICIBANK=734.7),
 Row(close=746.0, TCS=746.0, SBIN=746.0, ICICIBANK=746.0),
 Row(close=738.85, TCS=738.85, SBIN=738.85, ICICIBANK=738.85)]

既然所有的值都属于第一个引用,那么哪一个是错误的?我做错了什么?解决这个问题的最佳方法是什么?
编辑:spark_frame如下所示

[Row(SYMBOL='LINC', SERIES='  EQ', TIMESTAMP=datetime.datetime(2021, 12, 20, 0, 0), PREVCLOSE=235.6, OPEN=233.95, HIGH=234.0, LOW=222.15, LAST=222.15, CLOSE=224.2, AVG_PRICE=226.63, TOTTRDQTY=6447, TOTTRDVAL=14.61, TOTALTRADES=206, DELIVQTY=5507, DELIVPER=85.42),
 Row(SYMBOL='LINC', SERIES='  EQ', TIMESTAMP=datetime.datetime(2021, 12, 21, 0, 0), PREVCLOSE=224.2, OPEN=243.85, HIGH=243.85, LOW=222.85, LAST=226.0, CLOSE=225.6, AVG_PRICE=227.0, TOTTRDQTY=8447, TOTTRDVAL=19.17, TOTALTRADES=266, DELIVQTY=3401, DELIVPER=40.26)]

预期结果应如下所示:

[Row(close=736.85, TCS=1003.1, SBIN=431.85, ICICIBANK=712.85),
 Row(close=734.7, TCS=1034.7, SBIN=434.7, ICICIBANK=714.7)]
zbdgwd5y

zbdgwd5y1#

我根据新的理解对答案进行了全面修改。
要执行注解中所述的操作,您需要基于收盘价和股票代码透视表。方法如下:
输入数据(出于测试目的略有修改):

+------+------+-------------------+---------+------+------+------+------+-----+---------+---------+---------+-----------+--------+--------+
|SYMBOL|SERIES|          TIMESTAMP|PREVCLOSE|  OPEN|  HIGH|   LOW|  LAST|CLOSE|AVG_PRICE|TOTTRDQTY|TOTTRDVAL|TOTALTRADES|DELIVQTY|DELIVPER|
+------+------+-------------------+---------+------+------+------+------+-----+---------+---------+---------+-----------+--------+--------+
|  INFY|    EQ|2021-12-20 00:00:00|    235.6|233.95| 234.0|222.15|222.15|224.2|   226.63|     6447|    14.61|        206|    5507|   85.42|
|  LINC|    EQ|2021-12-21 00:00:00|    224.2|243.85|243.85|222.85| 226.0|225.6|    227.0|     8447|    19.17|        266|    3401|   40.26|
|  LINC|    EQ|2021-12-21 00:00:00|    224.2|243.85|243.85|222.85| 226.0|224.2|    227.0|     8447|    19.17|        266|    3401|   40.26|
+------+------+-------------------+---------+------+------+------+------+-----+---------+---------+---------+-----------+--------+--------+

下面是代码:

import datetime
from pyspark.sql.functions import first, col

data = [Row(SYMBOL='INFY', SERIES='  EQ', TIMESTAMP=datetime.datetime(2021, 12, 20, 0, 0), PREVCLOSE=235.6, OPEN=233.95,
            HIGH=234.0, LOW=222.15, LAST=222.15, CLOSE=224.2, AVG_PRICE=226.63, TOTTRDQTY=6447, TOTTRDVAL=14.61,
            TOTALTRADES=206, DELIVQTY=5507, DELIVPER=85.42),
        Row(SYMBOL='LINC', SERIES='  EQ', TIMESTAMP=datetime.datetime(2021, 12, 21, 0, 0), PREVCLOSE=224.2, OPEN=243.85,
            HIGH=243.85, LOW=222.85, LAST=226.0, CLOSE=225.6, AVG_PRICE=227.0, TOTTRDQTY=8447, TOTTRDVAL=19.17,
            TOTALTRADES=266, DELIVQTY=3401, DELIVPER=40.26),
        Row(SYMBOL='LINC', SERIES='  EQ', TIMESTAMP=datetime.datetime(2021, 12, 21, 0, 0), PREVCLOSE=224.2, OPEN=243.85,
            HIGH=243.85, LOW=222.85, LAST=226.0, CLOSE=224.2, AVG_PRICE=227.0, TOTTRDQTY=8447, TOTTRDVAL=19.17,
            TOTALTRADES=266, DELIVQTY=3401, DELIVPER=40.26)]

small_list = ['INFY', 'TCS', 'SBIN', 'LINC']

spark_frame = spark.createDataFrame(data)

# Initial data
spark_frame.show()

pivoted_df = spark_frame.groupBy('close').pivot('symbol').agg(first('avg_price'))

select_columns = [single_stock for single_stock in small_list if single_stock in pivoted_df.columns]

pivoted_df = pivoted_df.select('close', *select_columns)

# Output data
pivoted_df.show()
print(pivoted_df.collect())  # Don't use this on production data, you could get OOM on the driver node.

输出示例:

+-----+------+-----+
|close|  INFY| LINC|
+-----+------+-----+
|224.2|226.63|227.0|
|225.6|  null|227.0|
+-----+------+-----+

[Row(close=224.2, INFY=226.63, LINC=227.0), 
 Row(close=225.6, INFY=None, LINC=227.0)]

您可能需要稍微调整一下,以获得聚合中的逻辑来计算您具体需要的内容。
不要在生产中使用collect,因为它收集驱动程序上的所有数据,这可能会导致OOM异常。

相关问题