检查数组的哪些元素倾向于成对出现- pyspark

ncecgwcz  于 2023-06-05  发布在  Spark
关注(0)|答案(3)|浏览(342)

我有一个这样的spark Dataframe :

+-------------------+---------------------------+
|  movie_description|                       tags|
+-------------------+---------------------------+
|               ....|            [comedy,horror]|
|               ....|                  [romance]|
|               ....|         [thriller, sci-fi]|
|               ....| [sci-fi, horror, thriller]|
+-------------------+---------------------------+

如何检查哪对标签最常出现在一起?有10个唯一标记值。

flseospp

flseospp1#

你可以在同一个dataframe上做一个join,在分解你的标签列之后,对所有可能的对进行连接,然后对每对标签执行一个groupBy,对它们进行计数,然后保持所有计数的最大值:

spark = SparkSession.builder.master("local[*]").getOrCreate()
data = [
    ("A", ["comedy", "horror"]),
    ("B", ["romance"]),
    ("C", ["thriller", "sci-fi"]),
    ("D", ["sci-fi", "horror", "thriller"])
]
df = spark.createDataFrame(data, ["movie_description", "tags"])

explodedDF = df.withColumn("id", monotonically_increasing_id()).withColumn("tags", explode(col("tags"))).select("id", "tags")
joinDf = explodedDF.join(explodedDF.withColumnRenamed("tags", "tags2"), ["id"], "left").filter(col("tags") != col("tags2"))
pairCounts = joinDf.groupBy("tags", "tags2").count()
maxCountValue = pairCounts.agg(functions.max(col("count"))).first()[0]
max_tag_1 = pairCounts.filter(col("count") == maxCountValue).select("tags").first()[0]
max_tag_2 = pairCounts.filter(col("count") == maxCountValue).select("tags2").first()[0]
print(max_tag_1, max_tag_2)

结果:
惊悚科幻小说

am46iovg

am46iovg2#

您可以使用FPGrowth Spark ML库来获取常用的标记

from pyspark.sql.functions import col, size, array_sort
from pyspark.ml.fpm import FPGrowth

df = spark.createDataFrame([
    (0, ['comedy', 'horror']),
    (1, ['romance']),
    (2, ['thriller', 'sci-fi']),
    (3, ['sci-fi', 'horror', 'thriller'])
], ["id", "items"])

# Sort tags to ignore pairs order
df = df.withColumn('items', array_sort(col('items')))

fpGrowth = FPGrowth(itemsCol="items", minSupport=0.1, minConfidence=0.1)
model = fpGrowth.fit(df)

freq_tags = model.freqItemsets.filter(size(col('items')) == 2).sort(col('freq').desc())

freq_tags.show()
+------------------+----+
|             items|freq|
+------------------+----+
|[sci-fi, thriller]|   2|
|[thriller, horror]|   1|
|  [comedy, horror]|   1|
|  [sci-fi, horror]|   1|
+------------------+----+
htrmnn0y

htrmnn0y3#

我有一个与已经发布的答案略有不同的方法。(注意,使用FPGrowth模型的答案似乎是最简单的,IMO)。

# create UDF that creates pairs for lists having 3+ tags
def createPairPermutations(arr):
    from itertools import permutations

    return [sorted(k) for k in permutations(arr, 2)]

createPairPermutations_udf = func.udf(createPairPermutations, 'Array<Array<string>>')

# create pair permutations and explode
pairs_sdf = data_sdf. \
    withColumn('tag_pair_arr', 
               func.when(func.size('tags') < 3, func.array('tags')).
               otherwise(func.array_distinct(createPermutations_udf('tags')))
               ). \
    withColumn('tag_pairs', func.explode('tag_pair_arr')). \
    withColumn('tag_pairs', func.sort_array('tag_pairs'))

# +-----------------+--------------------------+----------------------------------------------------------+------------------+
# |movie_description|tags                      |tag_pair_arr                                              |tag_pairs         |
# +-----------------+--------------------------+----------------------------------------------------------+------------------+
# |A                |[comedy, horror]          |[[comedy, horror]]                                        |[comedy, horror]  |
# |B                |[romance]                 |[[romance]]                                               |[romance]         |
# |C                |[thriller, sci-fi]        |[[thriller, sci-fi]]                                      |[sci-fi, thriller]|
# |D                |[sci-fi, horror, thriller]|[[horror, sci-fi], [sci-fi, thriller], [horror, thriller]]|[horror, sci-fi]  |
# |D                |[sci-fi, horror, thriller]|[[horror, sci-fi], [sci-fi, thriller], [horror, thriller]]|[sci-fi, thriller]|
# |D                |[sci-fi, horror, thriller]|[[horror, sci-fi], [sci-fi, thriller], [horror, thriller]]|[horror, thriller]|
# +-----------------+--------------------------+----------------------------------------------------------+------------------+

# count occurrences of the pairs
pairs_sdf. \
    groupBy('tag_pairs'). \
    agg(func.countDistinct('movie_description').alias('freq')). \
    orderBy(func.desc('freq')). \
    show(truncate=False)

# +------------------+----+
# |tag_pairs         |freq|
# +------------------+----+
# |[sci-fi, thriller]|2   |
# |[horror, sci-fi]  |1   |
# |[comedy, horror]  |1   |
# |[horror, thriller]|1   |
# |[romance]         |1   |
# +------------------+----+

相关问题