keras 如何利用sklearn.metrics得到多类语义切分的混淆矩阵?

dwbf0jvd  于 2023-01-26  发布在  其他
关注(0)|答案(1)|浏览(205)

我遇到这个错误时,试图获得混淆矩阵的多类语义分割问题(11类)。y_true和y的形状如下所示。我已经尝试了.argmax,但它返回错误。有人能帮我解决这个问题吗?

y_true.shape
(29, 16, 16, 11)

y_pred.shape
(29, 16, 16, 11)

from sklearn.metrics import confusion_matrix

cf_matrix = confusion_matrix(y_true.argmax(axis=1),y_pred.argmax(axis=1))
print(cf_matrix)

ValueError                                Traceback (most recent call last)
<ipython-input-100-cea12dc5adac> in <module>()
      1 from sklearn.metrics import confusion_matrix
      2 
----> 3 cf_matrix = confusion_matrix(y_true.argmax(axis=1),y_pred.argmax(axis=1))
      4 print(cf_matrix)

1 frames
/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py in _check_targets(y_true, y_pred)
    102     # No metrics support "multiclass-multioutput" format
    103     if y_type not in ["binary", "multiclass", "multilabel-indicator"]:
--> 104         raise ValueError("{0} is not supported".format(y_type))
    105 
    106     if y_type in ["binary", "multiclass"]:

ValueError: unknown is not supported
7rfyedvj

7rfyedvj1#

它的工作原理是将y_pred_argmax和y_true都拉平,我会使用以下代码:

y_pred_argmax = np.argmax(y_pred, axis=3)
y_pred_argmax = np.expand_dims(y_pred_argmax, axis=3) 
y_pred_argmax = to_categorical(y_pred_argmax, num_classes=11)
y_pred_flattened = y_pred_argmax.flatten()
y_flattened = y.flatten()
cf_matrix = confusion_matrix(y_flattened, y_pred_flattened, normalize='true')

相关问题