在Tensorflow推荐器中构建检索索引时出错

gk7wooem  于 2022-12-30  发布在  其他
关注(0)|答案(1)|浏览(145)

我正在使用Tensorflow推荐器中的BruteForce

index = tfrs.layers.factorized_top_k.BruteForce(model.customer_model, k = 400)

候选人数据集如下所示:

<ZipDataset element_spec=({'article_id': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'prod_name': TensorSpec(shape=(None,), dtype=tf.string, name=None), 'product_type_name': TensorSpec(shape=(None,), dtype=tf.string, name=None)}, TensorSpec(shape=(None, 64), dtype=tf.float32, name=None))>

但当我试图建立检索索引时

index.index_from_dataset(candidates)

我得到以下错误:

AttributeError                            Traceback (most recent call last)
Input In [28], in <cell line: 6>()
      4 candidates = tf.data.Dataset.zip((articles.batch(128), articles.batch(128).map(model.article_model)))
      5 print(candidates)
----> 6 index.index_from_dataset(candidates)

File ~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_recommenders/layers/factorized_top_k.py:197, in TopK.index_from_dataset(self, candidates)
    174 def index_from_dataset(
    175     self,
    176     candidates: tf.data.Dataset
    177 ) -> "TopK":
    178   """Builds the retrieval index.
    179 
    180   When called multiple times the existing index will be dropped and a new one
   (...)
    194     ValueError if the dataset does not have the correct structure.
    195   """
--> 197   _check_candidates_with_identifiers(candidates)
    199   spec = candidates.element_spec
    201   if isinstance(spec, tuple):

File ~/miniconda3/envs/tf/lib/python3.9/site-packages/tensorflow_recommenders/layers/factorized_top_k.py:127, in _check_candidates_with_identifiers(candidates)
    119   raise ValueError(
    120       "The dataset must yield candidate embeddings or "
    121       "tuples of (candidate identifiers, candidate embeddings). "
    122       f"Got {spec} instead."
    123   )
    125 identifiers_spec, candidates_spec = spec
--> 127 if candidates_spec.shape[0] != identifiers_spec.shape[0]:
    128   raise ValueError(
    129       "Candidates and identifiers have to have the same batch dimension. "
    130       f"Got {candidates_spec.shape[0]} and {identifiers_spec.shape[0]}."
    131   )

AttributeError: 'dict' object has no attribute 'shape'

我假设它有一个问题,我的数据集是从字典创建的。
我应该如何传递候选人数据集,这样我就不会得到错误?

o2rvlv0m

o2rvlv0m1#

我想通了
我是这样建立候选人数据集的:

candidates = tf.data.Dataset.zip(articles.batch(128).map(model.article_model)))
index.index_from_dataset(candidates)

但我还需要传递候选标识符,而不仅仅是候选嵌入:

candidates = tf.data.Dataset.zip((articles.batch(128).map(lambda x: x["article_id"]), articles.batch(128).map(model.article_model)))
index.index_from_dataset(candidates)

相关问题