我正在尝试下面的代码www.example.comhttps://gist.github.com/Chris-hughes10/f5ff0b9100990a72dd62f29b2d93a803#file-train-py
我将data_dir更改为下面对本地数据集的描述
data_dir = r"C:\Users\walte\Desktop\dataset\0111"
if __name__ == "__main__":
main(data_dir)
并显示错误消息
TypeError Traceback (most recent call last)
c:\Users\walte\Desktop\timm\coAt_classifier_v2.ipynb Cell 9 in <cell line: 1>()
1 if __name__ == "__main__":
----> 2 main(data_dir)
c:\Users\walte\Desktop\timm\coAt_classifier_v2.ipynb Cell 9 in main(data_path)
46 validate_loss_fn = torch.nn.CrossEntropyLoss()
48 # print(optimizer)
49 # print(train_loss_fn)
50 # print(validate_loss_fn,)
51 # print(mixup_args)
52
53 # Create trainer and start training
---> 54 trainer = TimmMixupTrainer(
55 model=model,
56 optimizer=optimizer,
57 loss_func=train_loss_fn,
58 eval_loss_fn=validate_loss_fn,
59 num_classes=num_classes,
60 callbacks=[
61 *DEFAULT_CALLBACKS,
62 SaveBestModelCallback(watch_metric="accuracy", greater_is_better=True),
63 ],
64 )
...
----> 7 self.accuracy = torchmetrics.Accuracy(num_classes=num_classes)
8 self.ema_accuracy = torchmetrics.Accuracy(num_classes=num_classes)
9 self.ema_model = None
TypeError: __new__() missing 1 required positional argument: 'task'
我该怎么补救呢?
我发现这个错误发生在TimmMixupTrainer类中,但仍然不知道为什么会发生这种情况。
1条答案
按热度按时间k10s72fa1#
根据您的任务,您需要将
task
值分配给Accuracy
,例如task: ['binary', 'multiclass', 'multilabel']
。请参阅:https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html#module-interface