我有一个比内存还大的文本文件,我想在PyTorch中创建一个数据集,它可以逐行读取,这样我就不必在内存中一次加载它了。我发现pytorch IterableDataset
是解决我的问题的潜在方案。它只在使用一个工作线程时才能按预期工作,如果使用多个工作线程,它将创建重复的记录。让我给你看一个例子:
具有testfile.txt
,其中包含:
0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line
定义可迭代数据集:
class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename):
#Store the filename in object's memory
self.filename = filename
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def line_mapper(self, line):
#Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
return text, label
def __iter__(self):
#Create an iterator
file_itr = open(self.filename)
#Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
return mapped_itr
我们现在可以测试它:
base_dataset = CustomIterableDatasetv1("testfile.txt")
#Wrap it around a dataloader
dataloader = DataLoader(base_dataset, batch_size = 1, num_workers = 1)
for X, y in dataloader:
print(X,y)
它输出:
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
是的。但是如果我把工人的数量改为2,产量就变成
('0',) (' Dummy line\n',)
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)
这是不正确的,因为在数据加载器中为每个工作线程创建了每个样本的副本。
pytorch有办法解决这个问题吗?所以可以创建一个数据加载器,不加载内存中的所有文件,支持多个工作者。
2条答案
按热度按时间zdwk9cvp1#
所以我在torch讨论论坛https://discuss.pytorch.org/t/iterable-pytorch-dataset-with-multiple-workers/135475/3中找到了答案,他们指出我应该使用worker info连续切片到批量大小。
新数据集将如下所示:
特别感谢@Ivan,他也指出了切片解决方案。
对于两个工作进程,它只返回与一个工作进程相同的数据
eqqqjvef2#
你可以使用
torch.utils.data.get_worker_info
实用程序访问Dataset
的__iter__
函数中的worker标识符。这意味着你可以遍历迭代器并根据worker * id * 添加一个偏移量。你可以用itertools.islice
Package 迭代器,这样你就可以遍历start
索引和step
。下面是一个最小的例子:
即使我们使用的是
num_workers > 1
,遍历数据加载器也会产生唯一的示例:在您的情况下,您可以: