python 如何在转换器库中实现'stopping_criteria'参数?

yv5phkfx  于 2023-03-16  发布在  Python
关注(0)|答案(1)|浏览(349)

我正在为text-generation模型使用python huggingface transformers库,我需要知道如何实现我正在使用的generator()函数中的stopping_criteria参数。
我在此文档中找到了stopping_criteria参数:https://huggingface.co/transformers/main_classes/pipelines.html#transformers.TextGenerationPipeline
问题是,我只是不知道如何实施它。
我的代码:

from transformers import pipeline
generator = pipeline('text-generation', model='EleutherAI/gpt-neo-125M')
stl = StoppingCriteria(['###'])
res = generator(prompt, do_sample=True,stopping_criteria = stl)
6jjcrrmo

6jjcrrmo1#

这两种方法对我很有效。当你想停止时,your_condition为True。

class CustomStoppingCriteria(StoppingCriteria):
    def __init__(self):
        pass
    
    def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
        return your_condition

stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria()])

def custom_stopping_criteria(input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
    return your_condition

stopping_criteria = StoppingCriteriaList([custom_stopping_criteria])

相关问题