用Dataframescala spark计算余弦相似性

j2datikz  于 2021-07-12  发布在  Spark
关注(0)|答案(2)|浏览(400)

我有这样一个Dataframe:

+-------+-------+------------------+-------+----+
|userId1|movieId|              rat1|userId2|rat2|
+-------+-------+------------------+-------+----+
|      1|      1|               1.0|      2| 1.0|
|      1|      2|               1.0|      2| 2.0|
|      1|      3|               2.0|      2| 3.0|
|      2|      1|               1.0|      3| 0.0|
|      2|      2|               2.0|      3| 0.0|
|      2|      3|               3.0|      3| 0.0|
|      3|      1|               0.0|      1| 1.0|
|      3|      2|               0.0|      1| 1.0|

....

其中rat1和rat2是user1和user2的评级。我想要的是计算两个用户之间的余弦相似度,我的想法是从这个Dataframe中提取数组,然后计算余弦相似度,例如:

arrayUser1 = (1,1,2)
arrayUser2 = (1,2,3)
arrayUser3 = (0,0,0)

问题是我不知道如何提取这些数组,有人有办法吗?或者用更好的方法计算相似性的技巧?

jdgnovmf

jdgnovmf1#

您可以先将rat1和rat2相乘,然后按userid1和userid2分组,并将乘积相加:

df.show
+-------+-------+----+-------+----+
|userId1|movieId|rat1|userId2|rat2|
+-------+-------+----+-------+----+
|      1|      1| 1.0|      2| 1.0|
|      1|      2| 1.0|      2| 2.0|
|      1|      3| 2.0|      2| 3.0|
|      2|      1| 1.0|      3| 0.0|
|      2|      2| 2.0|      3| 0.0|
|      2|      3| 3.0|      3| 0.0|
|      3|      1| 0.0|      1| 1.0|
|      3|      2| 0.0|      1| 1.0|
|      3|      3| 0.0|      1| 2.0|
+-------+-------+----+-------+----+
val cos_sim = df.withColumn(
    "rat1",    // normalize rat1
    coalesce(
        $"rat1" / sqrt(sum($"rat1" * $"rat1").over(Window.partitionBy("userId1"))),
        lit(0)
    )
).withColumn(
    "rat2",    // normalize rat2
    coalesce(
        $"rat2" / sqrt(sum($"rat2" * $"rat2").over(Window.partitionBy("userId2"))),
        lit(0)
    )
).withColumn(
    "rat1_times_rat2",
    $"rat1" * $"rat2"
).groupBy("userId1", "userId2").agg(sum("rat1_times_rat2").alias("cos_sim"))

cos_sim.show
+-------+-------+-----------------+
|userId1|userId2|          cos_sim|
+-------+-------+-----------------+
|      3|      1|              0.0|
|      2|      3|              0.0|
|      1|      2|0.981980506061966|
+-------+-------+-----------------+
flvtvl50

flvtvl502#

您可以使用Dataframe groupBy 操作并执行 collect_set 聚合
下面是示例代码。

scala> someDF.show
+-------+-------+----+-------+----+
|userId1|movieId|rat1|userId2|rat2|
+-------+-------+----+-------+----+
|      1|      1| 1.0|      2| 1.0|
|      1|      2| 1.0|      2| 2.0|
|      1|      3| 2.0|      2| 3.0|
|      2|      1| 1.0|      3| 0.0|
|      2|      2| 2.0|      3| 0.0|
|      2|      3| 3.0|      3| 0.0|
|      3|      1| 0.0|      1| 1.0|
|      3|      2| 0.0|      1| 1.0|
+-------+-------+----+-------+----+

scala> someDF.groupBy("userId1").agg(collect_set("rat1").alias("ratinglist")).show
+-------+---------------+
|userId1|     ratinglist|
+-------+---------------+
|      1|     [2.0, 1.0]|
|      3|          [0.0]|
|      2|[2.0, 1.0, 3.0]|
+-------+---------------+

相关问题