我构建了一个数据集,在这里我对加载的图像进行各种检查,然后将这个数据集传递给一个数据加载器。
在我的DataSet类中,如果图片未通过我的检查,我会将样本返回为None,并且我有一个自定义的collate_fn函数,该函数从检索的批处理中删除所有None,并返回剩余的有效样本。
然而,此时返回的批大小可能不同。是否有方法告诉collate_fn保持来源数据,直到批大小达到一定长度?
class DataSet():
def __init__(self, example):
# initialise dataset
# load csv file and image directory
self.example = example
def __getitem__(self,idx):
# load one sample
# if image is too dark return None
# else
# return one image and its equivalent label
dataset = Dataset(csv_file='../', image_dir='../../')
dataloader = DataLoader(dataset , batch_size=4,
shuffle=True, num_workers=1, collate_fn = my_collate )
def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] could return something like G = [None, {},{},{}]
batch = list(filter (lambda x:x is not None, batch)) # this gets rid of nones in batch. For example above it would result to G = [{},{},{}]
# I want len(G) = 4
# so how to sample another dataset entry?
return torch.utils.data.dataloader.default_collate(batch)
6条答案
按热度按时间xuo3flqw1#
有2个黑客可以用来解决问题,选择一种方式:
使用原始批次样品快速选项:
否则,只需从数据集中随机加载另一个样本更好的选项:
xj3cbfub2#
这对我很有效,因为有时候甚至那些随机值都是None。
hs1rzwqc3#
对于希望即时拒绝训练示例的任何人,可以简单地使用IterableDataset并编写__iter__和__next__函数,而不是使用技巧来解决数据加载器的collate_fn中的问题,如下所示
cnjp1d6j4#
[编辑]从下面截取的代码的更新版本可以在这里找到https://github.com/project-lighter/lighter/blob/main/lighter/utils/collate.py
感谢Brian Formento提出问题并给出解决方法。如前所述,用新示例替换坏示例的最佳选择有两个问题:
1.新采样的示例也可能被破坏;
1.数据集不在范围内。
下面是这两个问题的解决方案-问题1通过递归调用解决,问题2通过创建collate函数的partial函数并将数据集固定在适当位置来解决。
但是,由于collate函数应该只有一个参数-
batch
,所以不能直接将其传递给DataLoader
,为此,我们使用指定的数据集创建一个partial函数,并将其传递给DataLoader
。myss37ts5#
为什么不在dataset类内部使用__ get_item__方法来解决这个问题呢?当数据不好时,您可以递归地请求一个不同的随机索引,而不是返回None。
这样您就不必处理不同大小的批次。
vfwfrxfs6#