如何将我在PyTorch中使用“torch.utils.data.random_split()”获得的“子集”转换为切片/Tensor/矩阵?

2exbekwf  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(279)

我尝试在PyTorch中使用torch.utils.data.random_split()将数据划分为训练和测试数据集。

train, test = torch.utils.data.random_split(iris, [112, 38], generator=torch.Generator().manual_seed(42))

上述代码中列车和集合的输出为:〈torch.实用程序.数据.数据集.位于0x141a7379b50的子集〉
我尝试使用下面的代码将下面的值分配给我的训练集和测试集,尽管我得到:"类型错误:列表索引必须是整数或切片,而不是字符串"错误。

train_X = train[['Sepal.Length', 'Sepal.Width', 'Petal.Length',
                 'Petal.Width']]
train_y = train.Species

test_X = test[['Sepal.Length', 'Sepal.Width', 'Petal.Length',
                 'Petal.Width']]
test_y = test.Species

我浏览了stackoverflow上的多个帖子,但没有找到合适的解决方案。

q8l4jmvw

q8l4jmvw1#

我正在使用下面这个极其蹩脚的解决方案:

test = torch.stack([t for t in test])

其中test现在是长度为38的Tensor。当调用:

train, test = torch.utils.data.random_split(iris, [112, 38], generator=torch.Generator().manual_seed(42))

traintest是Subset(如您所注意到的),它们是一种奇特的生成器,您可以使用列表解析来评估生成器,这不是很好的编码,但在紧要关头可以使用。

相关问题