pytorch 用于文本生成的T5模型的输出logits

kse8i1jr  于 2023-05-07  发布在  Git
关注(0)|答案(1)|浏览(404)

我正在使用Hugging Face上的T5模型进行文本摘要。我如何直接输出T5模型的logits,给定文本输入用于生成目的(而不是训练)?
我想逐个生成输出令牌,这样我就可以分别计算每个输出令牌的熵。.generate()方法似乎不适用于此。
我实际上想创建自己的生成函数,但我需要获得模型的logits才能做到这一点。

zzwlnbp8

zzwlnbp81#

你可以使用forward函数来获取logits,并应用argmax:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

input_ids = tokenizer("test here", padding="longest",
    max_length=128
    truncation=True,
    return_tensors="pt"
)

logits = model(**input_ids).logits

preds = F.softmax(logits, dim=-1).argmax(dim=-1)
y = tokenizer.batch_decode(sequences=preds, skip_special_tokens=True)

你可以在这里查看原始来源:多个序列的正向输出错误

相关问题