pytorch FashionMNIST数据集未转换为Tensor

xmd2e60i  于 2022-12-04  发布在  其他
关注(0)|答案(1)|浏览(146)

尝试计算数据集的平均值和标准差,以便随后对其进行归一化。
当前代码:

train_dataset = datasets.FashionMNIST('data', train=True, download = True, transform=[transforms.ToTensor()])
test_dataset = datasets.FashionMNIST('data', train=False, download = True, transform=[transforms.ToTensor()])

def calc_torch_mean_std(tens):   
    mean = torch.mean(tens, dim=1)
    std = torch.sqrt(torch.mean((tens - mean[:, None]) ** 2, dim=1))
    return(std, mean)

train_mean, train_std = calc_torch_mean_std(train_dataset)

test_mean, test_std = calc_torch_mean_std(test_dataset)

然而,我得到的错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/16/crymx03s6pzfspm_3qfrlkx00000gn/T/ipykernel_72423/605045038.py in <module>
      8     return(std, mean)
      9 
---> 10 train_mean, train_std = calc_torch_mean_std(train_dataset)
     11 
     12 test_mean, test_std = calc_torch_mean_std(test_dataset)

/var/folders/16/crymx03s6pzfspm_3qfrlkx00000gn/T/ipykernel_72423/605045038.py in calc_torch_mean_std(tens)
      4 
      5 def calc_torch_mean_std(tens):
----> 6     mean = torch.mean(tens, dim=1)
      7     std = torch.sqrt(torch.mean((tens - mean[:, None]) ** 2, dim=1))
      8     return(std, mean)

TypeError: mean() received an invalid combination of arguments - got (FashionMNIST, dim=int), but expected one of:
 * (Tensor input, *, torch.dtype dtype)
 * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)

它应该得到一个Tensor,因为我用transforms.ToTensor()来转换数据。
已检查转换的导入,一切正常。已检查数据集的参数。FashionMNIST()和转换使用正确(在有和没有[ ]的情况下都应工作)。
预期无错误,并获得两个数据集的平均值和标准差。

ars1skjm

ars1skjm1#

datasets.FashionMNIST返回(image,target),其中target是目标类的索引。所以如果你想取平均值,你只需要提取图像。

images = torch.vstack([pair[0] for pair in train_dataset])

图像现在应该是形状(N,H,W),你可以做任何你想从那里。
OP提到的另一个解决方案是使用train_dataset.data直接访问数据。

相关问题