df = df.withColumn('valid_tokens', array_intersect(
array([lit(x) for x in
broadcasted_valid_list.value]),
col("input_tokens")))
我有一个数组类型的列input_token。(df ~ 10 k行)我有一个广播列表,它有超过100 k的值。
我只需要该行中属于该列表的那些令牌。
array_intersect花费的时间太长,因为对于每一行,每个标记都要执行100万次搜索(每个ArrayType行值大约10个标记)
这里有没有其他的操作可以帮助你?有效的标记作为字典,但是如果有帮助的话,用什么来代替array_intersect呢?
1条答案
按热度按时间92dk7w1h1#
代码
[lit(x) for x in broadcasted_valid_list.value]
实际上是在驱动程序上运行的,它创建了一个约100 k值的数组,然后试图将其作为执行计划发送给驱动程序,这是一个太长的计划,并导致了您所谈论的延迟。新-第2版
我们仍然保存有效标记的列表,但是我们首先解压缩输入的所有标记,加入标记并且分组回到输入ID以得到所需的格式。
版本1
我发现运行您的类似代码的最简单方法是将100 k有效令牌的列表保存在某个地方。
可以看出,10被丢弃,因为它不在我们保存为 parquet 文件的范围内。x1c 0d1x