如何在图像分类pytorch中实现早期停止

xriantvc  于 2022-12-04  发布在  其他
关注(0)|答案(4)|浏览(260)

我是Pytorch和机器学习新手,我在本教程https://www.learnopencv.com/image-classification-using-transfer-learning-in-pytorch/中遵循本教程,并使用我的自定义数据集。然后我在本教程中遇到了同样的问题,但我不知道如何在Pytorch中进行提前停止,如果你有更好的不创建提前停止过程,请告诉我。

izkcnapc

izkcnapc1#

这是我在每个时代所做的

val_loss += loss
val_loss = val_loss / len(trainloader)
if val_loss < min_val_loss:
  #Saving the model
  if min_loss > loss.item():
    min_loss = loss.item()
    best_model = copy.deepcopy(loaded_model.state_dict())
    print('Min loss %0.2f' % min_loss)
  epochs_no_improve = 0
  min_val_loss = val_loss

else:
  epochs_no_improve += 1
  # Check early stopping condition
  if epochs_no_improve == n_epochs_stop:
    print('Early stopping!' )
    loaded_model.load_state_dict(best_model)

不知道有多正确(这段代码大部分是从别的网站上的帖子里拿来的,但是忘记了在哪里,所以我不能放参考链接。我只是稍微修改了一下),希望你能觉得有用,如果我说错了,请指出错误。谢谢

9rnv2umw

9rnv2umw2#

请尝试以下代码。

# Check early stopping condition
     if epochs_no_improve == n_epochs_stop:
        print('Early stopping!' )
        early_stop = True
        break
     else:
        continue
     break
if early_stop:
    print("Stopped")
    break
2uluyalo

2uluyalo3#

提前停止的想法是,如果在监控的数量上没有改善的迹象,则通过停止训练过程来避免过拟合,例如,验证损失在几次迭代之后停止减少。提前停止的最小实现需要3个组件:

  • best_score变量,用于存储验证损耗的最佳值
  • counter变量,用于跟踪运行的迭代次数
  • patience变量定义了允许在确认丢失时继续训练而不进行改进的历元数。如果counter超过此值,则停止训练过程。

伪代码如下所示

# Define best_score, counter, and patience for early stopping:
best_score = None
counter = 0
patience = 10
path = ./checkpoints # user_defined path to save model

# Training loop:
for epoch in range(num_epochs):
    # Compute training loss
    loss = model(features,labels,train_mask)
    
    # Compute validation loss
    val_loss = evaluate(model, features, labels, val_mask)
    
    if best_score is None:
        best_score = val_loss
    else:
        # Check if val_loss improves or not.
        if val_loss < best_score:
            # val_loss improves, we update the latest best_score, 
            # and save the current model
            best_score = val_loss
            torch.save({'state_dict':model.state_dict()}, path)
        else:
            # val_loss does not improve, we increase the counter, 
            # stop training if it exceeds the amount of patience
            counter += 1
            if counter >= patience:
                break

# Load best model 
print('loading model before testing.')
model_checkpoint = torch.load(path)

model.load_state_dict(model_checkpoint['state_dict'])

acc = evaluate_test(model, features, labels, test_mask)

我为Pytorch实现了一个通用的提前停止类,用于我的一些项目。它允许你选择任何感兴趣的验证量(损失,准确性等)。如果你喜欢更好的提前停止,那么请随时在repo early-stopping中检查它。也有一个示例笔记本供参考

wqsoz72f

wqsoz72f4#

在PyTorch中实现提前停止的一种方法是使用一个回调函数,在每个历元结束时调用该函数。该函数可以检查验证丢失,如果丢失在一定数量的历元内没有改善,则停止训练。
下面是一个如何实现这一点的示例:
定义一个函数来检查验证损失是否改善了def check_validation_loss(model,best_loss,current_epoch):

计算验证损失

val_loss = calculate_validation_loss(model)

# If the validation loss has not improved for 3 epochs, stop training
if current_epoch - best_loss['epoch'] >= 3:
    print('Stopping training, validation loss has not improved for 3 epochs')
    return True

# If the validation loss is better than the best loss, update the best loss
if val_loss < best_loss['loss']:
    best_loss['loss'] = val_loss
    best_loss['epoch'] = current_epoch

return False

定义函数以计算验证损失def calculate_validation_loss(model):

TODO:计算验证损失

定义培训循环

best_loss = {'loss': float('inf'), 'epoch': 0}

for epoch in range(1, num_epochs + 1):

训练模型一个时期

train_model(model, epoch)

# Check if we should stop training
if check_validation_loss(model, best_loss, epoch):
    break

这段代码使用一个字典来跟踪最佳验证损失和它发生的时间。check_validation_loss函数计算验证损失,并将其与最佳损失进行比较,如果应该停止训练,则返回True。
请注意,calculate_validation_loss函数并未在此程式码中实作,因此您需要为此加入自己的实作。train_model函数也未实作,但可以用您自己的训练程式码取代。
或者,您可以使用PyTorch中现有的提前停止实现之一,如torch.optim.lr.ReduceLROnPlateau或torch.utils.callbacks.EarlyStopping,而不是实现您自己的提前停止。这些实现的使用方式与上述代码类似,但为控制提前停止行为提供了更多的灵活性和选项。

相关问题