haystack 提示生成器产生huggingface_hub.errors.ValidationError:输入验证错误:输入必须少于4095个标记,给定:4701

ruarlubt  于 3个月前  发布在  其他
关注(0)|答案(3)|浏览(48)

描述bug

在使用标准的RAG管道时,我遇到了上述错误。

错误信息

File "/home/felix/PycharmProjects/anychat/src/anychat/analysis/rag.py", line 124, in query_rag_in_document_store
    result = self.llm_pipeline.run(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/core/pipeline/pipeline.py", line 197, in run
    res = comp.run(**last_inputs[name])
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/components/generators/hugging_face_api.py", line 187, in run
    return self._run_non_streaming(prompt, generation_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/components/generators/hugging_face_api.py", line 211, in _run_non_streaming
    tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/huggingface_hub/inference/_client.py", line 2061, in text_generation
    raise_text_generation_error(e)
  File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/huggingface_hub/inference/_common.py", line 457, in raise_text_generation_error
    raise exception from http_error
huggingface_hub.errors.ValidationError: Input validation error: `inputs` must have less than 4095 tokens. Given: 4701

预期行为

我的预期是,内置有一个截断功能,可以截断输入,以便不会将太多的令牌传递给模型。理想情况下,输入应该在特定部分进行截断(例如,而不是在提示的末尾截断,那样问题也会被截断)。

重现

query_template = """Beantworte die Frage basierend auf dem nachfolgenden Kontext und Chatverlauf. Antworte so detailliert wie möglich, aber nur mit Informationen aus dem Kontext. Wenn du eine Antwort nicht weißt, sag, dass du sie nicht kennst. Wenn sich eine Frage nicht auf den Kontext bezieht, sag, dass du die Frage nicht beantworten kannst und höre auf. Stelle nie selbst eine Frage.

Kontext:
{% for document in documents %}
    {{ document.content }}
{% endfor %}

Vorheriger Chatverlauf:
{{ history }}

Frage: {{ question }}
Antwort: """

    def _create_generator(self):
        if AnyChatConfig.hf_use_local_generator:
            return HuggingFaceLocalGenerator(
                model=self.model_id,
                task="text2text-generation",
                device=ComponentDevice.from_str("cuda:0"),
                huggingface_pipeline_kwargs={
                    "device_map": "auto",
                    "model_kwargs": {
                        "load_in_4bit": True,
                        "bnb_4bit_use_double_quant": True,
                        "bnb_4bit_quant_type": "nf4",
                        "bnb_4bit_compute_dtype": torch.bfloat16,
                    },
                },
                generation_kwargs={"max_new_tokens": 350},
            )
        else:
            return HuggingFaceAPIGenerator(
                api_type="text_generation_inference",
                api_params={"url": AnyChatConfig.hf_api_generator_url},
            )

    def create_llm_pipeline(self, document_store):
        """
        Creates an LLM pipeline that employs RAG on a document store that must have been set up before.
        :return:
        """
        # create the pipeline with the individual components
        self.llm_pipeline = Pipeline()
        self.llm_pipeline.add_component(
            "embedder",
            SentenceTransformersTextEmbedder(
                model=DocumentManager.embedding_model_id,
                device=ComponentDevice.from_str(
                    AnyChatConfig.hf_device_rag_text_embedder
                ),
            ),
        )
        self.llm_pipeline.add_component(
            "retriever",
            InMemoryEmbeddingRetriever(document_store=document_store, top_k=8),
        )
        self.llm_pipeline.add_component(
            "prompt_builder", PromptBuilder(template=query_template)
        )

        self.llm_pipeline.add_component("llm", self._create_generator())

        # connect the individual nodes to create the final pipeline
        self.llm_pipeline.connect("embedder.embedding", "retriever.query_embedding")
        self.llm_pipeline.connect("retriever", "prompt_builder.documents")
        self.llm_pipeline.connect("prompt_builder", "llm")

    def _get_formatted_history(self):
        history = ""
        for message in self.conversation_history:
            history += f"{message[0]}: {message[1]}\n"
        history = history.strip()

        return history

    def query_rag_in_document_store(self, query):
        """
        Uses the LLM and RAG to provide an answer to the given query based on the documents in the document store.
        :param query:
        :return:
        """
        logger.debug("querying using rag with: {}", query)

        # run the query through the pipeline
        result = self.llm_pipeline.run(
            {
                "embedder": {"text": query},
                "prompt_builder": {
                    "question": query,
                    "history": self._get_formatted_history(),
                },
                "llm": {"generation_kwargs": {"max_new_tokens": 350}},
            }
        )
        response = result["llm"]["replies"][0]
        post_processed_response = self._post_process_response(response, query)
        logger.debug(post_processed_response)

        self.conversation_history.append(("Frage", query))
        self.conversation_history.append(("Antwort", post_processed_response))

        return post_processed_response

常见问题解答检查

  • 你查看过our new FAQ page吗?
    系统:
  • OS: debian
  • GPU/CPU: NVIDIA RTX A6000 (CUDA 12.4)
  • Haystack版本(提交或版本号): 2.2.1
  • 文档存储:InMemoryDocumentStore
  • 阅读器:PyPDFToDocument
  • 检索器:InMemoryEmbeddingRetriever
krcsximq

krcsximq2#

感谢@fhamborg的建议,我们正在使用#6593跟踪截断提示的特定部分。关于由输入长度引起的错误,您可以将截断的最大长度作为generation_kwargsHuggingFaceAPIGenerator的一部分进行设置。这对您来说是否作为解决方法有效?https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation

wqnecbli

wqnecbli3#

感谢@julian-risch的快速回复!关于将truncation参数设置为某个值,我想虽然这有助于避免上述错误,但在输入过长的情况下(例如,问题是我提示符中的最后一个项目),它会截断实际问题,这可能会更糟。

有没有办法检索LLM的实际输入(或转换为该输入的文本,即可能被截断的输入)?这样我就可以比较我的完整提示符和实际的提示符(在可能的截断之后),如果实际上确实被截断了,我可以重新运行管道,但将检索器组件的top_k设置一个较低的值,例如。或者您认为最好只是捕获上面的异常,然后用减少的top_k重新运行吗?

编辑:我刚刚发现top_k参数必须在创建管道时设置,而不是在运行时设置。因此,上面的主意不幸地不起作用(除非我每次出现上述情况时都重新创建管道)。除了将top_k设置为非常低的值之外,您有什么办法既避免上述错误又截断问题吗?

相关问题