python 在sklearn ConfusionMatrixDisplay.from_predictions中显示空行和列

jljoyd4f  于 2023-10-14  发布在  Python
关注(0)|答案(1)|浏览(100)

我想显示一个标签预测的混淆矩阵。在我的数据的某些子集上,一些类缺失(从地面实况和预测中),例如下面示例中的类6。
我想仍然显示6行和列,即使填充零。有没有办法用sklearn的ConfusionMatrixDisplay实现这一点?

from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt

GT   = [4, 4, 4, 5, 2, 1, 1, 1, 2, 7]
pred = [4, 5, 4, 5, 3, 0, 3, 2, 2, 7]

fig, ax = plt.subplots()
ConfusionMatrixDisplay.from_predictions(GT, pred,
                                        cmap='Blues',
                                        ax=ax,)
plt.show()

我尝试将标签列表传递给display_labels参数,但失败了,因为长度与数据中实际存在的类数量不匹配

GT   = [4, 4, 4, 5, 2, 1, 1, 1, 2, 7]
pred = [4, 5, 4, 5, 3, 0, 3, 2, 2, 7]

fig, ax = plt.subplots()
ConfusionMatrixDisplay.from_predictions(GT, pred,
                                        cmap='Blues',
                                        ax=ax,
                                        display_labels = range(8))
plt.show()

67up9zun

67up9zun1#

from_estimatorfrom_predictions方法有一个labels参数来直接设置标签,覆盖数据中的内容。查看文档页面中的完整选项列表。

相关问题