pytorch 错误:尝试在自定义HF数据集上使用trainer.train()时,vars()参数必须具有__dict__属性?

mlmc2os5  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(159)

我有下面的模型,我试图微调(CLIP_ViT +分类头)。以下是我的模型定义:

class CLIPNN(nn.Module):

    def __init__(self, num_labels, pretrained_name="openai/clip-vit-base-patch32", dropout=0.1):
        super().__init__()
        self.num_labels = num_labels
        # load pre-trained transformer & processor
        self.transformer = CLIPVisionModel.from_pretrained(pretrained_name)
        self.processor = CLIPProcessor.from_pretrained(pretrained_name)
        # initialize other layers (head after the transformer body)
        self.classifier = nn.Sequential(
            nn.Linear(512, 128, bias=True),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout, inplace=False),
            nn.Linear(128, self.num_labels, bias=True))
        
        def forward(self, inputs, labels=None, **kwargs):
            logits = self.classifier(inputs)
            loss = None
            if labels is not None:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

            return SequenceClassifierOutput(
                loss=loss,
                logits=logits,
            )

我也有以下数据集的定义:

class CLIPDataset(nn.utils.data.Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __getitem__(self, idx):
        item = {"embeddings": nn.Tensor(self.embeddings[idx])}
        item['labels'] = nn.LongTensor([self.labels[idx]])
        return item

    def __len__(self):
        return len(self.labels)

注意:这里我假设模型是预先计算的嵌入,不计算嵌入,我知道如果我想微调CLIP基础模型,这不是正确的逻辑,我只是想让我的代码工作。
类似这样的东西会抛出一个错误:

model = CLIPNN(num_labels=2)
train_data = CLIPDataset(train_data, y_train)
test_data = CLIPDataset(test_data, y_test)

trainer = Trainer(
    model=model, args=training_args, train_dataset=train_data, eval_dataset=test_data
)
trainer.train()

TypeError Traceback(most recent call last)in -> 1 trainer.train()
~/anaconda 3/envs/pytorch_latest_p37/lib/python3.7/site-packages/transformers/trainer.py in train(self,resume_from_checkpoint,trial,ignore_keys_for_eval,**kwargs)1256 self.control = self.callback_handler.on_epoch_开始(args,self.state,self.control)1257 → 1258对于步骤,enumerate(epoch_iterator)中的输入:1259 1260 #如果继续培训,跳过任何已培训的步骤
~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/utils/data/dataloader.py in next(self)515 if self._sampler_iter is None:516 self._reset()→ 517 data = self._next_data()518 self._num_given += 1 519 if self._dataset_kind == _DatasetKind.Iterable and
~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)555 def _next_data(self):556 index = self._next_index()#可以引发StopIteration → 557 data = self._dataset_fetcher.fetch(index)#可以引发StopIteration 558如果self._pin_memory:559 data = _utils.pin_memory.pin_memory(data)
~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self,possibly_batched_index)45 else:46 data = self.dataset[possibly_batched_index] -> 47 return self.collate_fn(data)
~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/transformers/data/data_collator.py in default_data_collator(features,return_tensors)64 65 if return_tensors ==“pt”:-> 66 return torch_default_data_collator(features)67 elif return_tensors ==“tf”:返回tf_default_data_collator(features)
~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/transformers/data/data_collator.py in torch_default_data_collator(features)80 81 if not isinstance(features[0],(dict,BatchEncoding)):-> 82个特征= [vars(f)for f in features] 83 first = features[0] 84 batch = {}
~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/transformers/data/data_collator.py in(.0)80 81 if not isinstance(features[0],(dict,BatchEncoding)):-> 82个特征= [vars(f)for f in features] 83 first = features[0] 84 batch = {}
TypeError:vars()参数必须具有dict属性
知道我哪里做错了吗

xqk2d5yq

xqk2d5yq1#

您需要将label_names属性添加到Trainer
trainer = Trainer(model=model,args=training_args,train_dataset=train_data,label_names='labels '],eval_dataset=test_data)

相关问题