pytorch Huggingface的“resume_from_checkpoint”是否有效?

mzsu5hc0  于 2023-04-21  发布在  其他
关注(0)|答案(3)|浏览(1029)

我现在的教练是:

training_args = TrainingArguments(
    output_dir=f"./results_{model_checkpoint}",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    weight_decay=0.01,
    push_to_hub=True,
    save_total_limit = 1,
    resume_from_checkpoint=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_qa["train"],
    eval_dataset=tokenized_qa["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics
)

训练结束后,在我的output_dir中,我有几个培训师保存的文件:

['README.md',
 'tokenizer.json',
 'training_args.bin',
 '.git',
 '.gitignore',
 'vocab.txt',
 'config.json',
 'checkpoint-5000',
 'pytorch_model.bin',
 'tokenizer_config.json',
 'special_tokens_map.json',
 '.gitattributes']

从文档中可以看出,resume_from_checkpoint将从最后一个检查点继续训练模型:
resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
但是当我调用trainer.train()时,它似乎删除了最后一个检查点并开始一个新的检查点:

Saving model checkpoint to ./results_distilbert-base-uncased/checkpoint-500
...
Deleting older checkpoint [results_distilbert-base-uncased/checkpoint-5000] due to args.save_total_limit

它是真的从最后一个检查点(即5000)开始继续训练,并从0开始新检查点的计数(保存500步后的第一个检查点-“checkpoint-500”),还是只是不继续训练?我还没有找到一种方法来测试它,文档也不清楚这一点。

omtl5h9j

omtl5h9j1#

查看代码,它首先加载检查点状态,更新已经运行的epoch数,并从那里继续训练到运行作业的epoch总数(没有重置为0)。
要看到它继续训练,请在检查点上调用trainer.train()之前增加num_train_epochs

icomxhvb

icomxhvb2#

您还应该将resume_from_checkpoint参数添加到trainer.train,并链接到checkpoint
trainer.train(resume_from_checkpoint="{<path-where-checkpoint-were_stored>/checkpoint-0000”)
0000-检查点编号的示例。
不要忘记在整个过程中安装您的驱动器。

5w9g7ksd

5w9g7ksd3#

是的,它工作了!当你调用trainer.train()时,你隐含地告诉它覆盖所有检查点并从头开始。你应该调用trainer.train(resume_from_checkpoint=True)或将resume_from_checkpoint设置为指向检查点路径的字符串。

相关问题