我在像这样的unet训练之前在pytorch中使用了图像增强
class ProcessTrainDataset(Dataset):
def __init__(self, x, y):
self.x = x
self.y = y
self.pre_process = transforms.Compose([
transforms.ToTensor()])
self.transform_data = transforms.Compose([
transforms.ColorJitter(brightness=0.2, contrast=0.2)])
self.transform_all = transforms.Compose([
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
transforms.RandomAffine(degrees=0, translate=(0.2,0.2), scale=(0.9,1.1),),])
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
img_x = Image.open(self.x[idx])
img_y = Image.open(self.y[idx]).convert("L")
#First get into the right range of 0 - 1, permute channels first, and put to tensor
img_x = self.pre_process(img_x)
img_y = self.pre_process(img_y)
#Apply resize and shifting transforms to all; this ensures each pair has the identical transform applied
img_all = torch.cat([img_x, img_y])
img_all = self.transform_all(img_all)
#Split again and apply any color/saturation/hue transforms to data only
img_x, img_y = img_all[:-1, ...], img_all[-1:,...]
img_x = self.transform_data(img_x)
#Add augmented data to dataset
self.x_augmented.append(img_x)
self.y_augmented.append(img_y)
return img_x, img_y
但是我们如何知道是否所有的增强都已经应用于数据集,以及我们如何看到增强后的数据集的数量?
1条答案
按热度按时间cetgtptt1#
如何查看转换后数据集的长度?-用于增强的Pytorch数据转换(例如初始化中定义的随机转换)是 * 动态 * 的,这意味着每次调用
__getitem__(idx)
时,都会计算一个新的随机转换并应用于数据idx
。通过这种方式,数据集在功能上提供了无限数量的图像,即使数据集中只有一个示例(当然,这些图像将彼此高度相关,因为它们是从相同的基础图像生成的)。因此,将pytorch变换视为增加数据集中的元素数量实际上是不正确的。元素的数量总是等于len(self.x)
,但是每次返回元素idx
时,它会略有不同。考虑到这一点,您的实现不需要将转换后的数据存储在
x_augmented
和y_augmented
中,除非您除了使用__getitem__()
的返回值之外,还对这些数据有特定的下游用途。我怎么知道所有的变换都被应用了?-可以肯定地说,如果你定义了变换,然后你应用了变换,那么它们都被应用了。如果你想验证这一点,你可以显示变换前后的图像,并一次注解掉除了一个变换之外的所有变换,以确保每个变换都有效果。