我只想从pytorch customDataset的__getitem__
中获取某些“idx "。例如,我只希望1 5 8 11
作为idx
传递给__getitem__
的数据集,而我将数据加载器 Package 到该自定义数据集。
或者换句话说,假设我必须在def __getitem__(self, idx):
之后的行上有print(idx)
我能做些什么呢,所以print(idx)
,
1
5
8
11
我的意思是只有索引中的值被传递给__getitem__
,但现在它打印
0
1
2
3
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, data, indexes):
self.data = data
self.indexes = indexes
def __len__(self):
if self.indexes is None:
return len(self.data)
return len(self.indexes)
def __getitem__(self, idx):
print(idx)
return 1
custom_dataset = CustomDataset([i for i in range(20)], [1,5,8,11])
custom_dataloader = DataLoader(custom_dataset, batch_size=64)
for batch in custom_dataloader:
pass
很明显,设置self.indexes=[1, 5, 8, 11]
不符合我的要求。
第一优先级是有一些直接的方法,如在customDataset中设置一些属性,如果可能的话,而不是使用额外的对象,如采样器。
注意,我对采样器有问题,我可能会问另一个问题。也是采样器,删除了使用shuffle=True
的自由。所以直线的方式更受欢迎。
1条答案
按热度按时间kq0g1dla1#
如果我理解正确的话,你有长度为20的数据,但你想限制使用长度为4的索引子集。
此外,您还希望将其 Package 在DataLoader中。我真的不认为这是你应该做的事情。Dataset对象是为了访问特定的索引而创建的,因此应该在请求时返回数据,而不会在内部修改它以输出其他内容。
因此,您应该只使用允许的数据示例化Dataset类,这些数据可以由加载器及其默认采样器以顺序方式迭代,或者通过实现默认采样器(您希望避免)来防止加载器首先访问禁止的索引。因此,我建议对您的代码进行以下最小更改...