在Pytorch lightning 中前进和train_step的区别?

omvjsjqw  于 2022-12-29  发布在  其他
关注(0)|答案(3)|浏览(267)

我在Pytorch Lightning中设置了一个迁移学习Resnet。结构是从这个wandb教程https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY中借来的
通过查看文档https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
我对def forward()和def training_step()方法之间的区别感到困惑。
最初在PL文档中,模型不在训练步骤中调用,而只在forward中调用。但是forward也不在训练步骤中调用。我一直在数据上运行模型,输出看起来很合理(我有一个图像回调,我可以看到模型正在学习,并在最后获得了很好的准确性结果)。但我担心,鉴于forward方法没有被调用,该模型不知何故没有得到实施?
型号代码为:

class TransferLearning(pl.LightningModule):
    "Works for Resnet at the moment"
    def __init__(self, model, learning_rate, optimiser = 'Adam', weights = [ 1/2288  , 1/1500], av_type = 'macro' ):
        super().__init__()
        self.class_weights = torch.FloatTensor(weights)
        self.optimiser = optimiser
        self.thresh  =  0.5
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        #add metrics for tracking 
        self.accuracy = Accuracy()
        self.loss= nn.CrossEntropyLoss()
        self.recall = Recall(num_classes=2, threshold=self.thresh, average = av_type)
        self.prec = Precision( num_classes=2, average = av_type )
        self.jacq_ind = JaccardIndex(num_classes=2)
        

        # init model
        backbone = model
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify damage 2 classes
        num_target_classes = 2
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        recall = self.recall(preds, y)
        precision = self.prec(preds, y)
        jac = self.jacq_ind(preds, y)

        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        self.log('train_recall', recall, on_step=True, on_epoch=True, logger=True)
        self.log('train_precision', precision, on_step=True, on_epoch=True, logger=True)
        self.log('train_jacc', jac, on_step=True, on_epoch=True, logger=True)
        return loss
  
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        recall = self.recall(preds, y)
        precision = self.prec(preds, y)
        jac = self.jacq_ind(preds, y)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        self.log('val_recall', recall, prog_bar=True)
        self.log('val_precision', precision, prog_bar=True)
        self.log('val_jacc', jac, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        recall = self.recall(preds, y)
        precision = self.prec(preds, y)
        jac = self.jacq_ind(preds, y)

        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        self.log('test_recall', recall, prog_bar=True)
        self.log('test_precision', precision, prog_bar=True)
        self.log('test_jacc', jac, prog_bar=True)

        return loss
    
    def configure_optimizers(self,):
        print('Optimise with {}'.format(self.optimiser) )
        # optimizer = self.optimiser_dict[self.optimiser](self.parameters(), lr=self.learning_rate)
                
                # Support Adam, SGD, RMSPRop and Adagrad as optimizers.
        if self.optimiser == "Adam":
            optimiser = optim.AdamW(self.parameters(), lr = self.learning_rate)
        elif self.optimiser == "SGD":
            optimiser = optim.SGD(self.parameters(), lr = self.learning_rate)
        elif self.optimiser == "Adagrad":
            optimiser = optim.Adagrad(self.parameters(), lr = self.learning_rate)
        elif self.optimiser == "RMSProp":
            optimiser = optim.RMSprop(self.parameters(), lr = self.learning_rate)
        else:
            assert False, f"Unknown optimizer: \"{self.optimiser}\""

        return optimiser
t30tvxxf

t30tvxxf1#

我对def forward()和def training_step()方法之间的区别感到困惑。
引用文献:
"在Lightning中,我们建议将训练与推理分离。training_step定义了完整的训练循环。我们鼓励用户使用forward定义推理操作。"
所以forward()定义了你的预测/推理行为,它甚至不需要成为你的training_step的一部分,你可以在training_step中定义你的整个训练循环,但是如果你想这样的话,你可以选择让它成为你的training_step的一部分,例如forward()不是training_step的一部分:

def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # in this case it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

模型不在训练步骤中调用,只在forward中调用。但forward也不在训练步骤中调用
train_step中没有调用forward()是因为self(x)为您做了这件事。您也可以显式调用forward()而不是使用call(x)
我担心的是,如果forward方法没有被调用,那么模型就没有被实现。
只要您看到self.log记录的指标朝着正确的方向移动,您就知道您的模型得到了正确的调用和学习。

xpszyzbs

xpszyzbs2#

training_step中的self(x)表示类的__call__函数,并将使用forward()函数。
您可以在PyTorch源代码中查看self(x)中发生了什么的更多细节:https://github.com/pytorch/pytorch/blob/b6672b10e153b63748874ca9008fd3160f38c3dd/torch/nn/modules/module.py#L1124

kpbpu008

kpbpu0083#

主要区别在于如何使用模型的输出。
在Lightning中,思想是以这样一种方式组织代码,即训练逻辑与推理逻辑是分离的。

**forward:**封装模型的使用方式,无论您是在训练还是在执行推理。
**training_step:**包含生成损失值以训练模型所需的所有计算。通常有额外的层,如解码器、鉴别器、损失函数等,它们只对训练有用,当训练模型用于推理时不需要。这里我们通常也调用forward()。

OP组织代码的方式是最佳实践:

def forward(self, x):
    self.feature_extractor.eval()
    with torch.no_grad():
        representations = self.feature_extractor(x).flatten(1)
    x = self.classifier(representations)
    return x

def training_step(self, batch, batch_idx):
    x, y = batch

    ## self(x) is the same as calling self.forward(x)
    logits = self(x)  
    
    # Loss computation is not part of forward because it's only
    # needed for training
    loss = self.loss(logits, y)

参考:Introduction to PyTorch Lightning(参见“前进与训练_步骤”章节)

相关问题