-
Notifications
You must be signed in to change notification settings - Fork 4
/
8_MLPS_R_tree_methods.Rmd
100 lines (84 loc) · 3.38 KB
/
8_MLPS_R_tree_methods.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
---
title: "8_MLPS_R_tree_methods"
author: "Runshan Fu (TA - Heinz CMU PhD)"
date: "2/27/2017"
output:
html_document:
css: '~/Dropbox/avenir-white.css'
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE, warning = F, error = F, message = F)
```
## Lecture 8: Tree-based Methods
Key specific tasks we covered in this lecture:
* estimating a decision tree
* predict both class probabilities and class predictions
* pre-pruning options
* chi-sq hypothesis tests
Commonly used packages for decision tree includes `tree` and `rpart`. ISLR (page 323 - 331) have detailed examples for `tree`, and we show basic usage of `rpart` here, using the same dataset.
```{r}
# Load dataset
library(ISLR)
attach(Carseats)
# create a varialbe, called 'High', which takes on a value of 'Yes' if the 'Sales' varialbe exceeds 8, and 'No' otherwise.
High <- ifelse(Sales<=8, "No", "Yes")
# merge 'High' with the rest of the data
Carseats <- data.frame(Carseats, High)
# split the data into training set and test set
set.seed(2)
train <- sample(1:nrow(Carseats), 200)
carseats.train <- Carseats[train,]
carseats.test <- Carseats[-train,]
```
Estimate a decision tree using the training set:
```{r}
library(rpart)
tree.carseats <- rpart(High ~ .-Sales, data = Carseats, subset = train)
summary(tree.carseats)
```
Please note that rpart take several optional paramters to further configure the tree. See the documents(https://cran.r-project.org/web/packages/rpart/rpart.pdf) for more details.
Visualize the tree:
```{r}
library(rattle)
fancyRpartPlot(tree.carseats, main = "Decision Tree for Carseats")
```
Predict class probabilities for test set:
```{r}
pred.prob <- predict(tree.carseats, newdata = carseats.test, type = "prob")
# sample precited probability
pred.prob[0:10,]
```
Predict class for test set and compare with the true lables
```{r}
pred.class <- predict(tree.carseats, newdata = carseats.test, type = "class")
# the confusion matrix
table(pred.class, carseats.test$High)
# corret classification rate
(90+64)/200
```
Pre-prune the tree by depth
```{r}
# Pre-prune the tree to have a maximum depth of 4
tree.pruned <-rpart(High ~ .-Sales,
data = carseats.train,
control = rpart.control(maxdepth = 4)
)
fancyRpartPlot(tree.pruned, main = "Decision Tree Pruned by Depth")
```
Prune the tree by terminal nodes
```{r}
# Prune the tree to have 4 terminal nodes
tree.carseats$cptable
tree.pruned2 <- prune(tree.carseats, cp= 0.05)
fancyRpartPlot(tree.pruned2, main = "Decision Tree Pruned by Size")
```
You can also specify cp to limit terminal nodes when estimate the tree.
`rpart` does not offer built-in chi-square test, but you may use standard chi-squre test package in R. See: https://www.r-bloggers.com/chi-squared-test/.
Additional resources: https://www.r-bloggers.com/party-with-the-first-tribe/
### Additional Notes
* Use `Fselector` to calculate information gain:
```{r}
library("FSelector")
information.gain(High~., carseats.train)
```
* `rpart` uses surrogate variables to deal with missing values, i.e. it uses other independent variables to estimate missing values. "Any observation which is missing the split variable is then classified using the first surrogate variable, or if missing that, the second surrogate is used, and etc." See details in an introduction to `rpart` section 5: https://cran.r-project.org/web/packages/rpart/vignettes/longintro.pdf