我理解检索的任务-我已经看过了代码;也研究了替代方法,如SCNN,这是一种超快的最近邻。
但是,我仍然很难理解下面代码的机制
# Create a model that takes in raw query features, and
index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
# recommends movies out of the entire movies dataset.
index.index_from_dataset(
tf.data.Dataset.zip((movies.batch(100), movies.batch(100).map(model.movie_model)))
)
# Get recommendations.
_, titles = index(tf.constant(["42"]))
print(f"Recommendations for user 42: {titles[0, :3]}")
字符串model.user_model
经过训练,现在应该返回user_id的嵌入。BruteForce
层的输入是model.user_model
;然后应该索引它?
我猜输出是user_id
42,返回3个标题,在movies.batch(100)
之外。但我不明白BruteForce和索引的功能!
2条答案
按热度按时间rslzwgfq1#
BruteForce层测试从模型的最后一层提取的嵌入之间的所有组合。
根据tensorflow documentation for the layer,层返回最接近每个索引的top k结果(默认为10)索引的索引。
r7knjye22#
你误解为“The input for BruteForce layer is model.user_model”。user_model不是BruteForce layer的输入。它是ButeForce类的参数。所以“input”是BruteForce的示例。user_model是嵌入输入的query_model,是两个塔之一。
而index.index_from_dataset()设置了另一个候选塔的嵌入。
batch(100)并不只是输出100部电影,而是输出许多100部电影的块。