python-3.x 生成给定文本字符串的嵌入

jgovgodb  于 2023-01-22  发布在  Python
关注(0)|答案(1)|浏览(128)

我正在尝试生成一个给定文本字符串的嵌入

conversation = 'This is the text'

注意,我正在处理的文本是一个巨大的文本块,有超过512个标记。
下面的代码:

import torch
from transformers import BertModel, BertTokenizer

# Load a pre-trained BERT model and tokenizer
model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


# Set the chunk size
chunk_size = 512

# Process the document in chunks
for i in range(0, len(conversation), chunk_size):
    # Encode the text input
    encoded_text = tokenizer.encode(conversation[i:i+chunk_size], add_special_tokens=True)

    # Convert input to a tensor
    input_ids = torch.tensor([encoded_text]).unsqueeze(0)

    # Generate the embeddings
    outputs = model(input_ids)[1]
    embeddings = outputs[0]

但是,当我运行它时,我得到了以下错误:

-
ValueError                                Traceback (most recent call last)
<ipython-input-57-d63fec5d74fa> in <module>
     20 
     21     # Generate the embeddings
---> 22     outputs = model(input_ids)[1]
     23     embeddings = outputs[0]
     24 

1 frames
/usr/local/lib/python3.8/dist-packages/transformers/models/bert/modeling_bert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    973             raise ValueError("You have to specify either input_ids or inputs_embeds")
    974 
--> 975         batch_size, seq_length = input_shape
    976         device = input_ids.device if input_ids is not None else inputs_embeds.device
    977 

ValueError: too many values to unpack (expected 2)

请,我期待得到一个嵌入,代表文本的内容。任何帮助将不胜感激

hjzp0vay

hjzp0vay1#

我建议您使用一次标记化器来获取如下所示的输入here

chunk_size = 512

# Process the document in chunks
for i in range(0, len(conversation), chunk_size):
    text_chunk = conversation[i:i+chunk_size]
    input_ids = tokenizer(text_chunk, return_tensors='pt')

    outputs = model(**input_ids)[1]
    embeddings = outputs[0]
    #print(embeddings.shape)

相关问题