我想将一个列表转换为一种可以读入CUDA格式模型的格式。
我有这个:
print(type(train_dataset))
train_dataset = torch.tensor(train_dataset, device='cuda:0')
print(type(train_dataset))
输出为:
<class 'list'>
Traceback (most recent call last):
File "test_pytorch_test_gpu2.py", line 882, in <module>
train_dataset = torch.tensor(train_dataset, device='cuda:0')
ValueError: could not determine the shape of object type 'Data'
有没有人能解释一下如何将列表转换为cuda的格式/我所做的有什么问题?
1条答案
按热度按时间gwbalxhn1#
Dataset
是一个包裹训练数据的对象,它的主要功能是组织数据、标签,可能还有它们的扩充。在
Dataset
之上,通常使用DataLoader
:该对象从底层Dataset
收集训练样本,并将它们放入表示用于训练的小批量的Tensor中。您的程式码应该看起来像这样: