如何在PyTorch中只对过采样产生的新数据应用增强?

3htmauhk  于 2023-03-18  发布在  其他
关注(0)|答案(1)|浏览(117)

所以我有一个训练数据集(用torch.utils.data.random_split创建),RGB图像大小为150 x150,共有7个类。有一个类不平衡,我用加权采样器修复了这个问题。但是,我也想添加图像增强到我创建的新数据中,以避免过度拟合,否则它们将是重复的图像。以下是我到目前为止所做的:

# define data augmentation transforms for training set
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(size=150, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
])

# get the indices of examples in each class in the training set
class_indices = [[] for _ in range(7)]
for i in range(len(train_dataset)):
    _, label = train_dataset[i]
    class_indices[label].append(i)

# calculate the number of examples to sample from each class
max_class_size = max([len(class_indices[c]) for c in range(7)])
class_weights = [max_class_size / len(class_indices[c]) for c in range(7)]
num_samples = [int(class_weights[c] * len(class_indices[c])) for c in range(7)]

# create a WeightedRandomSampler to oversample the training set
sampler = data.WeightedRandomSampler(weights=class_weights, num_samples=sum(num_samples), replacement=True)
train_loader = data.DataLoader(train_dataset, batch_size=64, sampler=sampler)

# create new training set with oversampled examples
oversampled_train_dataset = data.Subset(train_dataset, indices=list(sampler))
oversampled_train_dataset.transform = transform_train

如您所见,当前我在末尾将转换应用于整个列车数据,这不是我想要的结果。在历元1得到的训练精度为1.0,而验证精度在0.25左右徘徊。它也不会随着历元的进展而提高/改变。在没有过采样的情况下,最终训练精度约为0.84,而验证精度约为0。71.
一个侧记,有大约20000图像在trainset之前过采样,所以为循环需要一段时间来运行,如果有优化我可以做,我会很感激的建议。谢谢

iswrvxsc

iswrvxsc1#

我的想法是创建两个DataLoader,然后在训练过程中将它们连接起来。

# No Augmentation
transform_train = transforms.Compose([
    transforms.ToTensor(),
])

# With Augmentation
transform_train_subset = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(size=150, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
])

# create a WeightedRandomSampler to oversample the training set
sampler = data.WeightedRandomSampler(weights=class_weights, num_samples=sum(num_samples), replacement=True)
train_loader = data.DataLoader(train_dataset, batch_size=64, sampler=sampler)

# create new training set with oversampled examples
oversampled_train_dataset = data.Subset(train_dataset, indices=list(sampler))

# Sampling the subset
oversampled_train_dataset.transform = transform_train_subset
subset_train_loader = train_loader = data.DataLoader(oversampled_train_dataset, batch_size=64, sampler=sampler)

for i, (x_full, y_full), (x_subset, y_subset) in enumerate(zip(train_loader, subset_train_loader)):
    imgs = torch.cat((x_full, x_subset), dim=0)
    labels = torch.cat((y_full, y_subset), dim=0)
    # Your training code here

因为我没有你的数据,所以很难测试。

相关问题