我正在使用Tensorflow推荐器中的BruteForce
index = tfrs.layers.factorized_top_k.BruteForce(model.customer_model, k = 400)
候选人数据集如下所示:
<ZipDataset element_spec=({'article_id': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'prod_name': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'product_type_name': TensorSpec(shape=(None,), dtype=tf.string, name=None)}, TensorSpec(shape=(None, 64), dtype=tf.float32, name=None))>
但当我试图建立检索索引时
index.index_from_dataset(candidates)
我得到以下错误:
AttributeError Traceback (most recent call last)
Input In [28], in <cell line: 6>()
4 candidates = tf.data.Dataset.zip((articles.batch(128), articles.batch(128).map(model.article_model)))
5 print(candidates)
----> 6 index.index_from_dataset(candidates)
File ~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_recommenders/layers/factorized_top_k.py:197, in TopK.index_from_dataset(self, candidates)
174 def index_from_dataset(
175 self,
176 candidates: tf.data.Dataset
177 ) -> "TopK":
178 """Builds the retrieval index.
179
180 When called multiple times the existing index will be dropped and a new one
(...)
194 ValueError if the dataset does not have the correct structure.
195 """
--> 197 _check_candidates_with_identifiers(candidates)
199 spec = candidates.element_spec
201 if isinstance(spec, tuple):
File ~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_recommenders/layers/factorized_top_k.py:127, in _check_candidates_with_identifiers(candidates)
119 raise ValueError(
120 "The dataset must yield candidate embeddings or "
121 "tuples of (candidate identifiers, candidate embeddings). "
122 f"Got {spec} instead."
123 )
125 identifiers_spec, candidates_spec = spec
--> 127 if candidates_spec.shape[0] != identifiers_spec.shape[0]:
128 raise ValueError(
129 "Candidates and identifiers have to have the same batch dimension. "
130 f"Got {candidates_spec.shape[0]} and {identifiers_spec.shape[0]}."
131 )
AttributeError: 'dict' object has no attribute 'shape'
我假设它有一个问题,我的数据集是从字典创建的。
我应该如何传递候选人数据集,这样我就不会得到错误?
1条答案
按热度按时间o2rvlv0m1#
我想通了
我是这样建立候选人数据集的:
但我还需要传递候选标识符,而不仅仅是候选嵌入: