matplotlib Python无法识别3D数组

zvokhttg  于 2023-03-23  发布在  Python
关注(0)|答案(1)|浏览(147)

我想创建一个自定义函数,它采用机器学习模型和corespondng X和y数组,并将为支持向量机绘制3D决策函数,但当我运行代码时,它给了我以下错误:ax.plot3D(X[y == 0, 0], X[y == 0, 1], X[y == 0, 2], 'ob') IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed
但问题是,我已经创建了3列数组,这里是相应的例子:X, y= make_classification(n_samples=4000,n_features=3,n_informative=3,n_redundant=0,random_state=1)
下面是完整代码:

from sklearn.datasets import make_classification
import numpy as np
from sklearn.svm import LinearSVC
import  matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def plot_3D_surface(mode,X,y):
    model.fit(X, y)
    z = lambda x, y: (-model.intercept_[0] - model.coef_[0][0] * x - model.coef_[0][1] * y) / model.coef_[0][2]

    tmp = np.linspace(-5, 5, 30)
    x, y = np.meshgrid(tmp, tmp)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot3D(X[y == 0, 0], X[y == 0, 1], X[y == 0, 2], 'ob')
    ax.plot3D(X[y == 1, 0], X[y== 1, 1], X[y == 1, 2], 'sr')
    ax.plot_surface(x, y, z(x, y))
    ax.view_init(30, 60)
    plt.show()

def plot_2D_surface(model,X,y,h=0.2,**params):
    X0, X1 = X[:, 0], X[:, 1]
    x_min, x_max = X0.min() - 1, X0.max() + 1
    y_min, y_max = X1.min() - 1, X1.max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    fig, ax = plt.subplots()
    Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    out = ax.contourf(xx, yy, Z, **params)
    ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
    ax.set_ylabel('y label here')
    ax.set_xlabel('x label here')
    ax.set_xticks(())
    ax.set_yticks(())
    ax.set_title('Decision surface of linear SVC ')
    ax.legend()
    plt.show()
X, y= make_classification(n_samples=4000,n_features=3,n_informative=3,n_redundant=0,random_state=1)
print(X[y==0,2])
model = LinearSVC()
model.fit(X, y)
#plot_2D_surface(model ,X, y, cmap=plt.cm.coolwarm, alpha=0.8)
plot_3D_surface(model,X,y)

我猜不出这个错误的原因,即使我将检查像这样的个别行:print(X[y==0,2])它工作并返回以下结果:[-2.03765616 0.00669218 -2.93745161 ... -1.31748082 0.55638536 -1.39934604]
请帮我找出错误的原因

dced5bon

dced5bon1#

plot_3D_surface()中重新定义y,然后尝试在此行中绘制它:

x, y = np.meshgrid(tmp, tmp)

因此,当你尝试用plot3D绘制它时,你不再使用作为参数传递的y值。我会将变量名更改为:

model_x, model_y = np.meshgrid(tmp, tmp)
...
ax.plot_surface(model_x, model_y, z(model_x, model_y))

相关问题