无法确定对象类型“Data”的形状-是否将pytorch列表转换为cuda?

2vuwiymt  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(221)

我想将一个列表转换为一种可以读入CUDA格式模型的格式。
我有这个:

print(type(train_dataset))
train_dataset = torch.tensor(train_dataset, device='cuda:0')
print(type(train_dataset))

输出为:

<class 'list'>
Traceback (most recent call last):
  File "test_pytorch_test_gpu2.py", line 882, in <module>
    train_dataset = torch.tensor(train_dataset, device='cuda:0')
ValueError: could not determine the shape of object type 'Data'

有没有人能解释一下如何将列表转换为cuda的格式/我所做的有什么问题?

gwbalxhn

gwbalxhn1#

Dataset是一个包裹训练数据的对象,它的主要功能是组织数据、标签,可能还有它们的扩充。
Dataset之上,通常使用DataLoader:该对象从底层Dataset收集训练样本,并将它们放入表示用于训练的小批量的Tensor中。
您的程式码应该看起来像这样:


# one-time setup of training data

train_dataset = ...  # your code that construct the Dataset
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)

# training loop

for e in range(num_epochs):
  # each epoch represents one pass over all the training samples.
  for (x, y) in train_loader:  # iterate over the data one _batch_ at a time
    # move training tensors, extracted from the Dataset, to GPU
    x = x.to(device)  
    y = y.to(device)
    # rest of your training code here:
    pred = model(x) 
    # ...

相关问题