numpy 将数据集拆分为给定比率的训练和测试数据集

7z5jn7bk  于 2022-12-26  发布在  其他
关注(0)|答案(2)|浏览(150)

对于一个学校项目,我需要将一个数据集拆分为给定比率的训练集和测试集。该比率是用作训练集的数据量,而其余数据用作测试。我根据教授的要求创建了一个基本实现,但我无法让它通过他创建的测试。下面是我的实现以及参数和返回变量所代表的内容

def splitData(X, y, split_ratio = 0.8):
'''
X: numpy.ndarray. Shape = [n+1, m]
y: numpy.ndarray. Shape = [m, ]
split_ratio: the ratio of examples go into the Training, Validation, and Test sets.
Split the whole dataset into Training, Validation, and Test sets.
:return: return (training_X, training_y), (test_X, test_y).
        training_X is a (n+1, m_tr) matrix with m_tr training examples;
        training_y is a (m_tr, ) column vector;
        test_X is a (n+1, m_test) matrix with m_test test examples;
        test_y is a (m_test, ) column vector.
'''
## Need to possible shuffle X array and Y array

## amount used for training
m_tr = len(X) * train_ratio

##m_test = len(X) - m_tr Amount that is used for testing

training_X = X[1:m_tr]
training_y = y[1:m_tr]
test_X = [m_tr:len(X)]
test_y = [m_tr:len(y)]
return training_X, training_y, test_X, test_y

由于指令的原因,我在声明m_test时加入了我的注解,但我非常肯定,将数组从第一个元素拆分为m_tr给出了总训练量,其余部分是测试数据。测试数据是通过从m_tr到len(x)或len(y)迭代每个列表找到的。我是否误解了拆分的工作原理?
PS -教授说我们可以跳过验证的拆分。

tzxcd3kk

tzxcd3kk1#

主要有3个问题:
1.在文档中指定需要剪切,而不是行
1.您应该返回2个对,而不是长度为4的元组
1.由于某种原因,您删除了第0个样本,因为您使用“1:“而不是“0:“进行切割

def splitData(X, y, split_ratio = 0.8):
'''
X: numpy.ndarray. Shape = [n+1, m]
y: numpy.ndarray. Shape = [m, ]
split_ratio: the ratio of examples go into the Training, Validation, and Test sets.
Split the whole dataset into Training, Validation, and Test sets.
:return: return (training_X, training_y), (test_X, test_y).
        training_X is a (n+1, m_tr) matrix with m_tr training examples;
        training_y is a (m_tr, ) column vector;
        test_X is a (n+1, m_test) matrix with m_test test examples;
        test_y is a (m_test, ) column vector.
'''
  m_tr = int(len(X) * train_ratio)
  training_X = X[:, :m_tr]
  training_y = y[:m_tr]
  test_X = X[:, m_tr:]
  test_y = y[m_tr:]
  return (training_X, training_y), (test_X, test_y)
lsmepo6l

lsmepo6l2#

1.函数参数称为split_ratio,但在实现函数时使用train_ratio。
1.变量m_tr是列表长度(data)乘以比率(split_ratio)的结果,这种运算的结果可以是浮点数,而用于分割数据的切片只接受整数。
1.对于test_X和test_y,在切片之前没有提供数据。
1.对于training_X和training_y,从第二个元素开始切片,因为您指定了1而不是0。因此,您丢失了第一个数据元素。
我纠正了你的错误:

def splitData(X, y, split_ratio = 0.8):
    
    m_tr = int(len(X) * split_ratio)
    training_X = X[:, :m_tr]
    training_y = y[:m_tr]
    test_X = X[:, m_tr:]
    test_y = y[m_tr:]
    return (training_X, training_y), (test_X, test_y)

相关问题