在PyTorch中实现“无限循环”数据集和数据加载器

2w2cym1i  于 2023-02-19  发布在  其他
关注(0)|答案(4)|浏览(252)

我想实现一个无限循环的数据集和数据加载器。这里是我尝试:

class Infinite(Dataset):
    def __len__(self):
        return HPARAMS.batch_size
#         return 1<<30 # This causes huge memory usage.
    def __getitem__(self, idx):
        """Randomly generates one new example."""
        return sample_func_to_be_parallelized()

infinite_loader = DataLoader(
    dataset=Infinite(), 
    batch_size=HPARAMS.batch_size, 
    num_workers=16,
    worker_init_fn=lambda worker_id: np.random.seed(worker_id),  
)

while True:
    for idx, data in enumerate(infinite_loader):
        # forward + backward on "data"

正如您所看到的,这里的主要挑战是__len()__方法。如果我在那里输入足够大的数字,比如1〈〈30,症状是在train循环的第一次迭代中内存使用量将跳到10+GB。过一段时间,可能是由于OOM导致工作线程被杀死。
如果我在这里输入一个小数字,比如1或BATCH_SIZE,那么train循环中的采样“数据”将被周期性地复制。这不是我想要的,因为我希望在每次迭代中生成和训练新数据。
我猜过多内存使用的罪魁祸首是堆栈中的某个地方,有很多东西被缓存了,随便看看Python的东西,我不能确定在哪里。
有人能告诉我实现我想要的东西的最佳方法是什么吗?(使用DataLoader的并行加载,同时保证加载的每个批处理都是全新的。)

r55awzrz

r55awzrz1#

这似乎是在不定期复制数据的情况下工作:

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 2

class Infinite(Dataset):

    def __len__(self):
        return BATCH_SIZE

    def __getitem__(self, idx):
        return torch.randint(0, 10, (3,))

data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)

batch_count = 0
while True:
    batch_count += 1
    print(f'Batch {batch_count}:')

    data = next(iter(data_loader))
    print(data)
    # forward + backward on "data"  

    if batch_count == 5:
        break

结果:

Batch 1:
tensor([[4, 7, 7],
        [0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
        [2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
        [8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
        [2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
        [2, 7, 5]])

我认为问题出在函数sample_func_to_be_parallelized()上。

    • 编辑**:如果我在__getitem__中使用np.random.randint(10, size=3)而不是torch.randint(0, 10, (3,))(作为sample_func_to_be_parallelized()的示例),则数据确实在每个批中重复。请参阅issue

因此,如果您在sample_func_to_be_parallelized()中的某个地方使用numpy的RGN,那么解决方法是使用

worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id)

并且在每次调用data = next(iter(data_loader))之前将种子复位np.random.seed()

kzipqqlq

kzipqqlq2#

DataLoader对你的数据集进行采样 * 而不进行替换 *。为了做到这一点,它会生成一个0到len(dataset)之间的索引的随机排列。我猜这个排列会消耗掉你的大部分内存。我不认为PyTorch API支持无限的集合,但是你可以尝试在DataLoader中派生代码并自己做。你可以使用batch_sampler参数,并传入一个基于RandomSampler实现的自定义变量,这将允许您保留DataLoader的并行加载部分。
也就是说,基于__len____getitem__的迭代协议并不适合无限次的收集,最好重新实现Dataset.__len__,只返回1,重新实现Dataset.__getitem__,无论索引如何,总是返回新的样本,然后从这个数据集中对n采样 * 次,并替换 *。它将为第0个样本请求n次,但由于您覆盖了__getitem__以返回不同的样本,因此这将有效地完成您所寻找的内容。

kpbpu008

kpbpu0083#

尝试使用itertools中的cycle。下面是一个简单数据集的示例:
代码:

from itertools import cycle

import torch
from torch.utils.data import Dataset, DataLoader

# Create some dummy data.
data = torch.tensor([[0, 0],
                     [1, 1],
                     [2, 2],
                     [3, 3]])

class DataSet(Dataset):
    """Our dataset. Iterates over tensor data"""

    def __init__(self, data):
        self.data = data
        self.n = self.data.shape[0]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.data[idx]

bs = 1  # batch size
workers = 1  # number of workers

dataset = DataSet(data)
data_loader = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=workers)

# Infinite loop.
print(f'batch size: {bs} | number of workers: {workers}')
for i, data in cycle(enumerate(data_loader)):
    print(i, data)

输出:
x一个一个一个一个x一个一个二个x

dsekswqp

dsekswqp4#

from torch.utils.data import DataLoader, Dataset, Sampler
import random

class listDataset(Dataset):
    def __init__(self):
        self.varList = [1,2,3,4]
    def __len__(self):
        return len(self.varList)
    def __getitem__(self, idx) :
        return self.varList[idx]

class customSampler(Sampler) :
    def __init__(self, dataset, shuffle):
        assert len(dataset) > 0
        self.dataset = dataset
        self.shuffle = shuffle

    def __iter__(self):
        order = list(range((len(self.dataset))))
        idx = 0
        while True:
            yield order[idx]
            idx += 1
            if idx == len(order):
                if self.shuffle:
                    random.shuffle(order)
                idx = 0

dset = listDataset()
sampler = customSampler(dset, shuffle=True)
loader = iter(DataLoader(dataset=dset, sampler=sampler, batch_size=6, num_workers=2))
for x in range(10):
    i = next(loader)
    print(i)

相关问题