pyspark中最长的公共子串

iqih9akk  于 2023-05-16  发布在  Spark
关注(0)|答案(1)|浏览(86)

我正在努力尝试在Spark中的两个列之间进行最长公共子串比较。
理论上,我可以用这样的函数来解决它:

def longest_common_substring(input_string, string_list):
    longest_substrings = []
    for string in string_list:
        longest_substring = ""
        for i in range(len(input_string)):
            for j in range(i, len(input_string)):
                if string.startswith(input_string[i:j+1]):
                    if len(input_string[i:j+1]) > len(longest_substring):
                        longest_substring = input_string[i:j+1]
        longest_substrings.append(longest_substring)
    return longest_substrings

longest_common_substring("Node 1 - 2643", 
                         ['Node 1 - 2643', 'Node ', 'Node 1 - 2643 - Node 1 A032 - 32432'])

Output:
['Node 1 - 2643', 'Node ', 'Node 1 - 2643']

但由于我在实际数据中有1亿到几十亿行,因此性能是关键。因此,《城市设计框架》可能不是最佳解决办法。
有没有一种方法可以只使用spark sql函数来实现这一点?

MWE

这里的一些示例数据包括列lcs,它表示我的目标列。

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("MWE - longest common substring") \
    .getOrCreate()

data = [
    ["AB1234|BC24412|DE34322", "AB1234|BC24412|DE9687", "AB1234|BC24412|DE"],
    ["AA2222|BB3333|CC6666", "AA2222|BD23123|EE12|ZZ929292", "AA2222|B"]
]
schema = ["pathA", "pathB", "lcs (goal)"]

df = spark.createDataFrame(data, schema=schema)

数据看起来像什么:

+----------------------+----------------------------+-----------------+                                                                                                                                                                                                                                                                                                              
|pathA                 |pathB                       |lcs (goal)       | 
+----------------------+----------------------------+-----------------+
|AB1234|BC24412|DE34322|AB1234|BC24412|DE9687       |AB1234|BC24412|DE|
|AA2222|BB3333|CC6666  |AA2222|BD23123|EE12|ZZ929292|AA2222|B         |
+----------------------+----------------------------+-----------------+
fnvucqvd

fnvucqvd1#

我在周围调整,找到了一个适合我的解决方案。也许将来能帮上忙。
最后,它比预期的要简单得多。它只是将split放入数组,然后是array_intersect,最后是concat_ws,以获取字符串。

import pyspark.sql.functions as F

df_final = (
    df
    .withColumn("pathA_arr", F.split("pathA", "\\|"))
    .withColumn("pathB_arr", F.split("pathB", "\\|"))
    .withColumn("common", F.array_intersect("pathA_arr", "pathB_arr"))
    .withColumn("lcs", F.concat_ws("|", "common"))
    .drop("pathA_arr", "pathB_arr", "common")
)

结果如下所示:

+----------------------+----------------------------+-----------------+--------------+
|pathA                 |pathB                       |lcs (goal)       |lcs           |
+----------------------+----------------------------+-----------------+--------------+
|AB1234|BC24412|DE34322|AB1234|BC24412|DE9687       |AB1234|BC24412|DE|AB1234|BC24412|
|AA2222|BB3333|CC6666  |AA2222|BD23123|EE12|ZZ929292|AA2222|B         |AA2222        |
+----------------------+----------------------------+-----------------+--------------+

lcs (goal)lcs之间有细微的差别。缺少|DE,但在我的真实的数据中这不是问题。我之所以使用这个例子,是因为我脑子里总是有一个典型的最长公共子串函数。
我认为这个解决方案的最大优点是,它非常快。在我的真实的数据中,这导致的额外时间可以忽略不计。例如,整个pyspark查询现在需要90s。当我试图用一个UDF解决它时,我在2小时后杀死了这个工作。
会让这个主题开放一些天,如果有人知道正则表达式的解决方案,这可能是有趣的比较。

相关问题