python 可视化决策树不仅有训练集标签分布,而且有测试集标签分布

5jvtdoz2  于 2023-02-18  发布在  Python
关注(0)|答案(1)|浏览(101)

bounty将在4天后过期。回答此问题可获得+50的声誉奖励。user1627466正在寻找来自声誉良好来源的答案:我正在寻找一个答案,它可以让我可视化一个特定的拟合树在测试集上的性能。理想情况下,它可以通过添加y_pred和y_test的值拆分来为每个节点添加更多细节。

例如,我们可以用训练集分布来可视化决策树

from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier 
from sklearn import tree

# Prepare the data data, can do row sample and column sample here
iris = datasets.load_iris()
X = iris.data
y = iris.target
# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
clf.fit(X, y)

fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(clf, 
                   feature_names=iris.feature_names,  
                   class_names=iris.target_names,
                   filled=True)

给出了具有训练集分布的

,例如根节点中的value = [50, 50, 50]
但是,我不能给予它一个测试集,并得到测试集在可视化树中的分布。

0yycz8jy

0yycz8jy1#

我不认为有一个sklearn方法可以做到这一点(还)。

    • 选项1:通过添加X_test信息更改树的注解图**

您可以使用下面的自定义功能:

def plot_tree_test(clf, tree_plot, X_test, y_test):

    n = len(tree_plot)
    cat = clf.n_classes_

    # Getting the path for each item in X_test
    path = clf.decision_path(X_test).toarray().transpose()

    # Looping through each node/leaf in the tree and adding information from X_test path
    for i in range(n):
        value = []
        for j in range(cat):
            value += [sum(y_test[path[i]==1]==j)]
        tree_plot[i].set_text(tree_plot[i].get_text()+f'\ntest samples = {path[i].sum()}\ntest value = {value}')
    
    return tree_plot

然后稍微修改一下脚本:

from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier 
from sklearn import tree
from sklearn.model_selection import train_test_split

# Prepare the data data, can do row sample and column sample here
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Creating a train and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True, random_state=1234)

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
clf.fit(X_train, y_train)

fig = plt.figure(figsize=(25,20))
tree_plot = tree.plot_tree(clf, 
                   feature_names=iris.feature_names,  
                   class_names=iris.target_names,
                   filled=True)

tree_plot = plot_tree_test(clf, tree_plot, X_test, y_test)
plt.show()

输出:

    • 选项2:使用X_test信息更改分类器本身**

您可以使用下面的自定义功能:

def tree_test(clf, X_test, y_test):

    state = clf.tree_.__getstate__()
    n = len(state['values'])
    cat = clf.n_classes_
    
    # Getting the path for each item in X_test
    path = clf.decision_path(X_test).toarray().transpose()
    
    # Looping through each node/leaf in the tree and adding information from X_test path
    values = []
    for i in range(n):
        value = []
        for j in range(cat):
            value += [float(sum(y_test[path[i]==1]==j))]
        values += [[value]]
        state['nodes'][i][5] = path[i].sum()
        state['nodes'][i][6] = max(path[i].sum(), 0.1) # 0 returns error
    
    values = np.array(values)
    state['values'] = values
    clf.tree_.__setstate__(state)
    return clf

然后稍微修改一下脚本:

from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier 
from sklearn import tree
from sklearn.model_selection import train_test_split
import numpy as np

# Prepare the data data, can do row sample and column sample here
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Creating a train and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True, random_state=1234)

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
clf.fit(X_train, y_train)

clf = tree_test(clf, X_test, y_test)

fig = plt.figure(figsize=(25,20))
tree_plot = tree.plot_tree(clf, 
                   feature_names=iris.feature_names,  
                   class_names=iris.target_names,
                   filled=True)

plt.show()

输出:

相关问题