python 使用seaborn绘制的sklearn混淆矩阵不正确

7ajki6be  于 2023-01-29  发布在  Python
关注(0)|答案(1)|浏览(219)

测试集仅包含1类和3类,如打印件所示。
使用seaborn绘制混淆矩阵热图时。
但是,海运热图绘图类0和2。
绘图应该向下移一行。我假设问题是由索引引起的。

from sklearn.metrics import confusion_matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns

cf_matrix = confusion_matrix(y_true, y_pred)
print(Counter(y_pred))
print(Counter(y_true))

cmn = cf_matrix.astype('float') / cf_matrix.sum(axis=1)[:, np.newaxis]
plt.figure(figsize = (15,15))
sns.heatmap(cmn, annot=True, fmt='.1f')
Counter({3: 100489, 12: 11306, 11: 4314, 4: 3303, 8: 2510, 7: 1850, 5: 185, 10: 132, 2: 69})
Counter({3.0: 117955, 1.0: 6203})

wkyowqbh

wkyowqbh1#

由于cmn是一个numpy数组,seaborn不知道行和列的名称,默认值是0,1,2,...,这也有助于确保y_predy_true是相同的整数类型,例如y_true = y_true.astype(int)
Scikit-learn提供了unique_labels函数来获取它使用的标签。
您可以通过with np.errstate(invalid='ignore'):暂时禁止被零除的警告。
为了进行测试,您可以创建一些易于手动计数的简单数组,并研究confusion_matrix(y_true, y_pred)在这种情况下是如何工作的。

from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np

y_true = [3, 3, 3, 3, 1, 1]
y_pred = [7, 3, 3, 1, 2, 1]

# make sure both arrays are of the same type
y_true = np.array(y_true).astype(int)
y_pred = np.array(y_pred).astype(int)

cf_matrix = confusion_matrix(y_true, y_pred)
with np.errstate(invalid='ignore'):
    cmn = cf_matrix.astype('float') / cf_matrix.sum(axis=1, keepdims=True)

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(12, 6))

sns.set()
sns.heatmap(cmn, annot=True, fmt='.1f', annot_kws={"fontsize": 14},
            linewidths=2, linecolor='black', clip_on=False, ax=ax1)
ax1.set_title('using default labels', fontsize=16)

labels = unique_labels(y_true, y_pred)
sns.heatmap(cmn, xticklabels=labels, yticklabels=labels,
            annot=True, fmt='.1f', annot_kws={"fontsize": 14},
            linewidths=2, linecolor='black', clip_on=False, ax=ax2)
ax2.set_title('using the same labels as sk-learn', fontsize=16)

for ax in (ax1, ax2):
    ax.tick_params(labelsize=20, rotation=0)
    ax.set_xlabel('Predicted value', fontsize=18)
    ax.set_ylabel('True value', fontsize=18)
plt.tight_layout()
plt.show()

相关问题