Skip to content

Commit

Permalink
Merge pull request #199 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.12.0 release
  • Loading branch information
ablaom authored Nov 29, 2022
2 parents b045bb9 + a072539 commit 5a04aba
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "DecisionTree"
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
license = "MIT"
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
version = "0.11.3"
version = "0.12.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
65 changes: 56 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
[![CI](https://github.com/JuliaAI/DecisionTree.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/DecisionTree.jl/actions?query=workflow%3ACI)
[![Codecov](https://codecov.io/gh/JuliaAI/DecisionTree.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaAI/DecisionTree.jl)
[![Docs Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliahub.com/docs/DecisionTree/pEDeB/0.10.11/)
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7359268.svg)](https://doi.org/10.5281/zenodo.7359268)

Julia implementation of Decision Tree (CART) and Random Forest algorithms

Expand Down Expand Up @@ -44,11 +45,15 @@ Available models: `DecisionTreeClassifier, DecisionTreeRegressor, RandomForestCl
See each model's help (eg. `?DecisionTreeRegressor` at the REPL) for more information

### Classification Example

Load DecisionTree package

```julia
using DecisionTree
```

Separate Fisher's Iris dataset features and labels

```julia
features, labels = load_data("iris") # also see "adult" and "digits" datasets

Expand All @@ -57,7 +62,9 @@ features, labels = load_data("iris") # also see "adult" and "digits" datasets
features = float.(features)
labels = string.(labels)
```

Pruned Tree Classifier

```julia
# train depth-truncated classifier
model = DecisionTreeClassifier(max_depth=2)
Expand All @@ -78,8 +85,11 @@ accuracy = cross_val_score(model, features, labels, cv=3)
Also, have a look at these [classification](https://github.com/cstjean/ScikitLearn.jl/blob/master/examples/Classifier_Comparison_Julia.ipynb) and [regression](https://github.com/cstjean/ScikitLearn.jl/blob/master/examples/Decision_Tree_Regression_Julia.ipynb) notebooks.

## Native API

### Classification Example

Decision Tree Classifier

```julia
# train full-tree classifier
model = build_tree(labels, features)
Expand Down Expand Up @@ -129,6 +139,7 @@ accuracy = nfoldCV_tree(labels, features,
rng = seed)
```
Random Forest Classifier

```julia
# train random forest classifier
# using 2 random features, 10 trees, 0.5 portion of samples per tree, and a maximum tree depth of 6
Expand Down Expand Up @@ -176,7 +187,9 @@ accuracy = nfoldCV_forest(labels, features,
verbose = true,
rng = seed)
```

Adaptive-Boosted Decision Stumps Classifier

```julia
# train adaptive-boosted stumps, using 7 iterations
model, coeffs = build_adaboost_stumps(labels, features, 7);
Expand All @@ -193,13 +206,15 @@ accuracy = nfoldCV_stumps(labels, features,
```

### Regression Example

```julia
n, m = 10^3, 5
features = randn(n, m)
weights = rand(-2:2, m)
labels = features * weights
```
Regression Tree

```julia
# train regression tree
model = build_tree(labels, features)
Expand Down Expand Up @@ -238,7 +253,9 @@ r2 = nfoldCV_tree(labels, features,
verbose = true,
rng = seed)
```

Regression Random Forest

```julia
# train regression forest, using 2 random features, 10 trees,
# averaging of 5 samples per leaf, and 0.7 portion of samples per tree
Expand Down Expand Up @@ -285,6 +302,14 @@ r2 = nfoldCV_forest(labels, features,
rng = seed)
```

## Saving Models
Models can be saved to disk and loaded back with the use of the [JLD2.jl](https://github.com/JuliaIO/JLD2.jl) package.
```julia
using JLD2
@save "model_file.jld2" model
```
Note that even though features and labels of type `Array{Any}` are supported, it is highly recommended that data be cast to explicit types (ie with `float.(), string.()`, etc). This significantly improves model training and prediction execution times, and also drastically reduces the size of saved models.

## MLJ.jl API

To use DecsionTree.jl models in
Expand Down Expand Up @@ -318,15 +343,6 @@ The following methods provide measures of feature importance for all models:
`impurity_importance`, `split_importance`, `permutation_importance`. Query the document
strings for details.


## Saving Models
Models can be saved to disk and loaded back with the use of the [JLD2.jl](https://github.com/JuliaIO/JLD2.jl) package.
```julia
using JLD2
@save "model_file.jld2" model
```
Note that even though features and labels of type `Array{Any}` are supported, it is highly recommended that data be cast to explicit types (ie with `float.(), string.()`, etc). This significantly improves model training and prediction execution times, and also drastically reduces the size of saved models.

## Visualization
A `DecisionTree` model can be visualized using the `print_tree`-function of its native interface
(for an example see above in section 'Classification Example').
Expand All @@ -335,3 +351,34 @@ In addition, an abstraction layer using `AbstractTrees.jl` has been implemented

Apart from this, `AbstractTrees.jl` brings its own implementation of `print_tree`.


## Citing the package in publications

DOI: [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7359268.svg)](https://doi.org/10.5281/zenodo.7359268).

BibTeX entry:

```
@software{ben_sadeghi_2022_7359268,
author = {Ben Sadeghi and
Poom Chiarawongse and
Kevin Squire and
Daniel C. Jones and
Andreas Noack and
Cédric St-Jean and
Rik Huijzer and
Roland Schätzle and
Ian Butterworth and
Yu-Fong Peng and
Anthony Blaom},
title = {{DecisionTree.jl - A Julia implementation of the
CART Decision Tree and Random Forest algorithms}},
month = nov,
year = 2022,
publisher = {Zenodo},
version = {0.11.3},
doi = {10.5281/zenodo.7359268},
url = {https://doi.org/10.5281/zenodo.7359268}
}
```
>
6 changes: 2 additions & 4 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,10 @@ function build_forest(
loss = (ns, n) -> util.entropy(ns, n, entropy_terms)

if rng isa Random.AbstractRNG
shared_seed = rand(rng, UInt)
Threads.@threads for i in 1:n_trees
# The Mersenne Twister (Julia's default) is not thread-safe.
_rng = copy(rng)
# Take some elements from the ring to have different states for each tree. This
# is the only way given that only a `copy` can be expected to exist for RNGs.
rand(_rng, i)
_rng = Random.seed!(copy(rng), shared_seed + i)
inds = rand(_rng, 1:t_samples, n_samples)
forest[i] = build_tree(
labels[inds],
Expand Down
6 changes: 2 additions & 4 deletions src/regression/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,10 @@ function build_forest(
forest = impurity_importance ? Vector{Root{S, T}}(undef, n_trees) : Vector{LeafOrNode{S, T}}(undef, n_trees)

if rng isa Random.AbstractRNG
shared_seed = rand(rng, UInt)
Threads.@threads for i in 1:n_trees
# The Mersenne Twister (Julia's default) is not thread-safe.
_rng = copy(rng)
# Take some elements from the ring to have different states for each tree.
# This is the only way given that only a `copy` can be expected to exist for RNGs.
rand(_rng, i)
_rng = Random.seed!(copy(rng), shared_seed + i)
inds = rand(_rng, 1:t_samples, n_samples)
forest[i] = build_tree(
labels[inds],
Expand Down

0 comments on commit 5a04aba

Please sign in to comment.