所以我试着用pytorch lightning 库来训练一个时间序列模型。
但是在运行下面的代码之后,
trainer = pl.Trainer(
max_epochs = N_EPOCHS,
)
trainer.fit(model,data_module)
我得到这个错误,
/usr/local/lib/python3.8/dist-packages/torch/utils/data/sampler.py in __iter__(self)
64
65 def __iter__(self) -> Iterator[int]:
---> 66 return iter(range(len(self.data_source)))
67
68 def __len__(self) -> int:
TypeError: object of type 'method' has no len()
这是我的数据模块初始化
data_module = StockPriceDataModule(train_sequences, test_sequences, batch_size=BATCH_SIZE)
data_module.setup()
数据模块类
class StockPriceDataModule(pl.LightningDataModule):
def __init__(self, train_sequences, test_sequences, batch_size = 8):
super().__init__()
self.train_sequences = train_sequences
self.test_sequences = test_sequences
self.batch_size = batch_size
def setup(self, stage=None):
self.train_dataset = StockDataset(self.train_sequences)
self.test_dataset = StockDataset(self.test_sequences)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size = self.batch_size,
shuffle = False,
num_workers = 2
)
def val_dataloader(self):
return DataLoader(
self.test_dataloader,
batch_size=1,
shuffle = False,
num_workers=1,
)
def test_dataloader(self):
return DataLoader(
self.test_dataloader,
batch_size=1,
shuffle = False,
num_workers=1,
)
我有点初学者。所以我只是按照这个视频https://www.youtube.com/watch?v=ODEGJ_kh2aA来学习使用LSTM的多功能时间序列预测。我的代码几乎是一样的...
我做错什么了?
1条答案
按热度按时间eqqqjvef1#
这里可能是错误的,但是你只需要在你的数据加载器类中创建一个方法,比如deflen(self):返回len(数据加载器)。
custom data loaders in Pytorch