使用Pytorch Lightning在回调中访问epoch结束时的所有批处理输出

eqoofvh9  于 12个月前  发布在  其他
关注(0)|答案(2)|浏览(163)

on_train_epoch_end的文档https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html#on-train-epoch-end指出:
要在epoch结束时访问所有批处理输出,请执行以下操作之一:
1.在LightningModule中实现training_epoch_end,并通过模块OR访问输出
1.在回调实现中跨train batch钩子缓存数据,以在此钩子中进行后处理。
我正在尝试使用第一种替代方法,下面是LightningModule和Callback设置:

import pytorch_lightning as pl
from pytorch_lightning import Callback

class LightningModule(pl.LightningModule):
    def __init__(self, *args):
        super().__init__()
        self.automatic_optimization = False

    def training_step(self, batch, batch_idx):
        return {'batch': batch}

    def training_epoch_end(self, training_step_outputs):
        # training_step_outputs has all my batches
        return

class MyCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # pl_module.batch ???
        return

如何通过回调中的pl_module访问输出?在回调中访问training_step_outputs的推荐方法是什么?

ercv8c1e

ercv8c1e1#

您可以将每个训练批次的输出存储在一个状态中,并在训练时期结束时访问它。下面是一个例子-

from pytorch_lightning import Callback

class MyCallback(Callback):
    def __init__(self):
        super().__init__()
        self.state = []
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, unused=0):
        self.state.append(outputs)
        
    def on_train_epoch_end(self, trainer, pl_module):
        # access output using state
        all_outputs = self.state

希望这对你有帮助!😀

yuvru6vn

yuvru6vn2#

下面是将epoch-end钩子从旧版本的Lightning转换到>= 2.0的方法:
使用前:

import lightning as L

class LitModel(L.LightningModule):
    
    def training_step(self, batch, batch_idx):
        ...
        return {"loss": loss, "banana": banana}
    
    # `outputs` is a list of all bananas returned in the epoch
    def training_epoch_end(self, outputs):
        avg_banana = torch.cat(out["banana"] for out in outputs).mean()

之后:

import lightning as L

class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        # 1. Create a list to hold the outputs of `*_step`
        self.bananas = []
    
    def training_step(self, batch, batch_idx):
        ...
        # 2. Add the outputs to the list
        # You should be aware of the implications on memory usage
        self.bananas.append(banana)
        return loss
    
    # 3. Rename the hook to `on_*_epoch_end`
    def on_train_epoch_end(self):
        # 4. Do something with all outputs
        avg_banana = torch.cat(self.bananas).mean()
        # Don't forget to clear the memory for the next epoch!
        self.bananas.clear()

这里的例子使用了训练钩子,但它也适用于相应的验证、测试和预测钩子。
资料来源:

相关问题