-
Notifications
You must be signed in to change notification settings - Fork 2
/
10-trees.Rmd
207 lines (153 loc) · 12.2 KB
/
10-trees.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# Trees, Random forests and Classification
## Introduction
Classification trees are non-parametric methods to recursively partition the data into more “pure” nodes, based on splitting rules.
Logistic regression vs Decision trees. It is dependent on the type of problem you are solving. Let’s look at some key factors which will help you to decide which algorithm to use:
* If the relationship between dependent & independent variable is well approximated by a linear model, linear regression will outperform tree based model.
* If there is a high non-linearity & complex relationship between dependent & independent variables, a tree model will outperform a classical regression method.
* If you need to build a model which is easy to explain to people, a decision tree model will always do better than a linear model. Decision tree models are even simpler to interpret than linear regression!
The 2 main disadventages of Decision trees:
**Over fitting**: Over fitting is one of the most practical difficulty for decision tree models. This problem gets solved by setting constraints on model parameters and pruning (discussed in detailed below).
**Not fit for continuous variables**: While working with continuous numerical variables, decision tree looses information when it categorizes variables in different categories.
Decision trees use multiple algorithms to decide to split a node in two or more sub-nodes. The creation of sub-nodes increases the homogeneity of resultant sub-nodes. In other words, we can say that purity of the node increases with respect to the target variable. Decision tree splits the nodes on all available variables and then selects the split which results in most homogeneous sub-nodes.
## First example.
Let's do a CART on the iris dataset. This is the `Hello World!` of CART.
```{r}
library(rpart)
library(rpart.plot)
data("iris")
str(iris)
table(iris$Species)
tree <- rpart(Species ~., data = iris, method = "class")
tree
```
The method-argument can be switched according to the type of the response variable. It is `class` for categorial, `anova` for numerical, `poisson` for count data and `exp for survival data.
*Important Terminology related to Decision Trees*
**Root Node**: It represents entire population or sample and this further gets divided into two or more homogeneous sets.
**Splitting**: It is a process of dividing a node into two or more sub-nodes.
**Decision Node**: When a sub-node splits into further sub-nodes, then it is called decision node.
**Leaf/ Terminal Node**: Nodes do not split is called Leaf or Terminal node.
**Pruning**: When we remove sub-nodes of a decision node, this process is called pruning. You can say opposite process of splitting.
**Branch / Sub-Tree**: A sub section of entire tree is called branch or sub-tree.
**Parent and Child Node**: A node, which is divided into sub-nodes is called parent node of sub-nodes where as sub-nodes are the child of parent node.
```{r iris_tree}
rpart.plot(tree)
```
This is a model with a **multi-class response**.
Each node shows
* the predicted class (setosa, versicolor, virginica),
* the predicted probability of each class,
* the percentage of observations in the node
```{r}
table(iris$Species, predict(tree, type = "class"))
```
## Second Example.
Data set is the titanic. This is a model with a **binary response**.
```{r}
data("ptitanic")
str(ptitanic)
ptitanic$age <- as.numeric(ptitanic$age)
ptitanic$sibsp <- as.integer(ptitanic$sibsp)
ptitanic$parch <- as.integer(ptitanic$parch)
```
Actually we can make the table more relevant.
```{r}
round(prop.table(table(ptitanic$sex, ptitanic$survived), 1), 2)
```
One can see here that the sum of the percentage add to 1 horizontally. If one want to make it vertically, we use *2*.
You can find the default limits by typing ?rpart.control. The first one we want to unleash is the `cp` parameter, this is the metric that stops splits that aren’t deemed important enough. The other one we want to open up is `minsplit` which governs how many passengers must sit in a bucket before even looking for a split.
By putting a very low `cp` we are asking to have a very deep tree. The idea is that we prune it later. So in this first regression on `ptitanic` we'll set a very low cp.
```{r titanic_tree}
library(rpart)
library(rpart.plot)
set.seed(123)
tree <- rpart(survived ~ ., data = ptitanic, cp=0.00001)
rpart.plot(tree)
```
Each node shows
* the predicted class (died or survived),
* the predicted probability of survival,
* the percentage of observations in the node.
Let's do a confusion matrix based on this tree.
```{r}
conf.matrix <- round(prop.table(table(ptitanic$survived, predict(tree, type="class")), 2), 2)
rownames(conf.matrix) <- c("Actually died", "Actually Survived")
colnames(conf.matrix) <- c("Predicted dead", "Predicted Survived")
conf.matrix
```
Let's learn a bit more about trees. By using the `name` function, one can see all the object inherent to the `tree` function.
A few intersting ones.
The `$where component indicates to which leaf the different observations have been assigned.
```{r}
names(tree)
```
How to prune a tree? We want the cp value (with a simpler tree) that minimizes the xerror. So you need to find the lowest Cross-Validation Error. 2 ways to do this. Either the `plotcp` or the `printcp` functions. The `plotcp` is a visual representation of `printcp` function.
The problem with reducing the `xerror is that the cross-validation error is a random quantity. There is no guarantee that if we were to fit the sequence of trees again using a different random seed that the same tree would minimize the cross-validation error.
A more robust alternative to minimum cross-validation error is to use the one standard deviation rule: choose the smallest tree whose cross-validation error is within one standard error of the minimum. Depending on how we define this there are two possible choices. The first tree whose point estimate of the cross-validation error falls within the ± 1 xstd of the minimum. On the other hand the standard error lower limit of the tree of size three is within + 1 xstd of the minimum.
Either of these is a reasonable choice, but insisting that the point estimate itself fall within the standard error limits is probably the more robust solution.
As discussed earlier, the technique of setting constraint is a greedy-approach. In other words, it will check for the best split instantaneously and move forward until one of the specified stopping condition is reached. Let’s consider the following case when you’re driving:
There are 2 lanes:
A lane with cars moving at 80km/h
A lane with trucks moving at 30km/h
At this instant, you are a car in the fast lane and you have 2 choices:
Take a left and overtake the other 2 cars quickly
Keep moving in the present lane
Lets analyze these choice. In the former choice, you’ll immediately overtake the car ahead and reach behind the truck and start moving at 30 km/h, looking for an opportunity to move back right. All cars originally behind you move ahead in the meanwhile. This would be the optimum choice if your objective is to maximize the distance covered in next say 10 seconds. In the later choice, you sale through at same speed, cross trucks and then overtake maybe depending on situation ahead. Greedy you!
This is exactly the difference between normal decision tree & pruning. A decision tree with constraints won’t see the truck ahead and adopt a greedy approach by taking a left. On the other hand if we use pruning, we in effect look at a few steps ahead and make a choice.
So we know pruning is better.
``` {r xval_titanic}
printcp(tree)
plotcp(tree)
tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"]
```
See if we can prune slightly the tree
```{r titanic_tree_2}
bestcp <- tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"]
tree.pruned <- prune(tree, cp = bestcp)
#this time we add a few arguments to add some mojo to our graphed tree.
#Actually this will give us a very similar graphed tree as rattle (and we like that graph!)
rpart.plot(tree.pruned, extra=104, box.palette="GnBu",
branch.lty=3, shadow.col="gray", nn=TRUE)
conf.matrix <- round(prop.table(table(ptitanic$survived, predict(tree.pruned, type="class"))), 2)
rownames(conf.matrix) <- c("Actually died", "Actually Survived")
colnames(conf.matrix) <- c("Predicted dead", "Predicted Survived")
conf.matrix
```
Another way to check the output of the classifier is with a ROC (Receiver Operating Characteristics) Curve. This plots the true positive rate against the false positive rate, and gives us a visual feedback as to how well our model is performing. The package we will use for this is ROCR.
```{r rocr_titanic_tree}
library(ROCR)
fit.pr = predict(tree.pruned, type="prob")[,2]
fit.pred = prediction(fit.pr, ptitanic$survived)
fit.perf = performance(fit.pred,"tpr","fpr")
plot(fit.perf,lwd=2,col="blue",
main="ROC: Classification Trees on Titanic Dataset")
abline(a=0,b=1)
```
Ordinarily, using the confusion matrix for creating the ROC curve would give us a single point (as it is based off True positive rate vs false positive rate). What we do here is ask the prediction algorithm to give class probabilities to each observation, and then we plot the performance of the prediction using class probability as a cutoff. This gives us the “smooth” ROC curve.
## How does a tree decide where to split?
A bit more theory, before we go further. This part has been taken from this [great tutorial](https://www.analyticsvidhya.com/blog/2016/04/complete-tutorial-tree-based-modeling-scratch-in-python/).
## Third example.
The dataset I will be using for this third example is the “Adult” dataset hosted on UCI’s Machine Learning Repository. It contains approximately 32000 observations, with 15 variables. The dependent variable that in all cases we will be trying to predict is whether or not an “individual” has an income greater than $50,000 a year.
Here is the set of variables contained in the data.
* age – The age of the individual
* type_employer – The type of employer the individual has. Whether they are government, military, private, an d so on.
* fnlwgt – The number of people the census takers believe that observation represents. We will be ignoring this variable
* education – The highest level of education achieved for that individual
* education_num – Highest level of education in numerical form
* marital – Marital status of the individual
* occupation – The occupation of the individual
* relationship – A bit more difficult to explain. Contains family relationship values like husband, father, and so on, but only contains one per observation. I’m not sure what this is supposed to represent
* race – descriptions of the individuals race. Black, White, Eskimo, and so on
* sex – Biological Sex
* capital_gain – Capital gains recorded
* capital_loss – Capital Losses recorded
* hr_per_week – Hours worked per week
* country – Country of origin for person
* income – Boolean Variable. Whether or not the person makes more than \$50,000 per annum income.
## References
* [Trees with the rpart package](http://machine-master.blogspot.com/2012/11/trees-with-rpart-package.html)
* [Wholesale customers Data Set](https://archive.ics.uci.edu/ml/datasets/Wholesale+customers) Origin of the data set of first example.
* [Titanic: Getting Started With R - Part 3: Decision Trees](http://trevorstephens.com/kaggle-titanic-tutorial/r-part-3-decision-trees/). First understanding on how to read the graph of a tree.
* [Classification and Regression Trees (CART) with rpart and rpart.plot](https://rpubs.com/minma/cart_with_rpart). Got the `Titanic` example from there as well as a first understanding on pruning.
* [Statistical Consulting Group](http://scg.sdsu.edu/ctrees_r/). We learn here how to use the ROC curve. And we got out of it the `adult`dataset.
* [A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python)](https://www.analyticsvidhya.com/blog/2016/04/complete-tutorial-tree-based-modeling-scratch-in-python/). This website is a real gems as always.
* [Stephen Milborrow. rpart.plot: Plot rpart Models. An Enhanced Version of plot.rpart., 2016. R Package.]() It is important to cite the very generous people who dedicates so much of their time to offer us great tool.