Pyspark在udf中的框架

nbnkbykc  于 2023-10-15  发布在  Spark
关注(0)|答案(1)|浏览(156)

我有两个pyspark的名字:qnotes_df(2列)和part_numbers_df(1列)。在qnotes_df中,我有一个名为'LONG_TEXT'的列。我想分析此列并提取可能在文本中的部件号。这些部件号将使用part_numbers_df进行匹配。我已经进行了标记化和所有东西,但当尝试将每个单词与part_numbers_df进行比较时,这是不可能的,因为您无法访问udf中的pyspark框架。对此有什么建议吗?
这是我的代码

# Define a UDF to extract part numbers from a text
def extract_part_numbers_udf(text):
    # Tokenize the text and filter part numbers
    tokens = nlp(text)
    matches = [str(token) for token in tokens if not token.is_punct and not token.is_space and part_numbers_df.filter(col("PART_NUMBER") == str(token)).count() > 0]
    return matches

# Register the UDF with ArrayType return type
udf_extract_part_numbers = udf(extract_part_numbers_udf, ArrayType(StringType()))

# Apply the UDF to create a new column
qnotes_df = qnotes_df.withColumn("REPLACEMENTS", udf_extract_part_numbers(qnotes_df["LONG_TEXT"]))

# Show the DataFrame with the new "REPLACEMENTS" column
qnotes_df.show(truncate=False)
wqnecbli

wqnecbli1#

试试这个:

from pyspark.sql.functions import udf, broadcast
from pyspark.sql.types import ArrayType, StringType
from pyspark.ml.feature import Tokenizer
import re

part_numbers_set = set(part_numbers_df.rdd.map(lambda row: row[0]).collect())

broadcast_part_numbers = spark.sparkContext.broadcast(part_numbers_set)

def extract_part_numbers_udf(text):
    # Tokenize the text and filter part numbers
    tokens = re.split(r'\W+', text)  # simple tokenization using regex
    matches = [token for token in tokens if token in broadcast_part_numbers.value]
    return matches

udf_extract_part_numbers = udf(extract_part_numbers_udf, ArrayType(StringType()))

qnotes_df = qnotes_df.withColumn("REPLACEMENTS", udf_extract_part_numbers(qnotes_df["LONG_TEXT"]))

qnotes_df.show(truncate=False)

相关问题