Pytorch -'method'类型的对象没有len()

0mkxixxg  于 2023-01-02  发布在  其他
关注(0)|答案(1)|浏览(226)

所以我试着用pytorch lightning 库来训练一个时间序列模型。
但是在运行下面的代码之后,

trainer = pl.Trainer(
    max_epochs = N_EPOCHS,
)

trainer.fit(model,data_module)

我得到这个错误,

/usr/local/lib/python3.8/dist-packages/torch/utils/data/sampler.py in __iter__(self)
     64 
     65     def __iter__(self) -> Iterator[int]:
---> 66         return iter(range(len(self.data_source)))
     67 
     68     def __len__(self) -> int:

TypeError: object of type 'method' has no len()

这是我的数据模块初始化

data_module = StockPriceDataModule(train_sequences, test_sequences, batch_size=BATCH_SIZE)
data_module.setup()

数据模块类

class StockPriceDataModule(pl.LightningDataModule):
    def __init__(self, train_sequences, test_sequences, batch_size = 8):
        super().__init__()
        self.train_sequences = train_sequences
        self.test_sequences = test_sequences
        self.batch_size = batch_size
        
    def setup(self, stage=None):
        self.train_dataset = StockDataset(self.train_sequences)
        self.test_dataset = StockDataset(self.test_sequences)
        
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size = self.batch_size,
            shuffle = False,
            num_workers = 2
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataloader,
            batch_size=1,
            shuffle = False,
            num_workers=1,
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataloader,
            batch_size=1,
            shuffle = False,
            num_workers=1,
        )

我有点初学者。所以我只是按照这个视频https://www.youtube.com/watch?v=ODEGJ_kh2aA来学习使用LSTM的多功能时间序列预测。我的代码几乎是一样的...
我做错什么了?

eqqqjvef

eqqqjvef1#

这里可能是错误的,但是你只需要在你的数据加载器类中创建一个方法,比如deflen(self):返回len(数据加载器)。
custom data loaders in Pytorch

相关问题