从PyTorch DataLoader获取单个随机示例

bvhaajcl  于 2023-08-05  发布在  其他
关注(0)|答案(7)|浏览(154)

如何从PyTorch DataLoader中获取单个随机示例?
如果我的DataLoader给出了多个图像和标签的minbatches,我如何获得一个随机的图像和标签?
请注意,我不希望每个minibatch只有一个图像和标签,我希望总共有一个示例。

smtd7mpg

smtd7mpg1#

如果你的DataLoader是这样的:

test_loader = DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=True)

字符串
它给你一个大小为batch_size的批,你可以通过直接索引这个批来挑选一个随机的例子:

for test_images, test_labels in test_loader:  
    sample_image = test_images[0]    # Reshape them according to your needs.
    sample_label = test_labels[0]

备选方案

1.可以使用RandomSampler获取随机样本。
1.在DataLoader中使用1的batch_size
1.直接从数据集中提取样本,如下所示:

mnist_test = datasets.MNIST('../MNIST/', train=False, transform=transform)


现在使用此数据集采集样本:

for image, label in mnist_test:
      # do something with image and other attributes


1.**(可能是最好的)**看这里:

inputs, classes = next(iter(dataloader))

ki1q1bka

ki1q1bka2#

如果你想从Trainloader/Testloader中选择特定的图像,你应该从master中查看Subset函数:
下面是一个如何使用它的示例:

testset = ImageFolderWithPaths(root="path/to/your/Image_Data/Test/", transform=transform)
subset_indices = [0] # select your indices here as a list
subset = torch.utils.data.Subset(testset, subset_indices)
testloader_subset = torch.utils.data.DataLoader(subset, batch_size=1, num_workers=0, shuffle=False)

字符串
这样您就可以使用一个图像和标签。但是,你当然可以在subset_indices中使用不止一个索引。
如果您想使用DataFolder中的特定图像,可以使用dataset.sample并构建一个字典来获取您想要使用的图像的索引。

ltskdhd1

ltskdhd13#

(This答案是补充@parthagar答案的备选3
迭代dataset不会**返回“random”示例,您应该用途:

# Recovers the original `dataset` from the `dataloader`
dataset = dataloader.dataset
n_samples = len(dataset)

# Get a random sample
random_index = int(numpy.random.random()*n_samples)
single_example = dataset[random_index]

字符串

tp5buhyn

tp5buhyn4#

TL;DR:

DataLoader获取单个示例的一般形式为:

list = [ x[0] for x in iter(trainloader).next() ]

字符串
特别是对于所问的问题,其中返回了minbatches的图像和标签:

image, label = [ x[0] for x in iter(trainloader).next() ]

可能有趣的信息:

要从DataLoader获取单个小批,请用途:

iter(trainloader).next()


当运行类似for images, labels in dataloader:的代码时,底层发生的是通过iter(dataloader)创建迭代器,然后在每次循环执行时调用迭代器的.next()
要从DataLoader(返回图像和标签)获取单个图像,请用途:

image = iter(trainloader).next()[0][0]


这与执行以下操作相同:

images, labels = iter(trainloader).next()
image = images[0]

goqiplq2

goqiplq25#

DataLoader随机抽取样本

假设DataLoader(shuffle=True)用于其构造,可以从DataLoader中随机抽取一个示例:

example = next(iter(dataloader))[0]

字符串

Dataset中随机抽取样本

如果不是这种情况,您可以从数据集中绘制一个随机示例:

idx = torch.randint(len(dataset), (1,))
example = dataset[idx]

q7solyqu

q7solyqu6#

获取随机样本的关键是为DataLoader设置shuffle=True,获取单个图像的关键是将批量大小设置为1。
下面是loading the mnist dataset之后的例子。

from torch.utils.data import DataLoader, Dataset, TensorDataset
bs = 1
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

for xb, yb in train_dl:
    print(xb.shape)
    x = xb.view(28,28) 
    print(x.shape)
    print(yb)
    break #just once

from matplotlib import pyplot as plt
plt.imshow(x, cmap="gray")

字符串


的数据

vd2z7a6w

vd2z7a6w7#

你可以简单地将trainloader转换为iterable,然后通过编写以下代码获取下一个批处理

dataiter = iter(trainloader)
images, labels = next(dataiter)

字符串

这里有一个例子

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# functions to show an image

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))


参考号:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

相关问题