pytorch 当每个批次的样品形状不同时,如何计算损失?

myss37ts  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(155)

我有一个这样的训练函数:

def training():
    model.train()
    
    train_mae = []
    
    progress = tqdm(train_dataloader, desc='Training')
    for batch_index, batch in enumerate(progress):
        x = batch['x'].to(device)
        x_lengths = batch['x_lengths'].to(device)
        y = batch['y'].to(device)
        y_type = batch['y_type'].to(device)
        y_valid_indices = batch['y_valid_indices'].to(device)

        # Zero Gradients
        optimizer.zero_grad()

        # Forward pass
        y_first, y_second = model(x)

        losses = []

        for j in range(len(x_lengths)):
            x_length = x_lengths[j].item()

            if y_type[j].item() == 0:
                predicted = y_first[j]
            else:
                predicted = y_second[j]

            actual = y[j]
            
            valid_mask = torch.zeros_like(predicted, dtype=torch.bool)
            valid_mask[:x_length] = 1
            
            # Padding of -1 is removed from y
            indices_mask = y[j].ne(-1)
            valid_indices = y[j][indices_mask]

            valid_predicted = predicted[valid_mask]
            valid_actual = actual[valid_mask]
            
            loss = mae_fn(valid_predicted, valid_actual, valid_indices)
            
            losses.append(loss)

        # Backward pass and update
        loss = torch.stack(losses).mean()   # This fails due to different shapes
        loss.backward()

        optimizer.step()
        
        train_mae.append(loss.detach().cpu().numpy())

        progress.set_description(
            f"mae: {loss.detach().cpu().numpy():.4f}"
        )

    # Return the average MAEs for y type
    return (
        np.mean(train_mae)
    )

个字符
显然不能叠加这些损失,因为它们由于指数而具有不同的形状。在maes[indices]上取平均值将解决这个问题,但它会导致非常糟糕的测试损失。我应该如何计算损失,因为指数决定了形状取决于y_type。

q9rjltbz

q9rjltbz1#

你能得到每个批次的平均值,然后根据每个批次的大小将这些平均值合并组合起来吗?与较小的批次相比,较大的批次对最终平均值的贡献更大。这应该比所有批次的平均值更稳定。下面的例子。

import torch

#Test data
losses_perbatch = [torch.randn(8, 1), torch.randn(4, 1), torch.randn(2, 1)]

#Weighted mean
total_samples = sum([len(batch) for batch in losses_perbatch])
weighted_mean_perbatch = torch.tensor([batch.sum()
                                       for batch in losses_perbatch]) / total_samples
#Equivalent to:
# weighted_mean_perbatch = torch.tensor([batch.mean() * len(batch)
#                                        for batch in losses_perbatch]) / total_samples

final_weighted_loss = sum(weighted_mean_perbatch)

字符串

相关问题