python array_intersect使用效率低

41ik7eoe  于 2023-03-16  发布在  Python
关注(0)|答案(1)|浏览(123)
df = df.withColumn('valid_tokens', array_intersect(
                array([lit(x) for x in
                       broadcasted_valid_list.value]),
                col("input_tokens")))

我有一个数组类型的列input_token。(df ~ 10 k行)我有一个广播列表,它有超过100 k的值。
我只需要该行中属于该列表的那些令牌。
array_intersect花费的时间太长,因为对于每一行,每个标记都要执行100万次搜索(每个ArrayType行值大约10个标记)
这里有没有其他的操作可以帮助你?有效的标记作为字典,但是如果有帮助的话,用什么来代替array_intersect呢?

92dk7w1h

92dk7w1h1#

代码[lit(x) for x in broadcasted_valid_list.value]实际上是在驱动程序上运行的,它创建了一个约100 k值的数组,然后试图将其作为执行计划发送给驱动程序,这是一个太长的计划,并导致了您所谈论的延迟。

新-第2版

我们仍然保存有效标记的列表,但是我们首先解压缩输入的所有标记,加入标记并且分组回到输入ID以得到所需的格式。

# mock - save to disk valid tokens
arr = list(range(19, 105000))
pddf = pd.DataFrame(arr, columns=["id"])
pddf.to_parquet("valid_tokens.parquet")

# read as spark dataframe and collect all tokens to one array row
df = spark.read.parquet("valid_tokens.parquet")

my_input.alias('input') \
.select('id', explode('input_tokens').alias('input_token')) \
.join(df.alias('good_tokens'), on=col('input_token') == col('good_tokens.id')) \
.groupBy('input.id') \
.agg(collect_list('input_token').alias('input_tokens')) \
.show()

版本1

我发现运行您的类似代码的最简单方法是将100 k有效令牌的列表保存在某个地方。

# mock - save to disk valid tokens
arr = list(range(19, 105000))
pddf = pd.DataFrame(arr, columns=["id"])
pddf.to_parquet("valid_tokens.parquet")

# read as spark dataframe and collect all tokens to one array row
df = spark.read.parquet("valid_tokens.parquet")
df_agg = df.agg(collect_list('id').alias('ids')) #one row for all tokens

# input of 100 rows with 3 tokens: 10, 20, 50
my_input = spark.range(100) \
.withColumn("input_tokens", array(lit(10), lit(20), lit(50))) \
.drop('id')

my_input \
.crossJoin(df_agg) \
.withColumn('final', array_intersect('input_tokens', 'ids').alias('final')) \
.select('input_tokens', 'final') \
.show(truncate=0)

可以看出,10被丢弃,因为它不在我们保存为 parquet 文件的范围内。x1c 0d1x

相关问题