on_train_epoch_end
的文档https://pytorch-lightning.readthedocs.io/en/stable/extensions/callbacks.html#on-train-epoch-end指出:
要在epoch结束时访问所有批处理输出,请执行以下操作之一:
1.在LightningModule中实现training_epoch_end,并通过模块OR访问输出
1.在回调实现中跨train batch钩子缓存数据,以在此钩子中进行后处理。
我正在尝试使用第一种替代方法,下面是LightningModule和Callback设置:
import pytorch_lightning as pl
from pytorch_lightning import Callback
class LightningModule(pl.LightningModule):
def __init__(self, *args):
super().__init__()
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
return {'batch': batch}
def training_epoch_end(self, training_step_outputs):
# training_step_outputs has all my batches
return
class MyCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# pl_module.batch ???
return
如何通过回调中的pl_module
访问输出?在回调中访问training_step_outputs
的推荐方法是什么?
2条答案
按热度按时间ercv8c1e1#
您可以将每个训练批次的输出存储在一个状态中,并在训练时期结束时访问它。下面是一个例子-
希望这对你有帮助!😀
yuvru6vn2#
下面是将epoch-end钩子从旧版本的Lightning转换到>= 2.0的方法:
使用前:
之后:
这里的例子使用了训练钩子,但它也适用于相应的验证、测试和预测钩子。
资料来源: