Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a method to export decision paths #2

Merged
merged 1 commit into from
Nov 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions treefarms/model/treefarms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, configuration={}):
self.model_set = None
self.dataset = None

# TODO: implement this
# TODO: implement this
def load(self, path):
"""
Parameters
Expand Down Expand Up @@ -87,7 +87,7 @@ def __train__(self, X, y):
self.model_set = ModelSetContainer(result)

print(f"training completed. Number of trees in the Rashomon set: {self.model_set.get_tree_count()}")


def fit(self, X, y):
"""
Expand All @@ -104,7 +104,7 @@ def fit(self, X, y):
self.__train__(X, y)
return self

# TODO: implement this
# TODO: implement this
def predict(self, X):
"""
Parameters
Expand Down Expand Up @@ -133,7 +133,7 @@ def __getitem__(self, idx):
if self.model_set is None:
raise Exception("Error: Model not yet trained")
return self.model_set.__getitem__(idx)

def get_tree_count(self):
"""Returns the number of trees in the Rashomon set

Expand All @@ -145,30 +145,47 @@ def get_tree_count(self):
if self.model_set is None:
raise Exception("Error: Model not yet trained")
return self.model_set.get_tree_count()

def visualize(self, feature_names=None, feature_description=None, *, width=500, height=650):
"""Generates a visualization of the Rashomon set using `timbertrek`

def get_decision_paths(self, feature_names=None, feature_description=None):
"""Create a hierarchical dictionary describing the decision paths in the
Rashomon set using `timbertrek`.
Parameters
---
feature_names : matrix-like, shape = [m_features + 1]
a matrix where each row is a sample to be predicted and each column is a feature to be used for prediction
"""
if self.model_set is None:
raise Exception("Error: Model not yet trained")

# Convert the trie structure to decision paths
trie = self.model_set.to_trie()
df = self.dataset
if feature_names is None:
feature_names = df.columns

decision_paths = timbertrek.transform_trie_to_rules(
trie,
df,
feature_names=feature_names,
feature_description=feature_description,
)

# return decision_paths


return decision_paths

def visualize(self, feature_names=None, feature_description=None, *, width=500, height=650):
"""Generates a visualization of the Rashomon set using `timbertrek`
Parameters
---
feature_names : matrix-like, shape = [m_features + 1]
a matrix where each row is a sample to be predicted and each column is a feature to be used for prediction
"""
# Get the decision paths
decision_paths = self.get_decision_paths(
feature_names=feature_names,
feature_description=feature_description
)

# Show in the in-notebook visualization
timbertrek.visualize(decision_paths, width=width, height=height)

def __translate__(self, leaves):
Expand Down