python pytorch-lightning train_dataloader对象在数据循环中花费了很长时间

eqoofvh9  于 2023-08-02  发布在  Python
关注(0)|答案(1)|浏览(100)

我有一个LSTM模型,我想用它来进行时间预测。我有一个扩展pl.LightningDataModule的类,在这个类中,我试图创建数据加载器来训练我的模型。

类BTCPriceDataModule

class BTCPriceDataModule(pl.LightningDataModule):

    def __init__(self, train_sequences, test_sequences, batch_size = 8):
        super().__init__()
        self.train_sequence = train_sequences
        self.test_sequences = test_sequences
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = BTCDataset(self.train_sequence)
        self.test_dataset = BTCDataset(self.test_sequences)

    def train_dataloader(self):
        print("coming here")
        return DataLoader(
            self.train_dataset,
            batch_size = self.batch_size,
            shuffle = False,
            num_workers=2
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=1
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=1
        )

字符串

BTCDataset类

class BTCDataset(Dataset):
    def __init__(self,sequences):
        self.sequences = sequences
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence, label = self.sequences[idx]

        return dict(
            sequence=torch.Tensor(sequence.to_numpy()),
            label = torch.tensor(label).float()
        )


当我运行trainer.fit()时,我的健全性检查需要很长时间才能完成,所以我开始调试代码,检查在处理数据时是否有任何问题。当我通过DataLoader传递顺序数据时,它甚至无法完成一个循环。示例如下:

data_module = BTCPriceDataModule(train_sequences, test_sequences, batch_size=BATCH_SIZE)
data_module.setup()

for i in data_module.train_dataloader():
    print(i['sequence'].shape)
    print(i['label'].shape)
    break


这个循环永远不会结束。我想我给了它足够的时间,也减少了我的数据集大小,但它仍然无法编译。len(data_module.train_dataloader())是1921年,所以我很确定它不是很重,对我的CPU需要超过5分钟的编译。如何调试此问题?或者有解决办法吗?

sdnqo3pr

sdnqo3pr1#

改变num_workers = 0对我来说很有用。

相关问题