如何在使用dataloader时仅从pytorch dataset __getitem__中获取某些值

9lowa7mx  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(105)

我只想从pytorch customDataset的__getitem__中获取某些“idx "。例如,我只希望1 5 8 11作为idx传递给__getitem__的数据集,而我将数据加载器 Package 到该自定义数据集。
或者换句话说,假设我必须在def __getitem__(self, idx):之后的行上有print(idx)
我能做些什么呢,所以print(idx)

1
5
8
11

我的意思是只有索引中的值被传递给__getitem__,但现在它打印

0
1
2
3
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
    def __init__(self, data, indexes):
        self.data = data
        self.indexes = indexes

    def __len__(self):
        if self.indexes is None:
            return len(self.data)
        return len(self.indexes)

    def __getitem__(self, idx):
        print(idx)
        return 1
custom_dataset = CustomDataset([i for i in range(20)], [1,5,8,11])
custom_dataloader = DataLoader(custom_dataset, batch_size=64)
for batch in custom_dataloader:
    pass

很明显,设置self.indexes=[1, 5, 8, 11]不符合我的要求。
第一优先级是有一些直接的方法,如在customDataset中设置一些属性,如果可能的话,而不是使用额外的对象,如采样器。
注意,我对采样器有问题,我可能会问另一个问题。也是采样器,删除了使用shuffle=True的自由。所以直线的方式更受欢迎。

kq0g1dla

kq0g1dla1#

如果我理解正确的话,你有长度为20的数据,但你想限制使用长度为4的索引子集。
此外,您还希望将其 Package 在DataLoader中。我真的不认为这是你应该做的事情。Dataset对象是为了访问特定的索引而创建的,因此应该在请求时返回数据,而不会在内部修改它以输出其他内容。
因此,您应该只使用允许的数据示例化Dataset类,这些数据可以由加载器及其默认采样器以顺序方式迭代,或者通过实现默认采样器(您希望避免)来防止加载器首先访问禁止的索引。因此,我建议对您的代码进行以下最小更改...

from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
    def __init__(self, data, indexes=None):
        self.data = data
        self.indexes = indexes

    def __len__(self):
        if self.indexes is None:
            return len(self.data)
        return len(self.indexes)

    def __getitem__(self, idx):
        print("index not the same as before:", idx)
        return data[idx]

#all data
data = np.array([i for i in range(20)]]
#allowed indices
allowed = [1,5,8,11]
# limit data
data = data[allowed]

custom_dataset = CustomDataset(data)
custom_dataloader = DataLoader(custom_dataset, batch_size=64)
for batch in custom_dataloader:
    pass

相关问题