numpy LDA -地块边界

pxyaymoc  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(145)

我正在使用Iris -文件,并希望绘制三个类的分离边界/区域:
我首先使用以下代码将所有功能减少到两个LDA - Coponents:

# Load the Iris dataset
iris = load_iris()
X = iris.data
y = iris.target

# Perform LDA
lda = LinearDiscriminantAnalysis()
X_lda = lda.fit_transform(X, y)

# Plot the data points
cmap = ListedColormap(['red', 'green', 'blue'])
for target, color, marker in zip(np.unique(y), ['r', 'g', 'b'], ['s', 'x', 'o']):
    plt.scatter(X_lda[y == target, 0], X_lda[y == target, 1], c=color, cmap=cmap, marker=marker, label=target, edgecolors='black')

# Set plot labels and limits
plt.title('Linear Discriminant Analysis (Iris Dataset)')
plt.xlabel('LDA Component 1')
plt.ylabel('LDA Component 2')
plt.legend()

字符串

结果:

x1c 0d1x的数据
现在我想画出这三个类的决策边界。我试过使用:

# Define the decision boundary
x1_min, x1_max = X_lda[:, 0].min() - 1, X_lda[:, 0].max() + 1
x2_min, x2_max = X_lda[:, 1].min() - 1, X_lda[:, 1].max() + 1
xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, 0.02),
                      np.arange(x2_min, x2_max, 0.02))

Z = lda.predict(np.array([xx1.ravel(), xx2.ravel()]).T)

Z = Z.reshape(xx1.shape)


但我得到了错误:ValueError:X有2个特征,但LinearDiscriminantAnalysis需要4个特征作为输入。
我如何划分这三个类之间的界限?

aoyhnmkz

aoyhnmkz1#

错误是说你已经用X训练了数据,这是一个N*4的数组,四个特征,然后你使用相同的模型lda来预测N*2的数组,所以维度是不匹配的。
如果你想用np.array([xx1.ravel(), xx2.ravel()]).T预测数据,你需要初始化另一个模型,然后用两个特征训练模型。
然后,您可以使用DecisionBoundaryDisplay.from_estimator绘制边界
完整代码:

from sklearn.datasets import load_iris
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.inspection import DecisionBoundaryDisplay


iris = load_iris()
X = iris.data
y = iris.target

lda = LinearDiscriminantAnalysis()
X_lda = lda.fit_transform(X, y)

# Create another model for two features
lda_2 = LinearDiscriminantAnalysis().fit(X_lda, y)

_, ax = plt.subplots()
DecisionBoundaryDisplay.from_estimator(
    lda_2,
    X_lda,
    cmap=plt.cm.Paired,
    ax=ax,
    response_method="predict",
    plot_method="pcolormesh",
    shading="auto",
    eps=0.2,
)

# Plot the data points
cmap = ListedColormap(['red', 'green', 'blue'])
for target, color, marker in zip(np.unique(y), ['r', 'g', 'b'], ['s', 'x', 'o']):
    ax.scatter(X_lda[y == target, 0], X_lda[y == target, 1], c=color, cmap=cmap, marker=marker, label=target, edgecolors='black')

# Set plot labels and limits
plt.title('Linear Discriminant Analysis (Iris Dataset)')
plt.xlabel('LDA Component 1')
plt.ylabel('LDA Component 2')
plt.legend()

字符串
输出:

的数据

相关问题