我使用的是PyTorch Lightning 1.4.0版,并为数据集定义了以下类:
class CustomTrainDataset(Dataset):
'''
Custom PyTorch Dataset for training
Args:
data (pd.DataFrame) - DF containing product info (and maybe also ratings)
all_itemIds (list) - Python3 list containing all Item IDs
'''
def __init__(self, data, all_orderIds):
self.users, self.items, self.labels = self.get_dataset(data, all_orderIds)
def __len__(self):
return len(self.users)
def __getitem__(self, idx):
return self.users[idx], self.items[idx], self.labels[idx]
def get_dataset(self, data, all_orderIds):
users, items, labels = [], [], []
user_item_set = set(zip(train_ratings['CustomerID'], train_ratings['ItemCode']))
num_negatives = 7
for u, i in user_item_set:
users.append(u)
items.append(i)
labels.append(1)
for _ in range(num_negatives):
negative_item = np.random.choice(all_itemIds)
while (u, negative_item) in user_item_set:
negative_item = np.random.choice(all_itemIds)
users.append(u)
items.append(negative_item)
labels.append(0)
return torch.tensor(users), torch.tensor(items), torch.tensor(labels)
然后是PL类:
第一个
要加载保存的检查点,我已尝试:
第一个
但这些似乎不起作用。我如何加载这个保存的检查点?
谢谢你!
2条答案
按热度按时间jljoyd4f1#
如这里所示,
load_from_checkpoint
是pytorch-lightning中加载权重的主要方式,它会自动加载训练中使用的超参数。因此,除了覆盖现有参数外,您不需要传递参数。我的建议是尝试trained_model = NCF.load_from_checkpoint("NCF_Trained.ckpt")
jvlzgdj92#
在init方法中添加一行:
那就叫