这是我的代码
def train_dataloader(self):
if self._is_weighted_sampler:
weights = list(self.label_weight_by_name.values())
sampler = torch.utils.data.sampler.WeightedRandomSampler(
torch.tensor(weights), len(weights)
)
else:
sampler = torch.utils.data.RandomSampler(self._train_dataset)
return DataLoader(self._train_dataset, batch_size=self._batch_size, shuffle=True, sampler=sampler)
注意,在加权采样器的情况下,它不需要数据集,但RandomSampler需要。
在RandomSampler的情况下,这意味着数据集被传递了两次。
我一定是错过了一些关于如何使用它,请纠正我。
1条答案
按热度按时间cczfrluj1#
实际上,看起来你对这种差异的看法是正确的;为什么一个调用需要数据集对象而另一个不需要,似乎没有一个直接明显的原因。根据文档,函数原型是:
深入研究源代码,可以看到
data_source
从未被Random_Sampler
索引,并且只被用作len(data_source)
。该对象生成索引,而数据集对象仅用于确定数据的长度。从
weights
(即,应该设置为与数据集具有相同数量的元素),并返回一组索引,然后必须单独用于索引数据集对象。也许开发人员的基本原理是“在
WeightedRandomSampler
的情况下,用户必须为数据源中的每个项定义一个权重。在RandomSampler
的情况下,所有的权重都是1,所以用户可以简单地传递数据集对象本身,而不是定义一个1的向量。”我不明白为什么他们不把数据集的长度作为整数传递给RandomSampler
。