我想显示一个标签预测的混淆矩阵。在我的数据的某些子集上,一些类缺失(从地面实况和预测中),例如下面示例中的类6。
我想仍然显示6
行和列,即使填充零。有没有办法用sklearn的ConfusionMatrixDisplay
实现这一点?
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
GT = [4, 4, 4, 5, 2, 1, 1, 1, 2, 7]
pred = [4, 5, 4, 5, 3, 0, 3, 2, 2, 7]
fig, ax = plt.subplots()
ConfusionMatrixDisplay.from_predictions(GT, pred,
cmap='Blues',
ax=ax,)
plt.show()
我尝试将标签列表传递给display_labels
参数,但失败了,因为长度与数据中实际存在的类数量不匹配
GT = [4, 4, 4, 5, 2, 1, 1, 1, 2, 7]
pred = [4, 5, 4, 5, 3, 0, 3, 2, 2, 7]
fig, ax = plt.subplots()
ConfusionMatrixDisplay.from_predictions(GT, pred,
cmap='Blues',
ax=ax,
display_labels = range(8))
plt.show()
1条答案
按热度按时间67up9zun1#
from_estimator
和from_predictions
方法有一个labels
参数来直接设置标签,覆盖数据中的内容。查看文档页面中的完整选项列表。