对于input_table
中的每一行,应创建Xoutput_table
中的行数,其中X=年中的天数(从StartDate
开始)Info
字段应包含Y字符数,其中Y= X*2,如果字符数较少,则字段应填充额外的#
字符。
在output_table
AM
和PM
列中,将以正确的顺序填充Info
字符,以便每个AM
和PM
字段将恰好有1个字符。
下面是代码:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType, StringType, DateType, StructField, StructType, TimestampType, ArrayType
# Connection details for input table
url="..."
user="..."
password="..."
input_table="..."
output_table="..."
# Define schema for input table
input_schema = StructType([
StructField("ID1", IntegerType(), True),
StructField("ID2", IntegerType(), True),
StructField("StartDate", TimestampType(), True),
StructField("Info", StringType(), True),
StructField("Extracted", TimestampType(), True)
])
# Define schema for output table
output_schema = StructType([
StructField("ID1", IntegerType(), True),
StructField("ID2", IntegerType(), True),
StructField("Date", DateType(), True),
StructField("AM", StringType(), True),
StructField("PM", StringType(), True),
StructField("CurrentYear", StringType(), True)
])
# Initialize SparkSession
spark = SparkSession.builder.getOrCreate()
# Register UDF for padding marks
pad_marks_udf = udf(lambda info, days: marks.ljust(days, '#')[:days], StringType())
# Register UDF for creating rows
create_rows_udf = udf(lambda start_date, marks, days: [(start_date + i, info[i], info[i + 1]) for i in range(0, days, 2)],
ArrayType(StructType([
StructField("Date", DateType(), True),
StructField("AM", StringType(), True),
StructField("PM", StringType(), True),
])))
# Define function to pad marks and create rows
def process_row(row):
id1 = row["ID1"]
id2 = row["ID2"]
start_date = row["StartDate"]
info= row["info"]
extracted = row["Extracted"]
# Calculate number of days * 2
days = (start_date.year % 4 == 0 and 366 or 365) * 2
# Pad info
padded_info = pad_info_udf(info, days)
# Create rows
rows = create_rows_udf(start_date, padded_info, days)
# Prepare output rows
output_rows = []
for r in rows:
date = r["Date"]
am = r["AM"]
pm = r["PM"]
current_year = f"{current_year .year}/{current_year .year + 1}"
output_rows.append((id1, id2, date, am, pm, current_year))
return output_rows
# Load input table as DataFrame
df_input = spark.read \
.format("jdbc") \
.option("url", url) \
.option("dbtable", input_table) \
.option("user", user) \
.option("password", password) \
.schema(input_schema) \
.load()
# Apply processing to input table
output_rows = df_input.rdd.flatMap(process_row)
# Create DataFrame from output rows
df_output = spark.createDataFrame(output_rows, output_schema)
# Write DataFrame to output table
df_output.write \
.format("jdbc") \
.option("url", url) \
.option("user", user) \
.option("password", password) \
.option("dbtable", output_table) \
.mode("append") \
.save()
类似的代码在Python
中运行没有问题,但当转换为PySpark
时抛出一个AssertionError
。它不需要在input_table
中做任何修改,并在output_table
中添加修改过的行。
1条答案
按热度按时间63lcw9qa1#
所以原因是代码不应该在RDD的函数中使用Spark UDF。应该使用普通函数。Spark UDF只能用于Spark SQL。
代码在本地机器上工作的原因是因为在本地模式下,执行器与驱动程序在同一个JVM中。