matplotlib 在sklearn.metrics.plot_confusion_matrix中取消科学记数法

6bc51xsx  于 2023-04-21  发布在  其他
关注(0)|答案(4)|浏览(185)

我试图很好地绘制混淆矩阵,所以我遵循scikit-learn较新版本0.22的built plot混淆矩阵函数。然而,我的混淆矩阵值的一个值是153,但它在混淆矩阵图中显示为1.5e+02:

根据scikit-learn的文档,我发现了一个名为values_format的参数,但我不知道如何操作这个参数,以便它可以抑制科学计数法。我的代码如下。

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_confusion_matrix

# import some data to play with

X = pd.read_csv("datasets/X.csv")
y = pd.read_csv("datasets/y.csv")

class_names = ['Not Fraud (positive)', 'Fraud (negative)']

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
logreg = LogisticRegression()
logreg.fit(X_train, y_train)

np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
titles_options = [("Confusion matrix, without normalization", None),
                  ("Normalized confusion matrix", 'true')]
for title, normalize in titles_options:
    disp = plot_confusion_matrix(logreg, X_test, y_test,
                                 display_labels=class_names,
                                 cmap=plt.cm.Greens,
                                 normalize=normalize, values_format = '{:.5f}'.format)
    disp.ax_.set_title(title)

    print(title)
    print(disp.confusion_matrix)

plt.show()
wgx48brx

wgx48brx1#

只需从调用参数声明中删除“.format”和{}方括号:

disp = plot_confusion_matrix(logreg, X_test, y_test,
                                 display_labels=class_names,
                                 cmap=plt.cm.Greens,
                                 normalize=normalize, values_format = '.5f')

此外,您可以使用'.5g'来避免十进制0
取自源

hrirmatl

hrirmatl2#

如果有人使用seabornheatmap来绘制混淆矩阵,并且上面的答案都不起作用。您应该在混淆矩阵seaborn中关闭科学记数法fmt='g',如下所示:

sns.heatmap(conf_matrix,annot=True, fmt='g')
lokaqttq

lokaqttq3#

只需传递values_format=''示例:

plot_confusion_matrix(clf, X_test, Y_test, values_format = '')
xlpyo6sf

xlpyo6sf4#

在较新的scikit-learn版本(version >= 1.0)中,sklearn.metrics.ConfusionMatrixDisplay.plot方法使用参数values_format
scikit-learn ConfusionMatrixDisplay文档

import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    random_state=0)
clf = SVC(random_state=0)
clf.fit(X_train, y_train)

predictions = clf.predict(X_test)
cm = confusion_matrix(y_test, predictions, labels=clf.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=clf.classes_)

# # #
# added `values_format` in contrast to original sklearn docs
# # #
disp.plot(values_format="d")

plt.show()

相关问题