描述bug我使用的模型(TrOCR模型):
使用以下内容时会出现问题:
- [x] 官方示例脚本:由nice tutorial (fine_tune)@NielsRogge完成
- [x] 我自己修改的脚本:(如下面的脚本)
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
class Dataset(Dataset):
def __init__(self, root_dir, df, processor, max_target_length=128):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.max_target_length = max_target_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# get file name + text
file_name = self.df['file_name'][idx]
text = self.df['text'][idx]
# prepare image (i.e. resize + normalize)
image = Image.open(self.root_dir + file_name).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = self.processor.tokenizer(text,
padding="max_length",
max_length=self.max_target_length).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
# encoding
return {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id
# python3 train.py path/to/labels path/to/images/
- 平台:Linux上的Linux Ubuntu发行版[GCC 9.4.0]
- PyTorch版本(GPU?):0.8.2+Cu110
- 变压器:4.22.2
- Python版本:3.8.10
清晰简洁地描述bug是什么。要重现重现行为的步骤:
1.在训练模型之后或在评估指标计算的训练阶段期间,我看到模型添加了标记<s><s>
或ids [0,0, ......,2,1,1, 1 ]
的双开头
1.以下是training
阶段的一个示例,显示了在compute_metrics输入预测中生成的令牌:[[0,0,506,4422,8046,2,1,1,1,1,1]]
输入参考:[[0,597,2747 ...,1,1,1]]
testing
型号[x1c 0d1x]期间的其他示例
预期行为对您预期发生的事情的清晰而简洁的描述。在2个重现的问题中:我在训练期间期望输入预测:[[,0,,506,4422,8046,2,1,1,1,1,1 ]]
此外,在测试阶段:生成的文本没有双tensor([[0,11867,405,22379,1277,..........,368,2]])
<s>ennyit erről, tőlem fényképezz amennyit akarsz, a véleményem akkor</s>
1条答案
按热度按时间qlvxas9a1#
这个问题来自于传递的令牌ID。我正在添加来自tokenizer的开始令牌+来自TrOCR模型的另一个开始令牌,因此会发生重复。解决方案非常简单,只需使用
labels = labels[1:]
跳过来自tokenizer的开始令牌