pytorch collate_fn拒收样品并生成另一个样品

zzwlnbp8  于 2023-02-19  发布在  其他
关注(0)|答案(6)|浏览(162)

我构建了一个数据集,在这里我对加载的图像进行各种检查,然后将这个数据集传递给一个数据加载器。
在我的DataSet类中,如果图片未通过我的检查,我会将样本返回为None,并且我有一个自定义的collate_fn函数,该函数从检索的批处理中删除所有None,并返回剩余的有效样本。
然而,此时返回的批大小可能不同。是否有方法告诉collate_fn保持来源数据,直到批大小达到一定长度?

class DataSet():
     def __init__(self, example):
          # initialise dataset
          # load csv file and image directory
          self.example = example
     def __getitem__(self,idx):
          # load one sample
          # if image is too dark return None
          # else 
          # return one image and its equivalent label

dataset = Dataset(csv_file='../', image_dir='../../')

dataloader = DataLoader(dataset , batch_size=4,
                        shuffle=True, num_workers=1, collate_fn = my_collate )

def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] could return something like G = [None, {},{},{}]
    batch = list(filter (lambda x:x is not None, batch)) # this gets rid of nones in batch. For example above it would result to G = [{},{},{}]
    # I want len(G) = 4
    # so how to sample another dataset entry?
    return torch.utils.data.dataloader.default_collate(batch)
xuo3flqw

xuo3flqw1#

有2个黑客可以用来解决问题,选择一种方式:
使用原始批次样品快速选项

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
        diff = len_batch - len(batch)
        for i in range(diff):
            batch = batch + batch[:diff]
    return torch.utils.data.dataloader.default_collate(batch)

否则,只需从数据集中随机加载另一个样本更好的选项

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # source all the required samples from the original dataset at random
        diff = len_batch - len(batch)
        for i in range(diff):
            batch.append(dataset[np.random.randint(0, len(dataset))])

    return torch.utils.data.dataloader.default_collate(batch)
xj3cbfub

xj3cbfub2#

这对我很有效,因为有时候甚至那些随机值都是None。

def my_collate(batch):
    len_batch = len(batch)
    batch = list(filter(lambda x: x is not None, batch))

    if len_batch > len(batch):                
        db_len = len(dataset)
        diff = len_batch - len(batch)
        while diff != 0:
            a = dataset[np.random.randint(0, db_len)]
            if a is None:                
                continue
            batch.append(a)
            diff -= 1

    return torch.utils.data.dataloader.default_collate(batch)
hs1rzwqc

hs1rzwqc3#

对于希望即时拒绝训练示例的任何人,可以简单地使用IterableDataset并编写__iter__和__next__函数,而不是使用技巧来解决数据加载器的collate_fn中的问题,如下所示

def __iter__(self):
    return self
def __next__(self):
    # load the next non-None example
cnjp1d6j

cnjp1d6j4#

[编辑]从下面截取的代码的更新版本可以在这里找到https://github.com/project-lighter/lighter/blob/main/lighter/utils/collate.py
感谢Brian Formento提出问题并给出解决方法。如前所述,用新示例替换坏示例的最佳选择有两个问题:
1.新采样的示例也可能被破坏;
1.数据集不在范围内。
下面是这两个问题的解决方案-问题1通过递归调用解决,问题2通过创建collate函数的partial函数并将数据集固定在适当位置来解决。

import random
import torch

def collate_fn_replace_corrupted(batch, dataset):
    """Collate function that allows to replace corrupted examples in the
    dataloader. It expect that the dataloader returns 'None' when that occurs.
    The 'None's in the batch are replaced with another examples sampled randomly.

    Args:
        batch (torch.Tensor): batch from the DataLoader.
        dataset (torch.utils.data.Dataset): dataset which the DataLoader is loading.
            Specify it with functools.partial and pass the resulting partial function that only
            requires 'batch' argument to DataLoader's 'collate_fn' option.

    Returns:
        torch.Tensor: batch with new examples instead of corrupted ones.
    """ 
    # Idea from https://stackoverflow.com/a/57882783

    original_batch_len = len(batch)
    # Filter out all the Nones (corrupted examples)
    batch = list(filter(lambda x: x is not None, batch))
    filtered_batch_len = len(batch)
    # Num of corrupted examples
    diff = original_batch_len - filtered_batch_len
    if diff > 0:
        # Replace corrupted examples with another examples randomly
        batch.extend([dataset[random.randint(0, len(dataset))] for _ in range(diff)])
        # Recursive call to replace the replacements if they are corrupted
        return collate_fn_replace_corrupted(batch, dataset)
    # Finally, when the whole batch is fine, return it
    return torch.utils.data.dataloader.default_collate(batch)

但是,由于collate函数应该只有一个参数-batch,所以不能直接将其传递给DataLoader,为此,我们使用指定的数据集创建一个partial函数,并将其传递给DataLoader

import functools
from torch.utils.data import DataLoader

collate_fn = functools.partial(collate_fn_replace_corrupted, dataset=dataset)
return DataLoader(dataset,
                  batch_size=batch_size,
                  num_workers=num_workers,
                  pin_memory=pin_memory,
                  collate_fn=collate_fn)
myss37ts

myss37ts5#

为什么不在dataset类内部使用__ get_item__方法来解决这个问题呢?当数据不好时,您可以递归地请求一个不同的随机索引,而不是返回None。

class DataSet():
    def __getitem__(self, idx):
        sample = load_sample(idx)
        if is_no_good(sample):
            idx = np.random.randint(0, len(self)-1)
            sample = self[idx]
        return sample

这样您就不必处理不同大小的批次。

vfwfrxfs

vfwfrxfs6#

    • 快速选项**有问题,下面是修正版本。
def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
        diff = len_batch - len(batch)
        batch = batch + batch[:diff] # assume diff < len(batch)
    return torch.utils.data.dataloader.default_collate(batch)

相关问题