我想将固定长度的图像序列加载到相同大小的批次中(例如序列长度=批次大小= 7)。
有多个目录,每个目录都有来自不同数量图像序列的图像。来自不同目录的序列彼此不相关。
用我现在的代码,我可以处理几个子目录,但是如果一个目录中没有足够的图像来填充一批,剩下的图像将从下一个目录中取出。我想避免这种情况。
相反,如果当前目录中没有足够的图像,则应丢弃一个批次,而该批次应仅使用下一个目录中的图像填充。这样,我希望避免在同一批次中混合不相关的图像序列。如果目录中没有足够的图像来创建一个批次,则应完全跳过该批次。
例如,序列长度/批量大小为7:
- 目录A有15个图像→创建2个批次,每个批次有7个图像;忽略其余图像
- 目录B有10个图像→创建1个批次,其中有7个图像;忽略其余图像
- 目录C有3个图像→目录被完全跳过
我还在学习中,但我认为这可以用一个costum批量采样器来完成?不幸的是,我对此有些问题。也许有人可以帮我找到解决方案。
这是我现在的代码:
class MainDataset(Dataset):
def __init__(self, img_dir, use_folder_name=False):
self.gt_images = self._load_main_dataset(img_dir)
self.dataset_len = len(self.gt_images)
self.use_folder_name = use_folder_name
def __len__(self):
return self.dataset_len
def __getitem__(self, idx):
img_dir = self.gt_images[idx]
img_name = self._get_name(img_dir)
gt = self._load_img(img_dir)
# Skip non-image files
if gt is None:
return None
gt = torch.from_numpy(gt).permute(2, 0, 1)
return gt, img_name
def _get_name(self, img_dir):
if self.use_folder_name:
return img_dir.split(os.sep)[-2]
else:
return img_dir.split(os.sep)[-1].split('.')[0]
def _load_main_dataset(self, img_dir):
if not (os.path.isdir(img_dir)):
return [img_dir]
gt_images = []
for root, dirs, files in os.walk(img_dir):
for file in files:
if not is_valid_file(file):
continue
gt_images.append(os.path.join(root, file))
gt_images.sort()
return gt_images
def _load_img(self, img_path):
gt_image = io.imread(img_path)
gt_image_bd = getBitDepth(gt_image)
gt_image = np.array(gt_image).astype(np.float32) / ((2 ** (gt_image_bd / 3)) - 1)
return gt_image
def is_valid_file(file_name: str):
# Check if the file has a valid image extension
valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif']
for ext in valid_image_extensions:
if file_name.lower().endswith(ext):
return True
return False
sequence_data_store = MainDataset(img_dir=sdr_img_dir, use_folder_name=True)
sequence_loader = DataLoader(sequence_data_store, num_workers=0, pin_memory=False)
字符串
1条答案
按热度按时间1yjd4xko1#
虽然使用批处理采样器可能是一个好主意,可以有一个通用的自定义数据集,你可以不同的采样,我更喜欢一个简单的方法。
我会在init函数中构造一个数据结构,它已经包含了你要操作的所有图像序列。事实是,目前,你的Dataset类在撒谎,因为它说你的数据集的长度等于图像文件夹的数量。这是不正确的,因为它取决于文件夹中包含的图像数量。
目前,您的数据集一次只返回一个图像,而您需要序列。
你的问题中也缺少了一些关于数据集实际结构的信息。尽管如此,这里有一个Datatet类的建议:
字符串