纠正Pytorch中的确认丢失?

e0bqpujr  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(147)

我对如何计算验证损失感到有点困惑?验证损失是在一个时期结束时计算,还是在批处理迭代期间也应该监视损失?下面我使用running_loss进行了计算,它是在批处理中累积的-但我想看看它是否是正确的方法?

def validate(loader, model, criterion):                       
    correct = 0                                               
    total = 0                                                 
    running_loss = 0.0                                        
    model.eval()                                              
    with torch.no_grad():                                     
        for i, data in enumerate(loader):                     
            inputs, labels = data                             
            inputs = inputs.to(device)                        
            labels = labels.to(device)                        

            outputs = model(inputs)                           
            loss = criterion(outputs, labels)                 
            _, predicted = torch.max(outputs.data, 1)         
            total += labels.size(0)                           
            correct += (predicted == labels).sum().item()     
            running_loss = running_loss + loss.item()         
    mean_val_accuracy = (100 * correct / total)               
    mean_val_loss = ( running_loss )                  
    #mean_val_accuracy = accuracy(outputs,labels)             
    print('Validation Accuracy: %d %%' % (mean_val_accuracy)) 
    print('Validation Loss:'  ,mean_val_loss )

下面是我正在使用的培训模块

def train(loader, model, criterion, optimizer, epoch):                                   
    correct = 0                                                                          
    running_loss = 0.0                                                                   
    i_max = 0                                                                            
    for i, data in enumerate(loader):                                                    
        total_loss = 0.0                                                                 
        #print('batch=',i)                                                               
        inputs, labels = data                                                            
        inputs = inputs.to(device)                                                       
        labels = labels.to(device)                                                       

        optimizer.zero_grad()                                                            
        outputs = model(inputs)                                                          
        loss = criterion(outputs, labels)                                                
        loss.backward()                                                                  
        optimizer.step()                                                                 

        running_loss += loss.item()                                                      
        if i % 2000 == 1999:                                                             
            print('[%d , %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))     
            running_loss = 0.0                                                           

    print('finished training')
    return mean_val_loss, mean_val_accuracy
8yparm6h

8yparm6h1#

您可以在需要时评估网络的验证。它可以是每个历元,或者如果因为数据集太大而成本太高,它可以是每个N历元。
您所做的似乎是正确的,您计算整个验证集的损失。您可以选择除以其长度,以规范化损失,这样,如果您有一天增加验证集,规模将是相同的。

相关问题