我想实现一个无限循环的数据集和数据加载器。这里是我尝试:
class Infinite(Dataset):
def __len__(self):
return HPARAMS.batch_size
# return 1<<30 # This causes huge memory usage.
def __getitem__(self, idx):
"""Randomly generates one new example."""
return sample_func_to_be_parallelized()
infinite_loader = DataLoader(
dataset=Infinite(),
batch_size=HPARAMS.batch_size,
num_workers=16,
worker_init_fn=lambda worker_id: np.random.seed(worker_id),
)
while True:
for idx, data in enumerate(infinite_loader):
# forward + backward on "data"
正如您所看到的,这里的主要挑战是__len()__
方法。如果我在那里输入足够大的数字,比如1〈〈30,症状是在train循环的第一次迭代中内存使用量将跳到10+GB。过一段时间,可能是由于OOM导致工作线程被杀死。
如果我在这里输入一个小数字,比如1或BATCH_SIZE,那么train循环中的采样“数据”将被周期性地复制。这不是我想要的,因为我希望在每次迭代中生成和训练新数据。
我猜过多内存使用的罪魁祸首是堆栈中的某个地方,有很多东西被缓存了,随便看看Python的东西,我不能确定在哪里。
有人能告诉我实现我想要的东西的最佳方法是什么吗?(使用DataLoader的并行加载,同时保证加载的每个批处理都是全新的。)
4条答案
按热度按时间r55awzrz1#
这似乎是在不定期复制数据的情况下工作:
结果:
我认为问题出在函数
sample_func_to_be_parallelized()
上。__getitem__
中使用np.random.randint(10, size=3)
而不是torch.randint(0, 10, (3,))
(作为sample_func_to_be_parallelized()
的示例),则数据确实在每个批中重复。请参阅issue。因此,如果您在
sample_func_to_be_parallelized()
中的某个地方使用numpy的RGN,那么解决方法是使用并且在每次调用
data = next(iter(data_loader))
之前将种子复位np.random.seed()
。kzipqqlq2#
DataLoader
对你的数据集进行采样 * 而不进行替换 *。为了做到这一点,它会生成一个0到len(dataset)
之间的索引的随机排列。我猜这个排列会消耗掉你的大部分内存。我不认为PyTorch API支持无限的集合,但是你可以尝试在DataLoader
中派生代码并自己做。你可以使用batch_sampler
参数,并传入一个基于RandomSampler
实现的自定义变量,这将允许您保留DataLoader
的并行加载部分。也就是说,基于
__len__
和__getitem__
的迭代协议并不适合无限次的收集,最好重新实现Dataset.__len__
,只返回1
,重新实现Dataset.__getitem__
,无论索引如何,总是返回新的样本,然后从这个数据集中对n
采样 * 次,并替换 *。它将为第0个样本请求n
次,但由于您覆盖了__getitem__
以返回不同的样本,因此这将有效地完成您所寻找的内容。kpbpu0083#
尝试使用
itertools
中的cycle
。下面是一个简单数据集的示例:代码:
输出:
x一个一个一个一个x一个一个二个x
dsekswqp4#