forked from dlab-berkeley/Machine-Learning-in-R
-
Notifications
You must be signed in to change notification settings - Fork 0
/
04-decision-trees.Rmd
120 lines (82 loc) · 3.68 KB
/
04-decision-trees.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Decision Trees
## Load packages
```{r load_packages}
library(ggplot2)
library(rpart)
library(rpart.plot)
```
## Load data
Load `train_x_class`, `train_y_class`, `test_x_class`, and `test_y_class` variables we defined in 02-preprocessing.Rmd for this *classification* task.
```{r setup_data}
# Objects: task_reg, task_class
load("data/preprocessed.RData")
```
## Overview
Decision trees are recursive partitioning methods that divide the predictor spaces into simpler regions and can be visualized in a tree-like structure. They attempt to classify data by dividing it into subsets according to a Y output variable and based on some predictors.
Let's see how a decision tree classifies if a person suffers from heart disease (`target` = 1) or not (`target` = 0).
## Fit Model
```{r}
set.seed(3)
tree = rpart::rpart(train_y_class ~ ., data = train_x_class,
# Use method = "anova" for a continuous outcome.
method = "class",
# Can use "gini" for gini coefficient.
parms = list(split = "information"))
# https://stackoverflow.com/questions/4553947/decision-tree-on-information-gain
# Here is the text-based display of the decision tree. Yikes! :^(
print(tree)
```
Although interpreting the text can be intimidating, a decision tree's main strength is its tree-like plot, which is much easier to interpret.
## Investigate Results
```{r plot_tree}
rpart.plot::rpart.plot(tree)
```
We can also look inside of `tree` to see what we can unpack. "variable.importance" is one we should check out!
```{r}
names(tree)
tree$variable.importance
```
Plot variable importance
```{r}
# Turn the tree$variable.importance vector into a dataframe
tree_varimp = data.frame(tree$variable.importance)
# Add rownames as their own column
tree_varimp$x = rownames(tree_varimp)
# Reorder clumns
tree_varimp = tree_varimp[, c(2,1)]
# Reset row names
rownames(tree_varimp) = NULL
# Rename columns
names(tree_varimp) = c("Variable", "Importance")
tree_varimp
# Plot
ggplot(tree_varimp, aes(x = reorder(Variable, Importance),
y = Importance)) +
geom_bar(stat = "identity") +
theme_bw() + coord_flip() + xlab("")
```
In decision trees the main hyperparameter (configuration setting) is the **complexity parameter** (CP), but the name is a little counterintuitive; a high CP results in a simple decision tree with few splits, whereas a low CP results in a larger decision tree with many splits.
`rpart` uses cross-validation internally to estimate the accuracy at various CP settings. We can review those to see what setting seems best.
Print the results for various CP settings - we want the one with the lowest "xerror". We can also plot the performance estimates for different CP settings.
```{r plotcp_tree}
# Show estimated error rate at different complexity parameter settings.
printcp(tree)
# Plot those estimated error rates.
plotcp(tree)
# Trees of similar sizes might appear to be tied for lowest "xerror", but a tree with fewer splits might be easier to interpret.
tree_pruned2 = prune(tree, cp = 0.028986) # 2 splits
tree_pruned6 = prune(tree, cp = 0.010870) # 6 splits
```
Print detailed results, variable importance, and summary of splits.
```{r}
summary(tree_pruned2)
rpart.plot(tree_pruned2)
```
```{r}
summary(tree_pruned6)
rpart.plot(tree_pruned6)
```
You can also get more fine-grained control by checking out the "control" argument inside the rpart function. Type `?rpart` to learn more.
Be sure to check out [gormanalysis](https://www.gormanalysis.com/blog/decision-trees-in-r-using-rpart/) excellent overview to help internalize what you learned in this example.
## Challenge 2
Open Challenge 2 in the "Challenges" folder.