pytorch 即使在使用SubsetRandomSampler之后也无法加载并行数据集

dced5bon  于 2022-11-23  发布在  其他
关注(0)|答案(1)|浏览(154)

我有两个并行数据集dataset1dataset2,下面是我使用SubsetRandomSampler并行加载它们的代码,其中我提供train_indices用于数据加载。
P.S.即使在设置num_workers=0并播种nptorch之后,示例也不会并行加载。衷心欢迎您提出任何建议,包括SubsetRandomSampler以外的方法。

import torch, numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler

dataset1 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
dataset2 = torch.tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

train_indices = list(range(len(dataset1)))
torch.manual_seed(12)
np.random.seed(12)
np.random.shuffle(train_indices)
sampler = SubsetRandomSampler(train_indices)

dataloader1 = DataLoader(dataset1, batch_size=2, num_workers=0, sampler=sampler)
dataloader2 = DataLoader(dataset2, batch_size=2, num_workers=0, sampler=sampler)

for i, (data1, data2) in enumerate(zip(dataloader1, dataloader2)):
  x = data1
  y = data2
  print(x, y)

输出量:

tensor([5, 1]) tensor([15, 18])
tensor([0, 2]) tensor([14, 12])
tensor([4, 6]) tensor([16, 10])
tensor([8, 9]) tensor([11, 19])
tensor([7, 3]) tensor([17, 13])

预期输出:

tensor([5, 1]) tensor([15, 11])
tensor([0, 2]) tensor([10, 12])
tensor([4, 6]) tensor([14, 16])
tensor([8, 9]) tensor([18, 19])
tensor([7, 3]) tensor([17, 13])
sg24os4d

sg24os4d1#

看起来您正在尝试并行加载两个数据集,但让它们保持相同的打乱顺序。
目前,代码对dataset1的索引进行混洗,然后使用这些混洗后的索引对dataset1dataset2进行采样。但是,这并不能保证相同的元素在输出中配对在一起,因为dataset2是与dataset1分开混洗的。
为了达到预期的输出,您需要将两个数据集混洗在一起,然后使用混洗索引从两个数据集进行采样。一种方法是首先将两个数据集合并为一个数据集,该数据集包含每个数据集对应元素的元组,然后对合并后的数据集进行混洗。然后,您可以使用混洗索引创建两个单独的数据加载器。每个函数将返回来自每个数据集的相应元素。
下面是一个如何实现这一点的示例:

# combine the two datasets into a single dataset of tuples
combined_dataset = list(zip(dataset1, dataset2))

# shuffle the combined dataset
train_indices = list(range(len(combined_dataset)))
np.random.seed(12)
np.random.shuffle(train_indices)

# create the dataloaders
dataloader = DataLoader(combined_dataset, batch_size=2, num_workers=0, sampler=SubsetRandomSampler(train_indices))

# unpack the elements from the tuples in each batch
for i, (data1, data2) in enumerate(dataloader):
  x = data1
  y = data2
  print(x, y)

相关问题