pytorch 更改lr_find的检查点路径

hi3rlvi2  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(130)

我想调整我的PyTorch Lightning模型的学习速率。我的代码在GPU集群上运行,所以我只能写入绑定挂载的特定文件夹。但是,trainer.tuner.lr_find试图将检查点写入我的脚本运行的文件夹,由于该文件夹不可写,因此它失败,并显示以下错误:
OSError: [Errno 30] Read-only file system: '/opt/xrPose/.lr_find_43df1c5c-0aed-4205-ac56-2fe4523ca4a7.ckpt'
有没有办法更改lr_find的检查点路径?我检查了文档,但在与checkpointing相关的部分中找不到任何有关的信息。
我的代码如下:

res = trainer.tuner.lr_find(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5)
logging.info(f"suggested learning rate: {res.suggestion()}")
model.hparams.learning_rate = res.suggestion()
rsaldnfx

rsaldnfx1#

初始化Trainer时,您可能需要指定default_root_dir

trainer = Trainer(default_root_dir='./my_dir')

官方文件中的描述:

default_root_dir-未传递任何记录器或pytorch_lightning.callbacks.ModelCheckpoint回调时日志和权重的默认路径。

程式码范例:

import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()

    def __getitem__(self, index):
        x = np.zeros((10,), np.float32)
        y = np.zeros((1,), np.float32)
        return x, y

    def __len__(self):
        return 100

class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.MSELoss()(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

model = MyModel()
trainer = Trainer(default_root_dir='./my_dir')
train_dataloader = DataLoader(MyDataset())
trainer.tuner.lr_find(model, train_dataloader)

相关问题