pytorch 用'generate'方法进行推理时,如何得到T5模型的logits?

vsnjm48y  于 2022-11-09  发布在  Git
关注(0)|答案(1)|浏览(342)

我目前正在使用HuggingFace的T5实现来进行文本生成,更具体地说,我正在使用T5ForConditionalGeneration来解决文本分类问题。
经过训练,模型的性能总体上是非常令人满意的,但是我想知道的是我如何才能得到生成的logits?
我目前正在通过model.generate(**tokenizer_outputs)按照文档中的建议执行推理,但这只是输出ID本身,而没有任何其他内容。
我之所以需要logit值,是因为我想衡量模型生成的置信度,我不能100%确定我的方法是否正确,但我认为如果我能得到每个生成的标记的logit值并取平均值,我就能得到生成序列的整体置信度得分。
有人知道我怎么做吗?谢谢。

6g8kf2rb

6g8kf2rb1#

我对此很纠结,因为我不熟悉Transformers库是如何工作的,但是在看了源代码之后,您所要做的就是将参数output_scoresreturn_dict_in_generate设置为True
要了解更多信息,请查看方法transformers.generation_utils.GenerationMixin.generate

相关问题