载入PyTorch lightning 训练检查点

9w11ddsr  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(255)

我使用的是PyTorch Lightning 1.4.0版,并为数据集定义了以下类:

class CustomTrainDataset(Dataset):
    '''
    Custom PyTorch Dataset for training

    Args:
        data (pd.DataFrame) - DF containing product info (and maybe also ratings)
        all_itemIds (list) - Python3 list containing all Item IDs
    '''

    def __init__(self, data, all_orderIds):
        self.users, self.items, self.labels = self.get_dataset(data, all_orderIds)

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        return self.users[idx], self.items[idx], self.labels[idx]

    def get_dataset(self, data, all_orderIds):
        users, items, labels = [], [], []
        user_item_set = set(zip(train_ratings['CustomerID'], train_ratings['ItemCode']))

        num_negatives = 7
        for u, i in user_item_set:
            users.append(u)
            items.append(i)
            labels.append(1)
            for _ in range(num_negatives):
                negative_item = np.random.choice(all_itemIds)
                while (u, negative_item) in user_item_set:
                    negative_item = np.random.choice(all_itemIds)
                users.append(u)
                items.append(negative_item)
                labels.append(0)

        return torch.tensor(users), torch.tensor(items), torch.tensor(labels)

然后是PL类:
第一个
要加载保存的检查点,我已尝试:
第一个
但这些似乎不起作用。我如何加载这个保存的检查点?
谢谢你!

jljoyd4f

jljoyd4f1#

如这里所示,load_from_checkpoint是pytorch-lightning中加载权重的主要方式,它会自动加载训练中使用的超参数。因此,除了覆盖现有参数外,您不需要传递参数。我的建议是尝试trained_model = NCF.load_from_checkpoint("NCF_Trained.ckpt")

jvlzgdj9

jvlzgdj92#

在init方法中添加一行:

self.save_hyperparameters(logger=False)

那就叫

trained_model = NCF.load_from_checkpoint("NCF_Trained.ckpt")

相关问题