python-3.x 类型错误:__new__()缺少1个必需的位置参数:运行timm示例代码时执行“task”

wbrvyc0a  于 2023-01-18  发布在  Python
关注(0)|答案(1)|浏览(1041)

我正在尝试下面的代码www.example.comhttps://gist.github.com/Chris-hughes10/f5ff0b9100990a72dd62f29b2d93a803#file-train-py
我将data_dir更改为下面对本地数据集的描述

data_dir = r"C:\Users\walte\Desktop\dataset\0111"
if __name__ == "__main__":
    main(data_dir)

并显示错误消息

TypeError                                 Traceback (most recent call last)
c:\Users\walte\Desktop\timm\coAt_classifier_v2.ipynb Cell 9 in <cell line: 1>()
      1 if __name__ == "__main__":
----> 2     main(data_dir)

c:\Users\walte\Desktop\timm\coAt_classifier_v2.ipynb Cell 9 in main(data_path)
     46 validate_loss_fn = torch.nn.CrossEntropyLoss()
     48 # print(optimizer)
     49 # print(train_loss_fn)
     50 # print(validate_loss_fn,)
     51 # print(mixup_args)
     52 
     53 # Create trainer and start training
---> 54 trainer = TimmMixupTrainer(
     55     model=model,
     56     optimizer=optimizer,
     57     loss_func=train_loss_fn,
     58     eval_loss_fn=validate_loss_fn,
     59     num_classes=num_classes,
     60     callbacks=[
     61         *DEFAULT_CALLBACKS,
     62         SaveBestModelCallback(watch_metric="accuracy", greater_is_better=True),
     63     ],
     64 )
...
----> 7 self.accuracy = torchmetrics.Accuracy(num_classes=num_classes)
      8 self.ema_accuracy = torchmetrics.Accuracy(num_classes=num_classes)
      9 self.ema_model = None

TypeError: __new__() missing 1 required positional argument: 'task'

我该怎么补救呢?
我发现这个错误发生在TimmMixupTrainer类中,但仍然不知道为什么会发生这种情况。

k10s72fa

k10s72fa1#

根据您的任务,您需要将task值分配给Accuracy,例如task: ['binary', 'multiclass', 'multilabel']。请参阅:https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html#module-interface

相关问题