Pytorch lightning 中来自多个GPU设置的组合损失和预测

amrnrhlw  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(180)

嗨,我在收集所有的损失和预测多GPU的情况下面临的问题。我正在使用pytorch lightning 2.0.4和deepspeed,分布式策略- deepspeed_stage_2。
我在这里添加我的框架代码以供参考。

def __init__(self):
        self.batch_train_preds = []
        self.batch_train_losses = []

    def  training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        # Model Step
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=train_labels)

        train_preds = torch.argmax(outputs.logits, dim=-1)

        return {'loss': outputs[0],
                'train_preds': train_preds}

    def on_train_batch_end(self, outputs, batch, batch_idx):
        # aggregate metrics or outputs at batch level
        train_batch_loss = outputs["loss"].mean()
        train_batch_preds = torch.cat(outputs["train_preds"])

        self.batch_train_preds.append(train_batch_preds)
        self.batch_train_losses.append(train_batch_loss.item())

        return {'train_batch_loss': train_batch_loss,
                'train_batch_preds': train_batch_preds
                }

    def on_train_epoch_end(self) -> None:
        # Aggregate epoch level training metrics

        epoch_train_preds = torch.cat(self.batch_train_preds)
        epoch_train_loss = np.mean(self.batch_train_losses)

        self.logger.log_metrics({"epoch_train_loss": epoch_train_loss})

字符串
在上面的代码块中,我试图通过跟踪全局列表(在init定义)中的每个批次,在epoch结束时将所有预测组合到单个Tensor中。但是在多GPU训练中,我遇到了一个错误,因为每个GPU都在自己的设备中处理批次,我无法将结果合并到一个全局列表中。
问:
我应该在on_train_batch_end或on_train_epoch_end或training_step中做什么,以便将所有gpu的结果合并到在我的init中创建的列表中,因为我想在我的训练,验证,测试中的ON_*_EPOCH_END()函数期间计算一些额外的指标(精度,召回率等)
(验证和测试与我上面的3个训练函数完全相似,即结合损失和预测)。
我遇到过all_gather,但它在所有设备(GPU)上都被调用,但合并了我想要的结果。
现在的问题是如何使用all_gather中的一个设备输出。一个代码片段将是非常有帮助的。

eaf3rand

eaf3rand1#

lightning文档建议使用all_gather。而且,您不需要手动累计损失,只需使用self.log(..., epoch=True)记录即可让 lightning 累计并正确记录:

class MyLightningModule(LightningModule):

    def __init__(self):
        super().__init__()
        self.batch_train_preds = []

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        # Model Step
        outputs = self.model(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )

        loss = outputs[0]

        train_preds = torch.argmax(outputs.logits, dim=-1)
        self.batch_train_preds.append(train_preds)

        self.log('train/loss', loss, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def on_train_epoch_end(self) -> None:

        # Aggregate epoch level training metrics
        epoch_train_preds = torch.cat(self.batch_train_preds, dim=0)

        # the following will stack predictions from all the distributed processes on dim=0
        epoch_train_preds = self.all_gather(epoch_train_preds)

        # reshape to (dataset_size, *other_dims)
        new_batch_size = self.trainer.world_size() * epoch_train_preds.shape[0]
        epoch_train_preds = epoch_train_preds.view(new_batch_size, *epoch_train_preds.shape[1:])

        # compute here your metrics over `epoch_train_preds`

        self.batch_train_preds.clear()  # free memory

字符串
如果只想在单个进程上计算度量,请使用if self.trainer.global_rank == 0:保护度量计算。
我还建议看看torchmetrics,它可以通过几行代码在分布式设置中自动同步指标。
此外,我还编写了一个framework,用于轻松训练和测试NLP的几个Transformer模型。

使用torchmetrics的附加示例

from torchmetrics.classification import BinaryAccuracy
from lightning.pytorch import LightningModule

class MyLightningModule(LightningModule):

    def __init__(self):
        super().__init__()
        self.train_accuracy = BinaryAccuracy()

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        # Model Step
        outputs = self.model(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )

        loss = outputs[0]

        train_preds = torch.argmax(outputs.logits, dim=-1)
        self.train_accuracy(train_preds, labels)  # updates the metric internal state with predictions and labels

        self.log('train/loss', loss, on_step=True, on_epoch=True, sync_dist=True)
        self.log('train/acc', self.train_accuracy, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def on_train_epoch_end(self) -> None:
        pass  # no need to reset the metric as lightning will take care of that after each epoch

相关问题