描述bug
在使用标准的RAG管道时,我遇到了上述错误。
错误信息
File "/home/felix/PycharmProjects/anychat/src/anychat/analysis/rag.py", line 124, in query_rag_in_document_store
result = self.llm_pipeline.run(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/core/pipeline/pipeline.py", line 197, in run
res = comp.run(**last_inputs[name])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/components/generators/hugging_face_api.py", line 187, in run
return self._run_non_streaming(prompt, generation_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/haystack/components/generators/hugging_face_api.py", line 211, in _run_non_streaming
tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/huggingface_hub/inference/_client.py", line 2061, in text_generation
raise_text_generation_error(e)
File "/home/felix/anaconda3/envs/anychat/lib/python3.11/site-packages/huggingface_hub/inference/_common.py", line 457, in raise_text_generation_error
raise exception from http_error
huggingface_hub.errors.ValidationError: Input validation error: `inputs` must have less than 4095 tokens. Given: 4701
预期行为
我的预期是,内置有一个截断功能,可以截断输入,以便不会将太多的令牌传递给模型。理想情况下,输入应该在特定部分进行截断(例如,而不是在提示的末尾截断,那样问题也会被截断)。
重现
query_template = """Beantworte die Frage basierend auf dem nachfolgenden Kontext und Chatverlauf. Antworte so detailliert wie möglich, aber nur mit Informationen aus dem Kontext. Wenn du eine Antwort nicht weißt, sag, dass du sie nicht kennst. Wenn sich eine Frage nicht auf den Kontext bezieht, sag, dass du die Frage nicht beantworten kannst und höre auf. Stelle nie selbst eine Frage.
Kontext:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Vorheriger Chatverlauf:
{{ history }}
Frage: {{ question }}
Antwort: """
def _create_generator(self):
if AnyChatConfig.hf_use_local_generator:
return HuggingFaceLocalGenerator(
model=self.model_id,
task="text2text-generation",
device=ComponentDevice.from_str("cuda:0"),
huggingface_pipeline_kwargs={
"device_map": "auto",
"model_kwargs": {
"load_in_4bit": True,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
},
generation_kwargs={"max_new_tokens": 350},
)
else:
return HuggingFaceAPIGenerator(
api_type="text_generation_inference",
api_params={"url": AnyChatConfig.hf_api_generator_url},
)
def create_llm_pipeline(self, document_store):
"""
Creates an LLM pipeline that employs RAG on a document store that must have been set up before.
:return:
"""
# create the pipeline with the individual components
self.llm_pipeline = Pipeline()
self.llm_pipeline.add_component(
"embedder",
SentenceTransformersTextEmbedder(
model=DocumentManager.embedding_model_id,
device=ComponentDevice.from_str(
AnyChatConfig.hf_device_rag_text_embedder
),
),
)
self.llm_pipeline.add_component(
"retriever",
InMemoryEmbeddingRetriever(document_store=document_store, top_k=8),
)
self.llm_pipeline.add_component(
"prompt_builder", PromptBuilder(template=query_template)
)
self.llm_pipeline.add_component("llm", self._create_generator())
# connect the individual nodes to create the final pipeline
self.llm_pipeline.connect("embedder.embedding", "retriever.query_embedding")
self.llm_pipeline.connect("retriever", "prompt_builder.documents")
self.llm_pipeline.connect("prompt_builder", "llm")
def _get_formatted_history(self):
history = ""
for message in self.conversation_history:
history += f"{message[0]}: {message[1]}\n"
history = history.strip()
return history
def query_rag_in_document_store(self, query):
"""
Uses the LLM and RAG to provide an answer to the given query based on the documents in the document store.
:param query:
:return:
"""
logger.debug("querying using rag with: {}", query)
# run the query through the pipeline
result = self.llm_pipeline.run(
{
"embedder": {"text": query},
"prompt_builder": {
"question": query,
"history": self._get_formatted_history(),
},
"llm": {"generation_kwargs": {"max_new_tokens": 350}},
}
)
response = result["llm"]["replies"][0]
post_processed_response = self._post_process_response(response, query)
logger.debug(post_processed_response)
self.conversation_history.append(("Frage", query))
self.conversation_history.append(("Antwort", post_processed_response))
return post_processed_response
常见问题解答检查
- 你查看过our new FAQ page吗?
系统: - OS: debian
- GPU/CPU: NVIDIA RTX A6000 (CUDA 12.4)
- Haystack版本(提交或版本号): 2.2.1
- 文档存储:InMemoryDocumentStore
- 阅读器:PyPDFToDocument
- 检索器:InMemoryEmbeddingRetriever
3条答案
按热度按时间mrphzbgm1#
Related to #6593
krcsximq2#
感谢@fhamborg的建议,我们正在使用#6593跟踪截断提示的特定部分。关于由输入长度引起的错误,您可以将截断的最大长度作为
generation_kwargs
HuggingFaceAPIGenerator
的一部分进行设置。这对您来说是否作为解决方法有效?https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generationwqnecbli3#
感谢@julian-risch的快速回复!关于将
truncation
参数设置为某个值,我想虽然这有助于避免上述错误,但在输入过长的情况下(例如,问题是我提示符中的最后一个项目),它会截断实际问题,这可能会更糟。有没有办法检索LLM的实际输入(或转换为该输入的文本,即可能被截断的输入)?这样我就可以比较我的完整提示符和实际的提示符(在可能的截断之后),如果实际上确实被截断了,我可以重新运行管道,但将检索器组件的
top_k
设置一个较低的值,例如。或者您认为最好只是捕获上面的异常,然后用减少的top_k
重新运行吗?编辑:我刚刚发现
top_k
参数必须在创建管道时设置,而不是在运行时设置。因此,上面的主意不幸地不起作用(除非我每次出现上述情况时都重新创建管道)。除了将top_k
设置为非常低的值之外,您有什么办法既避免上述错误又截断问题吗?