pytorch DataLoader对象是否为可迭代对象?

uemypmqf  于 2023-01-20  发布在  其他
关注(0)|答案(1)|浏览(187)

这是密码

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math

# creating a custom class for our dataset, which inherits from Dataset.
class WineDataset(Dataset):

    # this function is used for data loading
    def __init__(self):
      # data loading
      xy = np.loadtxt('./wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
      self.x = torch.from_numpy(xy[:, 1:])  # the first column is the output label
      self.y = torch.from_numpy(xy[:, [0]]) # n_samples, 1
      self.n_samples = xy.shape[0]

    # this function allows indexing in our dataset
    def __getitem__(self, index):
      return self.x[index], self.y[index] # the function returns a tuple.

    # this allows us to call len on our dataset.
    def __len__(self):
      return self.n_samples

dataset = WineDataset()
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2)

dataiter = iter(dataloader)
data = next(dataiter)
features, labels = data
print(features, labels)

我的问题是,既然我们已经可以直接在dataloader上调用enumerate方法,这是否意味着dataloader对象是一个可迭代对象?如果这是真的,那么调用iter(dataloader)就等同于从一个迭代器对象创建一个迭代器对象?
我对这件事有点困惑,请帮我解决。
我需要知道当dataloader作为参数传递时enumerate方法在幕后做什么,还需要知道iter(dataloader)在做什么。

fnx2tebb

fnx2tebb1#

Iterator是实现__iter__方法的东西。Iterable是实现__next__方法的东西。iter()enumerate()都调用类的__iter___方法。例如

class A: # this is an iterable
    def __iter__(self):
        print ('iter called at A')
        return B()
    
class B: # this is an iterator
    def __next__(self):
        print( 'next called at B')
        return 1

注意,类B的任何对象都是迭代器,因为它实现了__next__,但它不是可迭代对象,因为它没有__iter__方法。类似地,类A的任何对象都是可迭代对象,但不是迭代器。
运行它,

a = A()

创建迭代器
一个二个一个一个
在迭代器b上调用next()

next(b)
"""
next called at B
1
"""

无法在可迭代的a上调用next()

next(a)
"""
TypeError: 'A' object is not an iterator
"""

我们可以在a上执行for循环

for i in a:
    print(i)
    break
"""
iter called at A
next called at B
1
"""

无法在b上执行for循环

for i in b:
    print(i)
    break
"""
TypeError: 'B' object is not iterable
"""

现在打电话给enumerate
一个12b1x一个13b1x
可以在c上执行for循环以及调用next()
一个一个十四个一个一个十五个一个一个一个十六个一个一个一个十七个一个
所以当我们在Dataloader上调用enumerate时,它的__iter__方法也被调用了。看看pytorch源代码中__iter__函数的签名:

class DataLoader(Generic[T_co]):
.
.
    def __iter__(self) -> '_BaseDataLoaderIter':

这个_BaseDataLoaderIter类实现了__iter____next__,因此它既是一个可迭代对象,也是一个迭代器。

class _BaseDataLoaderIter(object):
.
.
    def __iter__(self) -> '_BaseDataLoaderIter':
        return self

    def __next__(self) -> Any:
    .
    .
        return data

因此,您可以在Dataloader上调用enumerate()iter(),也可以在..\Lib\site-packages\torch\utils\data\dataloader.py中直接查看python中的源代码

相关问题