前言
数据集来自kaggle
链接:https://www.kaggle.com/c/titanic/data
里面的test和train的csv数据集为所需数据集。
代码
1
2
3
4
5
6
7
|
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier #分类器 只能分类数字
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
|
1
2
|
data = pd.read_csv("train.csv")
data.head()
|
|
PassengerId |
Survived |
Pclass |
Name |
Sex |
Age |
SibSp |
Parch |
Ticket |
Fare |
Cabin |
Embarked |
0 |
1 |
0 |
3 |
Braund, Mr. Owen Harris |
male |
22.0 |
1 |
0 |
A/5 21171 |
7.2500 |
NaN |
S |
1 |
2 |
1 |
1 |
Cumings, Mrs. John Bradley (Florence Briggs Th... |
female |
38.0 |
1 |
0 |
PC 17599 |
71.2833 |
C85 |
C |
2 |
3 |
1 |
3 |
Heikkinen, Miss. Laina |
female |
26.0 |
0 |
0 |
STON/O2. 3101282 |
7.9250 |
NaN |
S |
3 |
4 |
1 |
1 |
Futrelle, Mrs. Jacques Heath (Lily May Peel) |
female |
35.0 |
1 |
0 |
113803 |
53.1000 |
C123 |
S |
4 |
5 |
0 |
3 |
Allen, Mr. William Henry |
male |
35.0 |
0 |
0 |
373450 |
8.0500 |
NaN |
S |
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 891 non-null int64
1 Survived 891 non-null int64
2 Pclass 891 non-null int64
3 Name 891 non-null object
4 Sex 891 non-null object
5 Age 714 non-null float64
6 SibSp 891 non-null int64
7 Parch 891 non-null int64
8 Ticket 891 non-null object
9 Fare 891 non-null float64
10 Cabin 204 non-null object
11 Embarked 889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
1
2
3
4
5
6
|
#筛选特征
data.drop(["Name","Ticket","Cabin"],inplace=True,axis=1)
#处理缺失值
data["Age"] = data["Age"].fillna(data["Age"].mean())
#删除缺失值少的行
data = data.dropna()
|
<class 'pandas.core.frame.DataFrame'>
Int64Index: 889 entries, 0 to 890
Data columns (total 9 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 889 non-null int64
1 Survived 889 non-null int64
2 Pclass 889 non-null int64
3 Sex 889 non-null object
4 Age 889 non-null float64
5 SibSp 889 non-null int64
6 Parch 889 non-null int64
7 Fare 889 non-null float64
8 Embarked 889 non-null object
dtypes: float64(2), int64(5), object(2)
memory usage: 69.5+ KB
1
|
labels = data["Embarked"].unique().tolist()
|
1
2
3
|
#转换多分类为数值
data["Embarked"] = data["Embarked"].apply(lambda x: labels.index(x))
#data["Sex"] = data["Sex"].apply(lambda x: labels.index(x))
|
1
2
3
|
#转换二分类为01变量
data["Sex"] = (data["Sex"] == "male").astype("int")
(data["Sex"] == "male").astype("int")
|
0 0
1 0
2 0
3 0
4 0
..
886 0
887 0
888 0
889 0
890 0
Name: Sex, Length: 889, dtype: int64
|
PassengerId |
Survived |
Pclass |
Sex |
Age |
SibSp |
Parch |
Fare |
Embarked |
0 |
1 |
0 |
3 |
1 |
22.000000 |
1 |
0 |
7.2500 |
0 |
1 |
2 |
1 |
1 |
0 |
38.000000 |
1 |
0 |
71.2833 |
1 |
2 |
3 |
1 |
3 |
0 |
26.000000 |
0 |
0 |
7.9250 |
0 |
3 |
4 |
1 |
1 |
0 |
35.000000 |
1 |
0 |
53.1000 |
0 |
4 |
5 |
0 |
3 |
1 |
35.000000 |
0 |
0 |
8.0500 |
0 |
... |
... |
... |
... |
... |
... |
... |
... |
... |
... |
886 |
887 |
0 |
2 |
1 |
27.000000 |
0 |
0 |
13.0000 |
0 |
887 |
888 |
1 |
1 |
0 |
19.000000 |
0 |
0 |
30.0000 |
0 |
888 |
889 |
0 |
3 |
0 |
29.699118 |
1 |
2 |
23.4500 |
0 |
889 |
890 |
1 |
1 |
1 |
26.000000 |
0 |
0 |
30.0000 |
1 |
890 |
891 |
0 |
3 |
1 |
32.000000 |
0 |
0 |
7.7500 |
2 |
889 rows × 9 columns
1
2
|
x = data.iloc[:,data.columns != "Survived"]
x
|
|
PassengerId |
Pclass |
Sex |
Age |
SibSp |
Parch |
Fare |
Embarked |
0 |
1 |
3 |
1 |
22.000000 |
1 |
0 |
7.2500 |
0 |
1 |
2 |
1 |
0 |
38.000000 |
1 |
0 |
71.2833 |
1 |
2 |
3 |
3 |
0 |
26.000000 |
0 |
0 |
7.9250 |
0 |
3 |
4 |
1 |
0 |
35.000000 |
1 |
0 |
53.1000 |
0 |
4 |
5 |
3 |
1 |
35.000000 |
0 |
0 |
8.0500 |
0 |
... |
... |
... |
... |
... |
... |
... |
... |
... |
886 |
887 |
2 |
1 |
27.000000 |
0 |
0 |
13.0000 |
0 |
887 |
888 |
1 |
0 |
19.000000 |
0 |
0 |
30.0000 |
0 |
888 |
889 |
3 |
0 |
29.699118 |
1 |
2 |
23.4500 |
0 |
889 |
890 |
1 |
1 |
26.000000 |
0 |
0 |
30.0000 |
1 |
890 |
891 |
3 |
1 |
32.000000 |
0 |
0 |
7.7500 |
2 |
889 rows × 8 columns
1
2
|
y = data.iloc[:,data.columns == "Survived"]
y
|
|
Survived |
0 |
0 |
1 |
1 |
2 |
1 |
3 |
1 |
4 |
0 |
... |
... |
886 |
0 |
887 |
1 |
888 |
0 |
889 |
1 |
890 |
0 |
889 rows × 1 columns
1
|
Xtrain,Xtest,ytrain,ytest = train_test_split(x,y,test_size=0.3)
|
1
2
3
4
|
Xtrain.index = range(Xtrain.shape[0])
#重构index
#Xtrain.reset_index(drop=True,inplace=True)
#会多一列index列,不好用
|
1
2
|
for i in [Xtrain,Xtest,ytrain,ytest]:
i.index = range(i.shape[0])
|
1
2
3
4
|
clf = DecisionTreeClassifier(random_state=25)
clf = clf.fit(Xtrain,ytrain)
score = clf.score(Xtest,ytest)
score
|
0.7602996254681648
1
2
3
4
|
clf = DecisionTreeClassifier(random_state=25)
#别划分测试和训练集,自动划分
score = cross_val_score(clf,x,y,cv=10).mean()
score
|
0.7469611848825333
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
# 交叉验证
tr = []#训练集分数
te = []#交叉验证分数
for i in range(10):
clf = DecisionTreeClassifier(random_state=25
,max_depth=i+1
,criterion="entropy"#熵,欠拟合才用
)
clf = clf.fit(Xtrain,ytrain)
score_tr = clf.score(Xtrain,ytrain)
score_te = cross_val_score(clf,x,y,cv=10).mean()
tr.append(score_tr)
te.append(score_te)
print(max(te))
|
0.8166624106230849
1
2
3
4
5
|
plt.plot(range(1,11),tr,color="red",label="train")
plt.plot(range(1,11),te,color="blue",label="test")
plt.xticks(range(1,11))
plt.legend()
plt.show()
|
网格搜索 同调多参数
[0, 5, 10, 15, 20, 25, 30, 35, 40, 45]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
#min_impurity_decrease
gini_thresholds = np.linspace(0,0.5,20)#线性插值取20个数在0~0.5中
#entropy_threholds = np.linspace(0,1,50)#线性插值取50个数在0~1中
#parameters 本质是一串参数和这些参数对应的,希望网格搜索来搜索的参数的取值范围
parameters = {"criterion":("gini","entropy")#基尼 信息熵
,"splitter":("best","random")
,"max_depth":[*range(1,10)]#1~更多
,"min_samples_leaf":[*range(1,50,5)]#1~更多
,"min_impurity_decrease":[*np.linspace(0,0.5,20)]#看选基尼还是信息熵,区间变化
}
clf = DecisionTreeClassifier(random_state=25)#决策树分类器 随机稳定25
GS = GridSearchCV(clf,parameters,cv=10)
GS = GS.fit(Xtrain,ytrain)
|
1
|
GS.best_params_ #从我们输入的参数和参数取值的列表中,返回最佳组合
|
{'criterion': 'entropy',
'max_depth': 8,
'min_impurity_decrease': 0.0,
'min_samples_leaf': 1,
'splitter': 'random'}
1
|
GS.best_score_ #网格搜索后的模型的评判标准
|
0.8200204813108039