Skip to content

Commit

Permalink
Merge pull request #61 from tlapusan/leaf_samples_distribution
Browse files Browse the repository at this point in the history
Leaf samples distribution
  • Loading branch information
parrt authored Oct 28, 2019
2 parents 945b31b + 7afd78f commit 6a74b03
Show file tree
Hide file tree
Showing 3 changed files with 34,136 additions and 13,416 deletions.
17 changes: 13 additions & 4 deletions dtreeviz/shadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,28 @@ def get_node_type(_tree_model):
return node_type

@staticmethod
def get_leaf_sample_counts(_tree_model):
def get_leaf_sample_counts(_tree_model, min_samples=0, max_samples=None):
"""Get the number of samples for each leaf.
There is the option to filter the leaves with less than min_samples or more than max_samples.
:param min_samples: int
Min number of samples for a leaf
:param max_samples: int
Max number of samples for a leaf
:return: tuple
Contains a list of leaf ids and a list of leaf samples
Contains a numpy array of leaf ids and an array of leaf samples
"""

node_type = ShadowDecTree.get_node_type(_tree_model)
n_node_samples = _tree_model.tree_.n_node_samples

leaf_samples = [(i, n_node_samples[i]) for i in range(0, _tree_model.tree_.node_count) if node_type[i]]
max_samples = max_samples if max_samples else n_node_samples.max()
leaf_samples = [(i, n_node_samples[i]) for i in range(0, _tree_model.tree_.node_count) if node_type[i]
and min_samples <= n_node_samples[i] <= max_samples]
x, y = zip(*leaf_samples)
return x, y
return np.array(x), np.array(y)

@staticmethod
def get_leaf_sample_counts_by_class(_tree_model):
Expand Down
33 changes: 31 additions & 2 deletions dtreeviz/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,12 +1257,20 @@ def viz_leaf_samples(tree_model: (tree.DecisionTreeRegressor, tree.DecisionTreeC
colors: dict = None,
fontsize: int = 14,
fontname: str = "Arial",
grid: bool = False):
grid: bool = False,
bins: int = 10,
min_samples: int = 0,
max_samples: int = None):
"""Visualize the number of training samples from each leaf.
There is the option to filter the leaves with less than min_samples or more than max_samples. This is helpful
especially when you want to investigate leaves with number of samples from a specific range.
If display_type = 'plot' it will show leaf samples using a plot.
If display_type = 'text' it will show leaf samples as plain text. This method is preferred if number
of leaves is very large and the plot become very big and hard to interpret.
If display_type = 'hist' it will show leaf sample histogram. Useful when you want to easily see the general
distribution of leaf samples.
:param tree_model: sklearn.tree
The tree to interpret
Expand All @@ -1278,9 +1286,15 @@ def viz_leaf_samples(tree_model: (tree.DecisionTreeRegressor, tree.DecisionTreeC
Plot labels font name
:param grid: bool
Whether to show the grid lines
:param bins: int
Number of histogram bins
:param min_samples: int
Min number of samples for a leaf
:param max_samples: int
Max number of samples for a leaf
"""

leaf_id, leaf_samples = ShadowDecTree.get_leaf_sample_counts(tree_model)
leaf_id, leaf_samples = ShadowDecTree.get_leaf_sample_counts(tree_model, min_samples, max_samples)

if display_type == "plot":
colors = adjust_colors(colors)
Expand All @@ -1303,6 +1317,21 @@ def viz_leaf_samples(tree_model: (tree.DecisionTreeRegressor, tree.DecisionTreeC
elif display_type == "text":
for leaf, samples in zip(leaf_id, leaf_samples):
print(f"leaf {leaf} has {samples} samples")
elif display_type == "hist":
colors = adjust_colors(colors)

fig, ax = plt.subplots(figsize=figsize)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(.3)
ax.spines['bottom'].set_linewidth(.3)
n, bins, patches = ax.hist(leaf_samples, bins=bins, color=colors["hist_bar"])
for rect in patches:
rect.set_linewidth(.5)
rect.set_edgecolor(colors['rect_edge'])
ax.set_xlabel("leaf sample", fontsize=fontsize, fontname=fontname, color=colors['axis_label'])
ax.set_ylabel("leaf count", fontsize=fontsize, fontname=fontname, color=colors['axis_label'])
ax.grid(b=grid)


def ctreeviz_leaf_samples(tree_model: (tree.DecisionTreeClassifier),
Expand Down
Loading

0 comments on commit 6a74b03

Please sign in to comment.