我有一个多类问题,其中0
是我的负类,1
和2
是正类。检查以下代码:
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
# Outputs
y_true = np.array((1, 2, 2, 0, 1, 0))
y_pred = np.array((1, 0, 0, 0, 0, 1))
# Metrics
precision_macro = precision_score(y_true, y_pred, average='macro')
precision_weighted = precision_score(y_true, y_pred, average='weighted')
recall_macro = recall_score(y_true, y_pred, average='macro')
recall_weighted = recall_score(y_true, y_pred, average='weighted')
f1_macro = f1_score(y_true, y_pred, average='macro')
f1_weighted = f1_score(y_true, y_pred, average='weighted')
# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()
在这种情况下,使用Sklearn
计算的度量如下:
precision_macro = 0.25
precision_weighted = 0.25
recall_macro = 0.33333
recall_weighted = 0.33333
f1_macro = 0.27778
f1_weighted = 0.27778
这就是混淆矩阵:
macro
和weighted
是相同的,因为我对每个类都有相同数量的样本?这是我手动做的。
1 -精度= TP/(TP+FP)。因此对于类1
和2
,我们得到:
Precision1 = TP1/(TP1+FP1) = 1/(1+1) = 0.5
Precision2 = TP2/(TP2+FP2) = 0/(0+0) = 0 (this returns 0 according Sklearn documentation)
Precision_Macro = (Precision1 + Precision2)/2 = 0.25
Precision_Weighted = (2*Precision1 + 2*Precision2)/4 = 0.25
2 - Recall = TP/(TP+FN)。因此对于类1
和2
,我们得到:
Recall1 = TP1/(TP1+FN1) = 1/(1+1) = 0.5
Recall2 = TP2/(TP2+FN2) = 0/(0+2) = 0
Recall_Macro = (Recall1+Recall2)/2 = (0.5+0)/2 = 0.25
Recall_Weighted = (2*Recall1+2*Recall2)/4 = (2*0.5+2*0)/4 = 0.25
3 - F1 = 2*(精确度 * 召回率)/(精确度+召回率)
F1_Macro = 2*(Precision_Macro*Recall_Macro)/(Precision_Macro*Recall_Macro) = 0.25
F1_Weighted = 2*(Precision_Weighted*Recall_Weighted)/(Precision_Weighted*Recall_Weighted) = 0.25
因此,精度分数与Sklearn
相同。但召回和F1不同。我在这里做错了什么?即使您使用Sklearn
的精度和召回值(即0.25
和0.3333
),您也无法获得0.27778
F1分数。
1条答案
按热度按时间klr1opcd1#
对于平均得分,您还需要类0的得分。类0的精度是
1/4
(因此平均值不变)。类0的召回率是1/2
,因此平均召回率是(1/2+1/2+0)/3 = 1/3
。平均F1得分不是平均查准率和查全率的调和均值;这里,类0的F1为
1/3
,类1的F1为1/2
,类2的F1未定义,但取为0
,平均值为5/18
。