代码
1
2
3
|
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
|
(178, 13)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2])
1
2
|
import pandas as pd
pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)
|
|
0 |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
0 |
0 |
14.23 |
1.71 |
2.43 |
15.6 |
127.0 |
2.80 |
3.06 |
0.28 |
2.29 |
5.64 |
1.04 |
3.92 |
1065.0 |
0 |
1 |
13.20 |
1.78 |
2.14 |
11.2 |
100.0 |
2.65 |
2.76 |
0.26 |
1.28 |
4.38 |
1.05 |
3.40 |
1050.0 |
0 |
2 |
13.16 |
2.36 |
2.67 |
18.6 |
101.0 |
2.80 |
3.24 |
0.30 |
2.81 |
5.68 |
1.03 |
3.17 |
1185.0 |
0 |
3 |
14.37 |
1.95 |
2.50 |
16.8 |
113.0 |
3.85 |
3.49 |
0.24 |
2.18 |
7.80 |
0.86 |
3.45 |
1480.0 |
0 |
4 |
13.24 |
2.59 |
2.87 |
21.0 |
118.0 |
2.80 |
2.69 |
0.39 |
1.82 |
4.32 |
1.04 |
2.93 |
735.0 |
0 |
... |
... |
... |
... |
... |
... |
... |
... |
... |
... |
... |
... |
... |
... |
... |
173 |
13.71 |
5.65 |
2.45 |
20.5 |
95.0 |
1.68 |
0.61 |
0.52 |
1.06 |
7.70 |
0.64 |
1.74 |
740.0 |
2 |
174 |
13.40 |
3.91 |
2.48 |
23.0 |
102.0 |
1.80 |
0.75 |
0.43 |
1.41 |
7.30 |
0.70 |
1.56 |
750.0 |
2 |
175 |
13.27 |
4.28 |
2.26 |
20.0 |
120.0 |
1.59 |
0.69 |
0.43 |
1.35 |
10.20 |
0.59 |
1.56 |
835.0 |
2 |
176 |
13.17 |
2.59 |
2.37 |
20.0 |
120.0 |
1.65 |
0.68 |
0.53 |
1.46 |
9.30 |
0.60 |
1.62 |
840.0 |
2 |
177 |
14.13 |
4.10 |
2.74 |
24.5 |
96.0 |
2.05 |
0.76 |
0.56 |
1.35 |
9.20 |
0.61 |
1.60 |
560.0 |
2 |
178 rows × 14 columns
['alcohol',
'malic_acid',
'ash',
'alcalinity_of_ash',
'magnesium',
'total_phenols',
'flavanoids',
'nonflavanoid_phenols',
'proanthocyanins',
'color_intensity',
'hue',
'od280/od315_of_diluted_wines',
'proline']
array(['class_0', 'class_1', 'class_2'], dtype='<U7')
1
2
|
#XXYY
Xtrain,Xtest,ytrain,ytest=train_test_split(wine.data,wine.target,test_size=0.3)
|
(124, 13)
(54, 13)
array([2, 0, 0, 1, 1, 2, 0, 2, 0, 0, 1, 0, 1, 2, 1, 2, 1, 0, 1, 1, 0, 1,
1, 1, 0, 0, 0, 2, 0, 2, 1, 0, 2, 1, 1, 1, 1, 2, 1, 0, 2, 0, 0, 2,
0, 2, 1, 2, 1, 1, 1, 0, 0, 0, 0, 1, 2, 0, 1, 1, 0, 1, 1, 0, 0, 0,
1, 1, 2, 2, 0, 0, 1, 2, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 2, 1, 1, 0,
2, 2, 2, 2, 1, 0, 1, 2, 0, 0, 1, 0, 0, 2, 1, 0, 0, 2, 2, 1, 2, 0,
2, 1, 2, 1, 1, 1, 0, 1, 1, 1, 1, 2, 0, 0])
1
2
3
4
5
6
7
|
#通过特征计算不纯度进行分类
clf = tree.DecisionTreeClassifier(criterion="entropy",random_state=30,splitter="random")#不纯度
#random_state=30 设置随机数种子复现结果,设置随机参数
#splitter 设置为best 或random 一个设置更重要分支进行分枝,一个更随机防止过拟合
clf = clf.fit(Xtrain,ytrain)
score = clf.score(Xtest,ytest)#返回预测的准确度accuracy
score
|
0.8888888888888888
1
2
3
4
5
6
7
8
9
10
11
12
|
feature_names=wine.feature_names
#画树
import graphviz
dot_data = tree.export_graphviz(clf
,feature_names=feature_names
,class_names=["白酒","红酒","伏特加"]
,filled=True #填充颜色
,rounded=True#直角框变成圆角框
,out_file=None
)
graph = graphviz.Source(dot_data)
graph
|
1
|
clf.feature_importances_
|
array([0.26346976, 0. , 0. , 0. , 0. ,
0.02075703, 0.36896763, 0. , 0. , 0.03519618,
0.048165 , 0.20016672, 0.0632777 ])
1
2
|
#特征重要性-元组列表
[*zip(feature_names,clf.feature_importances_)]
|
[('alcohol', 0.2634697570171739),
('malic_acid', 0.0),
('ash', 0.0),
('alcalinity_of_ash', 0.0),
('magnesium', 0.0),
('total_phenols', 0.020757027052013897),
('flavanoids', 0.36896762677034245),
('nonflavanoid_phenols', 0.0),
('proanthocyanins', 0.0),
('color_intensity', 0.035196177552898056),
('hue', 0.04816499533787735),
('od280/od315_of_diluted_wines', 0.20016671726273566),
('proline', 0.0632776990069589)]
1
2
|
scroe_train = clf.score(Xtrain,ytrain)
scroe_train
|
1.0
剪枝策略
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
#限制深度 max_depth 3
#min_samples_left 分支满足条件才能向下生长 5~1 或者用百分比0.几
#通过特征计算不纯度进行分类
clf = tree.DecisionTreeClassifier(criterion="entropy"#不纯度
,random_state=30
,splitter="random"
#,max_depth=3
,min_samples_leaf=10
#,min_impurity_split=1
)
#random_state=30 设置随机数种子复现结果,设置随机参数
#splitter 设置为best 或random 一个设置更重要分支进行分枝,一个更随机防止过拟合
clf = clf.fit(Xtrain,ytrain)
score = clf.score(Xtest,ytest)#返回预测的准确度accuracy
score
feature_names=wine.feature_names
#画树
import graphviz
dot_data = tree.export_graphviz(clf
,feature_names=feature_names
,class_names=["白酒","红酒","伏特加"]
,filled=True #填充颜色
,rounded=True#直角框变成圆角框
,out_file=None
)
graph = graphviz.Source(dot_data)
graph
|
1
2
|
score = clf.score(Xtest,ytest)#返回预测的准确度accuracy
score
|
0.8518518518518519
1
2
|
#max_features 设置特征数量用几个
#min_impurity_decrease 限制信息增益 父子节点信息增益差
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
#画超参数曲线看打分确定参数
import matplotlib.pyplot as plt
test = []
for i in range(10):
clf = tree.DecisionTreeClassifier(max_depth=i+1
,criterion="entropy"
,random_state=30
,splitter="random"
)
clf = clf.fit(Xtrain,ytrain)
score = clf.score(Xtest,ytest)#返回预测的准确度accuracy
test.append(score)
plt.plot(range(1,11),test,color="red",label="max_depth")
plt.legend()
plt.show()
|
1
2
3
4
|
# 目标权重参数
#class_weight
#class_weight_fraction_leaf
#默认偏向平衡,偏向主导类,可设置为偏向少数类
|
1
2
|
#apply返回每个测试样本所在的叶子节点的索引
clf.apply(Xtest)
|
array([ 9, 4, 32, 16, 4, 4, 32, 22, 28, 32, 27, 16, 25, 16, 32, 16, 10,
16, 4, 4, 32, 4, 16, 4, 25, 16, 16, 16, 32, 32, 6, 22, 10, 4,
4, 4, 32, 29, 25, 13, 16, 4, 9, 25, 32, 4, 4, 22, 4, 22, 32,
32, 22, 4])
1
2
|
#predict返回每个测试样本的分类/回归结果
clf.predict(Xtest)
|
array([1, 2, 0, 1, 2, 2, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 2, 1, 2, 2, 0, 2,
1, 2, 1, 1, 1, 1, 0, 0, 1, 1, 2, 2, 2, 2, 0, 1, 1, 1, 1, 2, 1, 1,
0, 2, 2, 1, 2, 1, 0, 0, 1, 2])
1
|
#特征维度起码是2维,如果是一维reshape(-1,1)来给数据增加维度
|
交叉验证
1
2
3
|
from sklearn.datasets import load_boston
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeRegressor
|
1
2
|
boston = load_boston()
boston.data
|
array([[6.3200e-03, 1.8000e+01, 2.3100e+00, ..., 1.5300e+01, 3.9690e+02,
4.9800e+00],
[2.7310e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9690e+02,
9.1400e+00],
[2.7290e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9283e+02,
4.0300e+00],
...,
[6.0760e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
5.6400e+00],
[1.0959e-01, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9345e+02,
6.4800e+00],
[4.7410e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
7.8800e+00]])
array([24. , 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15. ,
18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6,
15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21. , 12.7, 14.5, 13.2,
13.1, 13.5, 18.9, 20. , 21. , 24.7, 30.8, 34.9, 26.6, 25.3, 24.7,
21.2, 19.3, 20. , 16.6, 14.4, 19.4, 19.7, 20.5, 25. , 23.4, 18.9,
35.4, 24.7, 31.6, 23.3, 19.6, 18.7, 16. , 22.2, 25. , 33. , 23.5,
19.4, 22. , 17.4, 20.9, 24.2, 21.7, 22.8, 23.4, 24.1, 21.4, 20. ,
20.8, 21.2, 20.3, 28. , 23.9, 24.8, 22.9, 23.9, 26.6, 22.5, 22.2,
23.6, 28.7, 22.6, 22. , 22.9, 25. , 20.6, 28.4, 21.4, 38.7, 43.8,
33.2, 27.5, 26.5, 18.6, 19.3, 20.1, 19.5, 19.5, 20.4, 19.8, 19.4,
21.7, 22.8, 18.8, 18.7, 18.5, 18.3, 21.2, 19.2, 20.4, 19.3, 22. ,
20.3, 20.5, 17.3, 18.8, 21.4, 15.7, 16.2, 18. , 14.3, 19.2, 19.6,
23. , 18.4, 15.6, 18.1, 17.4, 17.1, 13.3, 17.8, 14. , 14.4, 13.4,
15.6, 11.8, 13.8, 15.6, 14.6, 17.8, 15.4, 21.5, 19.6, 15.3, 19.4,
17. , 15.6, 13.1, 41.3, 24.3, 23.3, 27. , 50. , 50. , 50. , 22.7,
25. , 50. , 23.8, 23.8, 22.3, 17.4, 19.1, 23.1, 23.6, 22.6, 29.4,
23.2, 24.6, 29.9, 37.2, 39.8, 36.2, 37.9, 32.5, 26.4, 29.6, 50. ,
32. , 29.8, 34.9, 37. , 30.5, 36.4, 31.1, 29.1, 50. , 33.3, 30.3,
34.6, 34.9, 32.9, 24.1, 42.3, 48.5, 50. , 22.6, 24.4, 22.5, 24.4,
20. , 21.7, 19.3, 22.4, 28.1, 23.7, 25. , 23.3, 28.7, 21.5, 23. ,
26.7, 21.7, 27.5, 30.1, 44.8, 50. , 37.6, 31.6, 46.7, 31.5, 24.3,
31.7, 41.7, 48.3, 29. , 24. , 25.1, 31.5, 23.7, 23.3, 22. , 20.1,
22.2, 23.7, 17.6, 18.5, 24.3, 20.5, 24.5, 26.2, 24.4, 24.8, 29.6,
42.8, 21.9, 20.9, 44. , 50. , 36. , 30.1, 33.8, 43.1, 48.8, 31. ,
36.5, 22.8, 30.7, 50. , 43.5, 20.7, 21.1, 25.2, 24.4, 35.2, 32.4,
32. , 33.2, 33.1, 29.1, 35.1, 45.4, 35.4, 46. , 50. , 32.2, 22. ,
20.1, 23.2, 22.3, 24.8, 28.5, 37.3, 27.9, 23.9, 21.7, 28.6, 27.1,
20.3, 22.5, 29. , 24.8, 22. , 26.4, 33.1, 36.1, 28.4, 33.4, 28.2,
22.8, 20.3, 16.1, 22.1, 19.4, 21.6, 23.8, 16.2, 17.8, 19.8, 23.1,
21. , 23.8, 23.1, 20.4, 18.5, 25. , 24.6, 23. , 22.2, 19.3, 22.6,
19.8, 17.1, 19.4, 22.2, 20.7, 21.1, 19.5, 18.5, 20.6, 19. , 18.7,
32.7, 16.5, 23.9, 31.2, 17.5, 17.2, 23.1, 24.5, 26.6, 22.9, 24.1,
18.6, 30.1, 18.2, 20.6, 17.8, 21.7, 22.7, 22.6, 25. , 19.9, 20.8,
16.8, 21.9, 27.5, 21.9, 23.1, 50. , 50. , 50. , 50. , 50. , 13.8,
13.8, 15. , 13.9, 13.3, 13.1, 10.2, 10.4, 10.9, 11.3, 12.3, 8.8,
7.2, 10.5, 7.4, 10.2, 11.5, 15.1, 23.2, 9.7, 13.8, 12.7, 13.1,
12.5, 8.5, 5. , 6.3, 5.6, 7.2, 12.1, 8.3, 8.5, 5. , 11.9,
27.9, 17.2, 27.5, 15. , 17.2, 17.9, 16.3, 7. , 7.2, 7.5, 10.4,
8.8, 8.4, 16.7, 14.2, 20.8, 13.4, 11.7, 8.3, 10.2, 10.9, 11. ,
9.5, 14.5, 14.1, 16.1, 14.3, 11.7, 13.4, 9.6, 8.7, 8.4, 12.8,
10.5, 17.1, 18.4, 15.4, 10.8, 11.8, 14.9, 12.6, 14.1, 13. , 13.4,
15.2, 16.1, 17.8, 14.9, 14.1, 12.7, 13.5, 14.9, 20. , 16.4, 17.7,
19.5, 20.2, 21.4, 19.9, 19. , 19.1, 19.1, 20.1, 19.9, 19.6, 23.2,
29.8, 13.8, 13.3, 16.7, 12. , 14.6, 21.4, 23. , 23.7, 25. , 21.8,
20.6, 21.2, 19.1, 20.6, 15.2, 7. , 8.1, 13.6, 20.1, 21.8, 24.5,
23.1, 19.7, 18.3, 21.2, 17.5, 16.8, 22.4, 20.6, 23.9, 22. , 11.9])
1
2
3
4
5
6
7
|
regressor = DecisionTreeRegressor(random_state=0)#实例化
cross_val_score(regressor #模型
,boston.data #数据
,boston.target #特征
,cv=10 #交叉验证次数
,scoring="neg_mean_squared_error" #返回负的均方误差,不写默认返回R方
)
|
array([-18.08941176, -10.61843137, -16.31843137, -44.97803922,
-17.12509804, -49.71509804, -12.9986 , -88.4514 ,
-55.7914 , -25.0816 ])