测试集仅包含1类和3类,如打印件所示。
使用seaborn
绘制混淆矩阵热图时。
但是,海运热图绘图类0和2。
绘图应该向下移一行。我假设问题是由索引引起的。
from sklearn.metrics import confusion_matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns
cf_matrix = confusion_matrix(y_true, y_pred)
print(Counter(y_pred))
print(Counter(y_true))
cmn = cf_matrix.astype('float') / cf_matrix.sum(axis=1)[:, np.newaxis]
plt.figure(figsize = (15,15))
sns.heatmap(cmn, annot=True, fmt='.1f')
Counter({3: 100489, 12: 11306, 11: 4314, 4: 3303, 8: 2510, 7: 1850, 5: 185, 10: 132, 2: 69})
Counter({3.0: 117955, 1.0: 6203})
1条答案
按热度按时间wkyowqbh1#
由于
cmn
是一个numpy数组,seaborn不知道行和列的名称,默认值是0,1,2,...
,这也有助于确保y_pred
和y_true
是相同的整数类型,例如y_true = y_true.astype(int)
。Scikit-learn提供了unique_labels函数来获取它使用的标签。
您可以通过
with np.errstate(invalid='ignore'):
暂时禁止被零除的警告。为了进行测试,您可以创建一些易于手动计数的简单数组,并研究
confusion_matrix(y_true, y_pred)
在这种情况下是如何工作的。