pytorch 使用GPT-2从输入嵌入中恢复输入ID

6ojccjat  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(157)

假设我有以下文本

aim = 'Hello world! you are a wonderful place to be in.'

我想使用GPT 2来产生input_id,然后产生嵌入,并从嵌入中恢复input_id,为此我做了以下操作:

from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("gpt2")

input_id可以定义为:

input_ids = tokenizer(aim)['input_ids']
#output: [15496, 995, 0, 345, 389, 257, 7932, 1295, 284, 307, 287, 13]

我可以解码这个以确保它重现目标:

tokenizer.decode(input_id)
#output: 'Hello world! you are a wonderful place to be in.'

正如所料!为了产生嵌入,我将input_ids转换为Tensor:

input_ids_tensor = torch.tensor([input_ids])

然后,我可以将我的嵌入式表示为:

# Generate the embeddings for input IDs 
with torch.no_grad():
    model_output = model(input_ids_tensor)
    last_hidden_states = model_output.last_hidden_state
    
# Extract the embeddings for the input IDs from the last hidden layer
input_embeddings = last_hidden_states[0,1:-1,:]

如前所述,我们的目标是使用input_embeddings并恢复input_id,所以我这样做:

x = torch.unsqueeze(input_embeddings, 1) # to make the dim acceptable
with torch.no_grad():
    text = model(x.long())
    decoded_text = tokenizer.decode(text[0].argmax(dim=-1).tolist())

但这样做我得到:

IndexError: index out of range in self

text = model(x.long())的层次上,我想知道我做错了什么?我如何使用我产生的嵌入来恢复input_id?

wqsoz72f

wqsoz72f1#

应该使用GPT2LMHeadModel而不是GPT2Model,因为GPT2Model没有预测头。

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Instantiate the model and tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Set the input text
text = "Hello, how are you?"

# Tokenize the input text
input_ids = tokenizer.encode(text, return_tensors='pt')

# Use the model's forward function to obtain logits
logits = model(input_ids).logits

# Obtain the predicted token IDs by getting the argmax of the logits along the token dimension
predicted_token_ids = torch.argmax(logits, dim=-1)

# Decode the predicted token IDs back to text
output_text = tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True)

# Print the output text and token IDs
print("Output text: ", output_text)
print("Output token IDs: ", predicted_token_ids.tolist())

输出:

Output text:  , I about you doing

Output token IDs:  [[11, 314, 546, 345, 1804, 198]]

输出文本看起来很奇怪,因为它们只预测步骤t的下一个标记,给定从步骤1到步骤t_1的标记。

Hello => ,
Hello, => I
Hello, how => about

要逐步生成文本,应使用generate函数。https://huggingface.co/docs/transformers/main_classes/text_generation

相关问题