pytorch python - TypeError:__init__()获取了意外的关键字参数“checkpoint_callback”

djmepvbi  于 11个月前  发布在  Python
关注(0)|答案(2)|浏览(225)

我得到这个错误消息:

TypeError                                 Traceback (most recent call last)
<ipython-input-41-2892cdd4e738> in <module>()
      5   max_epochs=N_EPOCHS,
      6   gpus=1, #GPU
----> 7   progress_bar_refresh_rate=30
      8 )

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/argparse.py in insert_env_defaults(self, *args, **kwargs)
    343 
    344         # all args were already moved to kwargs
--> 345         return fn(self, **kwargs)
    346 
    347     return cast(_T, insert_env_defaults)

TypeError: __init__() got an unexpected keyword argument 'checkpoint_callback'

字符串
当我运行这个块时:

trainer = pl.Trainer(
  logger=logger, 
  checkpoint_callback=checkpoint_callback,
  callbacks=[early_stopping_callback],
  max_epochs=N_EPOCHS,
  gpus=1, #GPU
  progress_bar_refresh_rate=30
)


checkpoint_callback的定义如下:

checkpoint_callback = ModelCheckpoint(
  dirpath="checkpoints",
  filename="best-checkpoint",
  save_top_k=1,
  verbose=True,
  monitor="val_loss",
  mode="min"
)


我不知道是什么原因导致了这个错误,有人能帮我吗?
查看完整的源代码:https://colab.research.google.com/drive/1hT7PDVb0oGSpLejMGFBMWzRKTPwsSwwS?usp=sharing

lf3rwulv

lf3rwulv1#

当我查看pytorch_lightning github时,在inithttps://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/trainer/trainer.py)中没有看到checkpoint_callback变量
你确定它是这么叫的吗?你想通过传递这个checkpoint_callback来实现什么?
//edit:我认为您只需将checkpoint_callback追加到callbacks列表中即可

jc3wubiy

jc3wubiy2#

要保存最佳模型,您可以使用enable_checkpointing True或False。

相关问题