我有以下代码:
elements = spark.createDataFrame([
('g1', 'a', 1), ('g1', 'a', 2), ('g1', 'b', 1), ('g1', 'b', 3),
('g2', 'c', 1), ('g2', 'c', 3), ('g2', 'c', 2), ('g2', 'd', 4),
], ['group', 'instance', 'element'])
all_elements_per_instance = elements.groupBy("group", "instance").agg(f.collect_set('element').alias('elements'))
# +-----+--------+---------+
# |group|instance| elements|
# +-----+--------+---------+
# | g1| b| [1, 3]|
# | g1| a| [1, 2]|
# | g2| c|[1, 2, 3]|
# | g2| d| [4]|
# +-----+--------+---------+
@f.udf(ArrayType(IntegerType()))
def intersect(elements):
return list(functools.reduce(lambda x, y: set(x).intersection(set(y)), elements))
all_intersect_elements_per_group = all_elements_per_instance.groupBy("group")\
.agg(intersect(f.collect_list("elements")).alias("intersection"))
# +-----+------------+
# |group|intersection|
# +-----+------------+
# | g1| [1]|
# | g2| []|
# +-----+------------+
有没有办法避免使用udf(因为它很昂贵),并以某种方式使用 f.array_intersect
或者类似于聚合函数的函数?
2条答案
按热度按时间7lrncoxx1#
你可以使用高阶函数
aggregate
做一个array_intersect
关于要素:camsedfj2#
如果你想找到
elements
至少由2个共享的instances
在每个group
,您实际上可以通过使用窗口然后使用groupby来计算每个组/元素的不同示例来简化它group
仅收集计数大于1的元素:我曾经
collect_set
+size
作为函数countDistinct
不支持窗口。