决策树可视化

需要安装以下packages

1
2
pip install graphviz
pip install pydotplus

绘制决策树的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from sklearn.tree import export_graphviz
import traceback
import subprocess

def visualize_tree(dt, path=None, feature_names=None):
"""
决策树可视化
:param dt: sklearn.tree.DecisionTreeClassifier().fit后的实例
:param path: 图片保存父路径
:param feature_names: 特征名,默认取data.columns
:return:
"""
if path is None:
dot_path = 'dt.dot'
png_path = 'dt.png'
else:
dot_path = os.path.join(path, 'dt.dot')
png_path = os.path.join(path, 'dt.png')

if feature_names is None:
feature_names = ['feature_%s'%i for i in range(dt.n_features_)]

with open(dot_path, 'w') as f:
export_graphviz(dt, out_file=f,
feature_names=feature_names)

command = ["dot", "-Tpng", dot_path, "-o", png_path]
try:
subprocess.check_call(command)
except Exception as e:
print(traceback.format_exc())
exit("Could not run dot, ie graphviz, to produce visualization")