python 如何为海运混淆矩阵添加正确的标签

s71maibg  于 2023-05-05  发布在  Python
关注(0)|答案(3)|浏览(135)

我已经使用seaborn将数据绘制成混淆矩阵,但我遇到了一个问题。问题是它只在两个轴上显示从0到11的数字,因为我有12个不同的标签。
代码如下:

cf_matrix = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots(figsize=(15,10)) 
sns.heatmap(cf_matrix, linewidths=1, annot=True, ax=ax, fmt='g')

在这里你可以看到我的混淆矩阵:

我得到了我应该得到的混淆矩阵。唯一的问题是没有显示的标签名称。我在互联网上搜索了很长一段时间,没有运气。是否有任何参数可以附加标签或如何做到这一点?

xtfmy6hx

xtfmy6hx1#

当您分解类别时,您应该保留水平,因此您可以将其与pd.crosstab而不是confusion_matrix结合使用来绘制。以iris为例:

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report, confusion_matrix

df = pd.read_csv("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
                 header=None,names=["s.wid","s.len","p.wid","p.len","species"])
X = df.iloc[:,:4]
y,levels = pd.factorize(df['species'])

在这一部分,你得到了[0,..1,..2]中的标签y和水平,作为0,1,2对应的原始标签:

Index(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'], dtype='object')

所以我们适合并做你所拥有的:

clf = RandomForestClassifier(max_depth=2, random_state=0)
clf.fit(X,y)
y_pred = clf.predict(X)
print(classification_report(y,y_pred,target_names=levels))

和一个混淆矩阵0,1,2:

cf_matrix = confusion_matrix(y, y_pred)
sns.heatmap(cf_matrix, linewidths=1, annot=True, fmt='g')

我们回头使用水平:

cf_matrix = pd.crosstab(levels[y],levels[y_pred])
fig, ax = plt.subplots(figsize=(5,5))
sns.heatmap(cf_matrix, linewidths=1, annot=True, ax=ax, fmt='g')

rlcwz9us

rlcwz9us2#

标签按字母顺序排序。因此,使用numpy DISTINCT ture_label,您将获得按字母顺序排序的ndarray

cm_labels = np.unique(true_label)
cm_array = confusion_matrix(true_label, predict_label)
cm_array_df = pd.DataFrame(cm_array, index=cm_labels, columns=cm_labels)
sn.heatmap(cm_array_df, annot=True, annot_kws={"size": 12})
ddarikpa

ddarikpa3#

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# Predict the labels of the test set
y_pred = model.predict(X_test)

# Compute the confusion matrix
cm = confusion_matrix(y_test, y_pred, labels=[0, 1, 2])

# Define the labels and titles for the confusion matrix
classes = ['Negative', 'Neutral', 'Positive']
title = 'Confusion matrix for Logistic Regression model'

# Create a heatmap of the confusion matrix
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', xticklabels=classes, yticklabels=classes)

# Set the axis labels and title
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title(title)

# Add legends for the heatmap
bottom, top = plt.ylim()
plt.ylim(bottom + 0.5, top - 0.5)
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.show()

相关问题