我有一个LSTM模型,我想用它来进行时间预测。我有一个扩展pl.LightningDataModule
的类,在这个类中,我试图创建数据加载器来训练我的模型。
类BTCPriceDataModule
class BTCPriceDataModule(pl.LightningDataModule):
def __init__(self, train_sequences, test_sequences, batch_size = 8):
super().__init__()
self.train_sequence = train_sequences
self.test_sequences = test_sequences
self.batch_size = batch_size
def setup(self, stage=None):
self.train_dataset = BTCDataset(self.train_sequence)
self.test_dataset = BTCDataset(self.test_sequences)
def train_dataloader(self):
print("coming here")
return DataLoader(
self.train_dataset,
batch_size = self.batch_size,
shuffle = False,
num_workers=2
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=1,
shuffle=False,
num_workers=1
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=1,
shuffle=False,
num_workers=1
)
字符串
BTCDataset类
class BTCDataset(Dataset):
def __init__(self,sequences):
self.sequences = sequences
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
sequence, label = self.sequences[idx]
return dict(
sequence=torch.Tensor(sequence.to_numpy()),
label = torch.tensor(label).float()
)
型
当我运行trainer.fit()
时,我的健全性检查需要很长时间才能完成,所以我开始调试代码,检查在处理数据时是否有任何问题。当我通过DataLoader传递顺序数据时,它甚至无法完成一个循环。示例如下:
data_module = BTCPriceDataModule(train_sequences, test_sequences, batch_size=BATCH_SIZE)
data_module.setup()
for i in data_module.train_dataloader():
print(i['sequence'].shape)
print(i['label'].shape)
break
型
这个循环永远不会结束。我想我给了它足够的时间,也减少了我的数据集大小,但它仍然无法编译。len(data_module.train_dataloader())
是1921年,所以我很确定它不是很重,对我的CPU需要超过5分钟的编译。如何调试此问题?或者有解决办法吗?
1条答案
按热度按时间sdnqo3pr1#
改变num_workers = 0对我来说很有用。