这是密码
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
# creating a custom class for our dataset, which inherits from Dataset.
class WineDataset(Dataset):
# this function is used for data loading
def __init__(self):
# data loading
xy = np.loadtxt('./wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
self.x = torch.from_numpy(xy[:, 1:]) # the first column is the output label
self.y = torch.from_numpy(xy[:, [0]]) # n_samples, 1
self.n_samples = xy.shape[0]
# this function allows indexing in our dataset
def __getitem__(self, index):
return self.x[index], self.y[index] # the function returns a tuple.
# this allows us to call len on our dataset.
def __len__(self):
return self.n_samples
dataset = WineDataset()
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2)
dataiter = iter(dataloader)
data = next(dataiter)
features, labels = data
print(features, labels)
我的问题是,既然我们已经可以直接在dataloader上调用enumerate方法,这是否意味着dataloader对象是一个可迭代对象?如果这是真的,那么调用iter(dataloader)就等同于从一个迭代器对象创建一个迭代器对象?
我对这件事有点困惑,请帮我解决。
我需要知道当dataloader作为参数传递时enumerate方法在幕后做什么,还需要知道iter(dataloader)在做什么。
1条答案
按热度按时间fnx2tebb1#
Iterator是实现
__iter__
方法的东西。Iterable是实现__next__
方法的东西。iter()
和enumerate()
都调用类的__iter___
方法。例如注意,类
B
的任何对象都是迭代器,因为它实现了__next__
,但它不是可迭代对象,因为它没有__iter__
方法。类似地,类A
的任何对象都是可迭代对象,但不是迭代器。运行它,
创建迭代器
一个二个一个一个
在迭代器
b
上调用next()
无法在可迭代的
a
上调用next()
我们可以在
a
上执行for
循环无法在
b
上执行for
循环现在打电话给
enumerate
一个12b1x一个13b1x
可以在c上执行
for
循环以及调用next()
一个一个十四个一个一个十五个一个一个一个十六个一个一个一个十七个一个
所以当我们在
Dataloader
上调用enumerate
时,它的__iter__
方法也被调用了。看看pytorch源代码中__iter__
函数的签名:这个
_BaseDataLoaderIter
类实现了__iter__
和__next__
,因此它既是一个可迭代对象,也是一个迭代器。因此,您可以在
Dataloader
上调用enumerate()
和iter()
,也可以在..\Lib\site-packages\torch\utils\data\dataloader.py
中直接查看python中的源代码