我想创建一个自定义函数,它采用机器学习模型和corespondng X和y数组,并将为支持向量机绘制3D决策函数,但当我运行代码时,它给了我以下错误:ax.plot3D(X[y == 0, 0], X[y == 0, 1], X[y == 0, 2], 'ob') IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed
但问题是,我已经创建了3列数组,这里是相应的例子:X, y= make_classification(n_samples=4000,n_features=3,n_informative=3,n_redundant=0,random_state=1)
下面是完整代码:
from sklearn.datasets import make_classification
import numpy as np
from sklearn.svm import LinearSVC
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def plot_3D_surface(mode,X,y):
model.fit(X, y)
z = lambda x, y: (-model.intercept_[0] - model.coef_[0][0] * x - model.coef_[0][1] * y) / model.coef_[0][2]
tmp = np.linspace(-5, 5, 30)
x, y = np.meshgrid(tmp, tmp)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot3D(X[y == 0, 0], X[y == 0, 1], X[y == 0, 2], 'ob')
ax.plot3D(X[y == 1, 0], X[y== 1, 1], X[y == 1, 2], 'sr')
ax.plot_surface(x, y, z(x, y))
ax.view_init(30, 60)
plt.show()
def plot_2D_surface(model,X,y,h=0.2,**params):
X0, X1 = X[:, 0], X[:, 1]
x_min, x_max = X0.min() - 1, X0.max() + 1
y_min, y_max = X1.min() - 1, X1.max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
fig, ax = plt.subplots()
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
out = ax.contourf(xx, yy, Z, **params)
ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
ax.set_ylabel('y label here')
ax.set_xlabel('x label here')
ax.set_xticks(())
ax.set_yticks(())
ax.set_title('Decision surface of linear SVC ')
ax.legend()
plt.show()
X, y= make_classification(n_samples=4000,n_features=3,n_informative=3,n_redundant=0,random_state=1)
print(X[y==0,2])
model = LinearSVC()
model.fit(X, y)
#plot_2D_surface(model ,X, y, cmap=plt.cm.coolwarm, alpha=0.8)
plot_3D_surface(model,X,y)
我猜不出这个错误的原因,即使我将检查像这样的个别行:print(X[y==0,2])
它工作并返回以下结果:[-2.03765616 0.00669218 -2.93745161 ... -1.31748082 0.55638536 -1.39934604]
请帮我找出错误的原因
1条答案
按热度按时间dced5bon1#
在
plot_3D_surface()
中重新定义y
,然后尝试在此行中绘制它:因此,当你尝试用plot3D绘制它时,你不再使用作为参数传递的
y
值。我会将变量名更改为: