One advantage of decision trees over other algorithms is the ability to visualize decision tree models. Decision trees are subdivided into classification trees, which are used to predict classifications, and regression trees, which are used to predict values. The visualization of decision trees can help us to understand the details of the algorithm in a very intuitive way. However, some problems may be encountered during the specific use. The following are some of the considerations collated.

Graphviz visualization tool

Graphviz is an open source graph (Graph) visualization software that uses abstract graphs and networks to represent structured information. One use of Graphviz in the field of data science is to implement decision tree visualization. There are still some gateways between using Graphviz. If you install graphviz using pip install graphviz the following error is reported.

ExecutableNotFound: failed to execute ‘dot’, make sure the Graphviz executables are on your systems’ PATH

The solution is to install the executable package of Graphviz and add the installation path to the PATH of the environment variable. How to use it.

Export_graphviz to export the tree to Graphviz format

1
2
3
4
5
6
7
8
9
from sklearn import tree
from sklearn.datasets import load_iris

iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf.fit(iris.data, iris.target)

with open("iris.dot", 'w') as f:
    tree.export_graphviz(clf, out_file=f)

Here a plain text file iris.dot will be generated, which you can open directly to view, similar to.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
digraph Tree {
node [shape=box] ;
0 [label="X[2] <= 2.45\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]"] ;
1 [label="gini = 0.0\nsamples = 50\nvalue = [50, 0, 0]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="X[3] <= 1.75\ngini = 0.5\nsamples = 100\nvalue = [0, 50, 50]"] ;
0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
3 [label="X[2] <= 4.95\ngini = 0.168\nsamples = 54\nvalue = [0, 49, 5]"] ;
2 -> 3 ;
4 [label="X[3] <= 1.65\ngini = 0.041\nsamples = 48\nvalue = [0, 47, 1]"] ;
3 -> 4 ;
5 [label="gini = 0.0\nsamples = 47\nvalue = [0, 47, 0]"] ;

Convert .dot files to visual graphics

To have better visualization, you can use the dot program in the graphviz executable package to convert it into a visual PDF document.

This is done by executing the following command.

1
dot -Tpdf iris.dot -o iris.pdf

After converting the PDF opened the following graphics.

Using the command line is very cumbersome, you can take the approach of installing pydotplus (pip install pydotplus) to generate PDFs. In addition, when using tree.export_graphviz to export data is also possible to add some additional parameters to make the picture look easier to understand.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
from sklearn import tree
from sklearn.datasets import load_iris
import pydotplus

iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)

graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf('iris.pdf')

1
sklearn.tree.export_graphviz(decision_tree, out_file=None, *, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, rounded=False, special_characters=False, precision=3)

Parameters passed in.

  • decision_tree: decision tree object
  • out_file: handle or name of the output file.
  • max_depth: the maximum depth of the number
  • feature_names: list of feature names
  • class_names: list of category names, sorted in ascending order
  • label: option to display purity information {‘all’, ‘root’, ’none’}
  • filled: plot nodes to indicate the purity of nodes for most classes of a classification, extreme values of regression values, or multiple outputs.
  • leaves_parallel: plot all leaf nodes at the bottom of the tree.
  • impurity: whether to show purity display
  • node_ids: whether to show the ID number of each node
  • proportion: change the display of “value” and “sample size” to proportion respectively.
  • rotate: set unTrue to draw from left to right, False to draw from top to bottom.
  • rounded: if set to not True, draws with rounded corners.
  • special_characters: False when set to ignore special characters to achieve PostScrip compatibility.
  • precision: the precision of each node value

If you feel that the generation of PDF view is more trouble, you can take the generation of images.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
from sklearn import tree
from sklearn.datasets import load_iris
import pydotplus

iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)

graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png("dtree.png")

scikit-learn’s tree.plot_tree

Starting with scikit-learn version 21.0, it is possible to use scikit-learn’s tree.plot_tree method to visualize decision trees using matplotlib, instead of relying on the hard-to-install dot library (no need to install Graphviz). The following Python code shows how to visualize a decision tree using scikit-learn.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from sklearn import tree
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf.fit(iris.data, iris.target)

plt.figure(figsize=(10, 8))
tree.plot_tree(clf,
               feature_names=iris.feature_names,
               class_names=iris.target_names,
               filled=True, rounded=True,
               )
plt.show()

Since it is similar to tree.export_graphviz in use, we will not go into details here.

Beautify the output dtreeviz

dtreeviz is a component that beautifies the output and is very simple to use.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from sklearn import tree
from sklearn.datasets import load_iris
from dtreeviz.trees import dtreeviz

iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf.fit(iris.data, iris.target)
viz = dtreeviz(clf,
               x_data=iris.data,
               y_data=iris.target,
               target_name='class',
               feature_names=iris.feature_names,
               class_names=list(iris.target_names),
               title="Decision Tree - Iris data set")
viz.save('dtreeviz.svg')

At each node, we can see the stacked histogram of the features used to split the observations, colored by class. In this way, we can see how the classes are split. the small triangles on the x-axis are the splitting points. The leaf nodes are represented by pie charts that show which class the observations in the leaf belong to. This way, we can easily see which class is the most dominant and so also the predictions of the model. We can also create a similar visualization for the test set by simply replacing the x_data and y_data parameters when calling the function. If you don’t like histograms and want to simplify the plot, you can specify fancy=False to receive the following simplified plot.

Another handy feature of dtreeviz is to improve the interpretability of the model, i.e. to highlight the path of a particular observation on the plot. In this way, we can clearly see which features contribute to class prediction. Using the following code snippet, we highlight the path of the first sample of the test set.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
from sklearn import tree
from sklearn.datasets import load_iris
from dtreeviz.trees import dtreeviz

iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf.fit(iris.data, iris.target)
viz = dtreeviz(clf,
               x_data=iris.data,
               y_data=iris.target,
               target_name='class',
               feature_names=iris.feature_names,
               class_names=list(iris.target_names),
               title="Decision Tree - Iris data set",
               X=iris.data[0])
viz.save('dtreeviz.svg')

This graph is very similar to the previous one, however, the orange highlights clearly show the path followed by the sample. In addition, we can see the orange triangle on each histogram. It indicates the observed value for a given feature. We can also change the orientation of the plot by setting orientation=“LR” from top to bottom and then from left to right.

Finally, we can print the decision used for this observation prediction in plain English.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
from sklearn import tree
from sklearn.datasets import load_iris
from dtreeviz.trees import dtreeviz
from dtreeviz.trees import explain_prediction_path

iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf.fit(iris.data, iris.target)
viz = dtreeviz(clf,
               x_data=iris.data,
               y_data=iris.target,
               target_name='class',
               feature_names=iris.feature_names,
               class_names=list(iris.target_names),
               title="Decision Tree - Iris data set",
               X=iris.data[0])
viz.save('dtreeviz.svg')

# 输出解释
print(explain_prediction_path(clf, iris.data[0], feature_names=iris.feature_names, explanation_type="plain_english"))
# petal width (cm) < 0.8

The previous article has introduced the decision tree classification example, the next required look at the decision tree regression.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
from sklearn import tree
from sklearn.datasets import load_boston
from dtreeviz.trees import dtreeviz

boston = load_boston()

reg = tree.DecisionTreeRegressor(max_depth=3)
reg.fit(boston.data, boston.target)

viz = dtreeviz(reg,
               x_data=boston.data,
               y_data=boston.target,
               target_name='price',
               feature_names=boston.feature_names,
               title="Decision Tree - Boston housing",
               show_node_labels=True)

viz.save('dtreeviz.svg')

We examine the difference between classification trees and regression trees. This time instead of histograms, we examine the scatterplots of features used for segmentation and targeting. We see some dashed lines on these scatterplots. Their interpretation is as follows.

  • The horizontal line is the target mean of the left and right edges in the decision node.
  • The vertical line is the segmentation point. It represents exactly the same information as the black triangle.

In the leaf nodes, the dashed lines indicate the average of the targets within the leaves, which are also predicted by the model. We can go a step further and plot only the nodes used for prediction. To do this, we specify show_just_path=True. The figure below shows only the selected nodes in the tree above.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from sklearn import tree
from sklearn.datasets import load_boston
from dtreeviz.trees import dtreeviz

boston = load_boston()

reg = tree.DecisionTreeRegressor(max_depth=3)
reg.fit(boston.data, boston.target)

viz = dtreeviz(reg,
               x_data=boston.data,
               y_data=boston.target,
               target_name='price',
               feature_names=boston.feature_names,
               title="Decision Tree - Boston housing",
               X=boston.data[0],
               show_just_path=True)

viz.save('dtreeviz.svg')