将CSV解析为PytorchTensor

m528fe3b  于 11个月前  发布在  其他
关注(0)|答案(5)|浏览(124)

我有一个CSV文件,除了头行之外,所有的值都是数值。当试图构建Tensor时,我得到了以下异常:

Traceback (most recent call last):
  File "pytorch.py", line 14, in <module>
    test_tensor = torch.tensor(test)
ValueError: could not determine the shape of object type 'DataFrame'

字符串
下面是我的代码:

import torch
import dask.dataframe as dd

device = torch.device("cuda:0")

print("Loading CSV...")
test = dd.read_csv("test.csv", encoding = "UTF-8")
train = dd.read_csv("train.csv", encoding = "UTF-8")

print("Converting to Tensor...")
test_tensor = torch.tensor(test)
train_tensor = torch.tensor(train)


使用pandas而不是Dask进行CSV解析产生了相同的错误。我还试图在调用torch.tensor(data)时指定dtype=torch.float64,但又得到了相同的错误。

ig9co6j1

ig9co6j11#

我想你只是缺少了.values

import torch
import pandas as pd

train = pd.read_csv('train.csv')
train_tensor = torch.tensor(train.values)

字符串

mftmpeh8

mftmpeh82#

较新版本的pandas强烈建议使用to_numpy而不是values

train_tensor = torch.tensor(train.to_numpy())

字符串

ezykj2lf

ezykj2lf3#

使用NumPy

import numpy as np
import torch

tensor = torch.from_numpy(
    np.genfromtxt("train.csv", delimiter=",")
)

字符串

pgx2nnw8

pgx2nnw84#

所有导入函数似乎都需要一个带有数字数组的.csv。您在最初的问题案例中提到,您的.csv包含列标题。请尝试在.csv文件中不包含标题的代码。

nwwlzxa7

nwwlzxa75#

尝试先将其转换为数组:

test_tensor = torch.Tensor(test.values)

字符串

相关问题