pytorch 我必须将数据集同时传递给加载器和RandomSampler吗?

uxhixvfz  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(131)

这是我的代码

def train_dataloader(self):
        if self._is_weighted_sampler:
            weights = list(self.label_weight_by_name.values())
            sampler = torch.utils.data.sampler.WeightedRandomSampler(
                torch.tensor(weights), len(weights)
            )
        else:
            sampler = torch.utils.data.RandomSampler(self._train_dataset)
        return DataLoader(self._train_dataset, batch_size=self._batch_size, shuffle=True, sampler=sampler)

注意,在加权采样器的情况下,它不需要数据集,但RandomSampler需要。
在RandomSampler的情况下,这意味着数据集被传递了两次。
我一定是错过了一些关于如何使用它,请纠正我。

cczfrluj

cczfrluj1#

实际上,看起来你对这种差异的看法是正确的;为什么一个调用需要数据集对象而另一个不需要,似乎没有一个直接明显的原因。根据文档,函数原型是:

torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)

深入研究源代码,可以看到data_source从未被Random_Sampler索引,并且只被用作len(data_source)。该对象生成索引,而数据集对象仅用于确定数据的长度。

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)

weights(即,应该设置为与数据集具有相同数量的元素),并返回一组索引,然后必须单独用于索引数据集对象。
也许开发人员的基本原理是“在WeightedRandomSampler的情况下,用户必须为数据源中的每个项定义一个权重。在RandomSampler的情况下,所有的权重都是1,所以用户可以简单地传递数据集对象本身,而不是定义一个1的向量。”我不明白为什么他们不把数据集的长度作为整数传递给RandomSampler

相关问题