Skip to content

Commit

Permalink
[python] Improved python tree plots (#2304)
Browse files Browse the repository at this point in the history
* Some basic changes to the plot of the trees to make them readable.

* Squeezed the information in the nodes.

* Added colouring when a dictionnary mapping the features to the constraints is passed.

* Fix spaces.

* Added data percentage as an option in the nodes.

* Squeezed the information in the leaves.

* Important information is now in bold.

* Added a legend for the color of monotone splits.

* Changed "split_gain" to "gain" and "internal_value" to "value".

* Sqeezed leaves a bit more.

* Changed description in the legend.

* Revert "Sqeezed leaves a bit more."

This reverts commit dd8bf14a3ba604b0dfae3b7bb1c64b6784d15e03.

* Increased the readability for the gain.

* Tidied up the legend.

* Added the data percentage in the leaves.

* Added the monotone constraints to the dumped model.

* Monotone constraints are now specified automatically when plotting trees.

* Raise an exception instead of the bug that was here before.

* Removed operators on the branches for a clearer design.

* Small cleaning of the code.

* Setting a monotone constraint on a categorical feature now returns an exception instead of doing nothing.

* Fix bug when monotone constraints are empty.

* Fix another bug when monotone constraints are empty.

* Variable name change.

* Added is / isn't on every edge of the trees.

* Fix test "tree_create_digraph".

* Add new test for plotting trees with monotone constraints.

* Typo.

* Update documentation of categorical features.

* Typo.

* Information in nodes more explicit.

* Used regular strings instead of raw strings.

* Small refactoring.

* Some cleaning.

* Added future statement.

* Changed output for consistency.

* Updated documentation.

* Added comments for colors.

* Changed text on edges for more clarity.

* Small refactoring.

* Modified text in leaves for consistency with nodes.

* Updated default values and documentaton for consistency.

* Replaced CHECK with Log::Fatal for user-friendliness.

* Updated tests.

* Typo.

* Simplify imports.

* Swapped count and weight to improve readibility of the leaves in the plotted trees.

* Thresholds in bold.

* Made information in nodes written in a specific order.

* Added information to clarify legend.

* Code cleaning.
  • Loading branch information
CharlesAuguste authored and StrikerRUS committed Sep 8, 2019
1 parent b6d4ad8 commit f52be9b
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 47 deletions.
2 changes: 2 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,8 @@ IO Parameters

- **Note**: all negative values will be treated as **missing values**

- **Note**: the output cannot be monotonically constrained with respect to a categorical feature

- ``predict_raw_score`` :raw-html:`<a id="predict_raw_score" title="Permalink to this parameter" href="#predict_raw_score">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool, aliases: ``is_predict_raw_score``, ``predict_rawscore``, ``raw_score``

- used only in ``prediction`` task
Expand Down
1 change: 1 addition & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ struct Config {
// desc = **Note**: all values should be less than ``Int32.MaxValue`` (2147483647)
// desc = **Note**: using large values could be memory consuming. Tree decision rule works best when categorical features are presented by consecutive integers starting from zero
// desc = **Note**: all negative values will be treated as **missing values**
// desc = **Note**: the output cannot be monotonically constrained with respect to a categorical feature
std::string categorical_feature = "";

// alias = is_predict_raw_score, predict_rawscore, raw_score
Expand Down
40 changes: 40 additions & 0 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,31 @@ inline static const char* Atoi(const char* p, T* out) {
return p;
}

template <typename T>
inline void SplitToIntLike(const char *c_str, char delimiter,
std::vector<T> &ret) {
CHECK(ret.empty());
std::string str(c_str);
size_t i = 0;
size_t pos = 0;
while (pos < str.length()) {
if (str[pos] == delimiter) {
if (i < pos) {
ret.push_back({});
Atoi(str.substr(i, pos - i).c_str(), &ret.back());
}
++pos;
i = pos;
} else {
++pos;
}
}
if (i < pos) {
ret.push_back({});
Atoi(str.substr(i).c_str(), &ret.back());
}
}

template<typename T>
inline static double Pow(T base, int power) {
if (power < 0) {
Expand Down Expand Up @@ -551,6 +576,21 @@ inline static std::string Join(const std::vector<T>& strs, const char* delimiter
return str_buf.str();
}

template<>
inline std::string Join<int8_t>(const std::vector<int8_t>& strs, const char* delimiter) {
if (strs.empty()) {
return std::string("");
}
std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
str_buf << static_cast<int16_t>(strs[0]);
for (size_t i = 1; i < strs.size(); ++i) {
str_buf << delimiter;
str_buf << static_cast<int16_t>(strs[i]);
}
return str_buf.str();
}

template<typename T>
inline static std::string Join(const std::vector<T>& strs, size_t start, size_t end, const char* delimiter) {
if (end - start <= 0) {
Expand Down
1 change: 1 addition & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ def __init__(self, data, label=None, reference=None,
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
params : dict or None, optional (default=None)
Other parameters for Dataset.
free_raw_data : bool, optional (default=True)
Expand Down
2 changes: 2 additions & 0 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def train(params, train_set, num_boost_round=100,
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
early_stopping_rounds : int or None, optional (default=None)
Activates early stopping. The model will train until the validation score stops improving.
Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
Expand Down Expand Up @@ -451,6 +452,7 @@ def cv(params, train_set, num_boost_round=100,
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
early_stopping_rounds : int or None, optional (default=None)
Activates early stopping.
CV score needs to improve at least every ``early_stopping_rounds`` round(s)
Expand Down
112 changes: 79 additions & 33 deletions python-package/lightgbm/plotting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# coding: utf-8
# pylint: disable = C0103
"""Plotting library."""
from __future__ import absolute_import
from __future__ import absolute_import, division

import warnings
from copy import deepcopy
Expand Down Expand Up @@ -369,7 +369,7 @@ def plot_metric(booster, metric=None, dataset_names=None,
return ax


def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
def _to_graphviz(tree_info, show_info, feature_names, precision=3, constraints=None, **kwargs):
"""Convert specified tree to graphviz instance.
See:
Expand All @@ -380,48 +380,90 @@ def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
else:
raise ImportError('You must install graphviz to plot tree.')

def add(root, parent=None, decision=None):
def add(root, total_count, parent=None, decision=None):
"""Recursively add node or edge."""
if 'split_index' in root: # non-leaf
name = 'split{0}'.format(root['split_index'])
if feature_names is not None:
label = 'split_feature_name: {0}'.format(feature_names[root['split_feature']])
else:
label = 'split_feature_index: {0}'.format(root['split_feature'])
label += r'\nthreshold: {0}'.format(_float2str(root['threshold'], precision))
for info in show_info:
if info in {'split_gain', 'internal_value', 'internal_weight'}:
label += r'\n{0}: {1}'.format(info, _float2str(root[info], precision))
elif info == 'internal_count':
label += r'\n{0}: {1}'.format(info, root[info])
graph.node(name, label=label)
l_dec = 'yes'
r_dec = 'no'
if root['decision_type'] == '<=':
l_dec, r_dec = '<=', '>'
lte_symbol = "&#8804;"
operator = lte_symbol
elif root['decision_type'] == '==':
l_dec, r_dec = 'is', "isn't"
operator = "="
else:
raise ValueError('Invalid decision type in tree model.')
add(root['left_child'], name, l_dec)
add(root['right_child'], name, r_dec)
name = 'split{0}'.format(root['split_index'])
if feature_names is not None:
label = '<B>{0}</B> {1} '.format(feature_names[root['split_feature']], operator)
else:
label = 'feature <B>{0}</B> {1} '.format(root['split_feature'], operator)
label += '<B>{0}</B>'.format(_float2str(root['threshold'], precision))
for info in ['split_gain', 'internal_value', 'internal_weight', "internal_count", "data_percentage"]:
if info in show_info:
output = info.split('_')[-1]
if info in {'split_gain', 'internal_value', 'internal_weight'}:
label += '<br/>{0} {1}'.format(_float2str(root[info], precision), output)
elif info == 'internal_count':
label += '<br/>{0}: {1}'.format(output, root[info])
elif info == "data_percentage":
label += '<br/>{0}% of data'.format(_float2str(root['internal_count'] / total_count * 100, 2))

fillcolor = "white"
style = ""
if constraints:
if constraints[root['split_feature']] == 1:
fillcolor = "#ddffdd" # light green
if constraints[root['split_feature']] == -1:
fillcolor = "#ffdddd" # light red
style = "filled"
label = "<" + label + ">"
graph.node(name, label=label, shape="rectangle", style=style, fillcolor=fillcolor)
add(root['left_child'], total_count, name, l_dec)
add(root['right_child'], total_count, name, r_dec)
else: # leaf
name = 'leaf{0}'.format(root['leaf_index'])
label = 'leaf_index: {0}'.format(root['leaf_index'])
label += r'\nleaf_value: {0}'.format(_float2str(root['leaf_value'], precision))
if 'leaf_count' in show_info:
label += r'\nleaf_count: {0}'.format(root['leaf_count'])
label = 'leaf {0}: '.format(root['leaf_index'])
label += '<B>{0}</B>'.format(_float2str(root['leaf_value'], precision))
if 'leaf_weight' in show_info:
label += r'\nleaf_weight: {0}'.format(_float2str(root['leaf_weight'], precision))
label += '<br/>{0} weight'.format(_float2str(root['leaf_weight'], precision))
if 'leaf_count' in show_info:
label += '<br/>count: {0}'.format(root['leaf_count'])
if "data_percentage" in show_info:
label += '<br/>{0}% of data'.format(_float2str(root['leaf_count'] / total_count * 100, 2))
label = "<" + label + ">"
graph.node(name, label=label)
if parent is not None:
graph.edge(parent, name, decision)

graph = Digraph(**kwargs)
add(tree_info['tree_structure'])

graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir="LR")
if "internal_count" in tree_info['tree_structure']:
add(tree_info['tree_structure'], tree_info['tree_structure']["internal_count"])
else:
raise Exception("Cannnot plot trees with no split")

if constraints:
# "#ddffdd" is light green, "#ffdddd" is light red
legend = """<
<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="4">
<TR>
<TD COLSPAN="2"><B>Monotone constraints</B></TD>
</TR>
<TR>
<TD>Increasing</TD>
<TD BGCOLOR="#ddffdd"></TD>
</TR>
<TR>
<TD>Decreasing</TD>
<TD BGCOLOR="#ffdddd"></TD>
</TR>
</TABLE>
>"""
graph.node("legend", label=legend, shape="rectangle", color="white")
return graph


def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
def create_tree_digraph(booster, tree_index=0, show_info=None, precision=3,
old_name=None, old_comment=None, old_filename=None, old_directory=None,
old_format=None, old_engine=None, old_encoding=None, old_graph_attr=None,
old_node_attr=None, old_edge_attr=None, old_body=None, old_strict=False, **kwargs):
Expand All @@ -441,8 +483,9 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items:
'split_gain', 'internal_value', 'internal_count', 'internal_weight', 'leaf_count', 'leaf_weight'.
precision : int or None, optional (default=None)
'split_gain', 'internal_value', 'internal_count', 'internal_weight',
'leaf_count', 'leaf_weight', 'data_percentage'.
precision : int or None, optional (default=3)
Used to restrict the display of floating point values to a certain precision.
**kwargs
Other parameters passed to ``Digraph`` constructor.
Expand Down Expand Up @@ -482,6 +525,8 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
else:
feature_names = None

monotone_constraints = model.get('monotone_constraints', None)

if tree_index < len(tree_infos):
tree_info = tree_infos[tree_index]
else:
Expand All @@ -490,14 +535,14 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
if show_info is None:
show_info = []

graph = _to_graphviz(tree_info, show_info, feature_names, precision, **kwargs)
graph = _to_graphviz(tree_info, show_info, feature_names, precision, monotone_constraints, **kwargs)

return graph


def plot_tree(booster, ax=None, tree_index=0, figsize=None,
old_graph_attr=None, old_node_attr=None, old_edge_attr=None,
show_info=None, precision=None, **kwargs):
show_info=None, precision=3, **kwargs):
"""Plot specified tree.
Note
Expand All @@ -519,8 +564,9 @@ def plot_tree(booster, ax=None, tree_index=0, figsize=None,
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items:
'split_gain', 'internal_value', 'internal_count', 'internal_weight', 'leaf_count', 'leaf_weight'.
precision : int or None, optional (default=None)
'split_gain', 'internal_value', 'internal_count', 'internal_weight',
'leaf_count', 'leaf_weight', 'data_percentage'.
precision : int or None, optional (default=3)
Used to restrict the display of floating point values to a certain precision.
**kwargs
Other parameters passed to ``Digraph`` constructor.
Expand Down
1 change: 1 addition & 0 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def fit(self, X, y,
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
callbacks : list of callback functions or None, optional (default=None)
List of callback functions that are applied at each iteration.
See Callbacks in Python API for more information.
Expand Down
1 change: 1 addition & 0 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
// get feature names
feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos();
monotone_constraints_ = config->monotone_constraints;

// if need bagging, create buffer
ResetBaggingConfig(config_.get(), true);
Expand Down
1 change: 1 addition & 0 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ class GBDT : public GBDTBase {
bool need_re_bagging_;
bool balanced_bagging_;
std::string loaded_parameter_;
std::vector<int8_t> monotone_constraints_;

Json forced_splits_json_;
};
Expand Down
24 changes: 21 additions & 3 deletions src/boosting/gbdt_model_text.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
str_buf << "\"objective\":\"" << objective_function_->ToString() << "\",\n";
}

str_buf << "\"feature_names\":[\""
<< Common::Join(feature_names_, "\",\"") << "\"],"
<< '\n';
str_buf << "\"feature_names\":[\"" << Common::Join(feature_names_, "\",\"")
<< "\"]," << '\n';

str_buf << "\"monotone_constraints\":["
<< Common::Join(monotone_constraints_, ",") << "]," << '\n';

str_buf << "\"tree_info\":[";
int num_used_model = static_cast<int>(models_.size());
Expand Down Expand Up @@ -269,6 +271,11 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) cons

ss << "feature_names=" << Common::Join(feature_names_, " ") << '\n';

if (monotone_constraints_.size() != 0) {
ss << "monotone_constraints=" << Common::Join(monotone_constraints_, " ")
<< '\n';
}

ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n';

int num_used_model = static_cast<int>(models_.size());
Expand Down Expand Up @@ -364,6 +371,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
} else if (strs.size() > 2) {
if (strs[0] == "feature_names") {
key_vals[strs[0]] = cur_line.substr(std::strlen("feature_names="));
} else if (strs[0] == "monotone_constraints") {
key_vals[strs[0]] = cur_line.substr(std::strlen("monotone_constraints="));
} else {
// Use first 128 chars to avoid exceed the message buffer.
Log::Fatal("Wrong line at model file: %s", cur_line.substr(0, std::min<size_t>(128, cur_line.size())).c_str());
Expand Down Expand Up @@ -424,6 +433,15 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
return false;
}

// get monotone_constraints
if (key_vals.count("monotone_constraints")) {
Common::SplitToIntLike(key_vals["monotone_constraints"].c_str(), ' ', monotone_constraints_);
if (monotone_constraints_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Log::Fatal("Wrong size of monotone_constraints");
return false;
}
}

if (key_vals.count("feature_infos")) {
feature_infos_ = Common::Split(key_vals["feature_infos"].c_str(), ' ');
if (feature_infos_.size() != static_cast<size_t>(max_feature_idx_ + 1)) {
Expand Down
4 changes: 4 additions & 0 deletions src/io/dataset_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,10 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
BinType bin_type = BinType::NumericalBin;
if (categorical_features_.count(i)) {
bin_type = BinType::CategoricalBin;
bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
if (!feat_is_unconstrained) {
Log::Fatal("The output cannot be monotone with respect to categorical features");
}
}
bin_mappers[i].reset(new BinMapper());
if (config_.max_bin_by_feature.empty()) {
Expand Down
Loading

0 comments on commit f52be9b

Please sign in to comment.