Skip to content

Commit

Permalink
changing some of the method names to reduce confusion/overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
brifordwylie committed May 28, 2024
1 parent b6bea8b commit 4d2f414
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/model/model_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@

# Get the model metrics and regression predictions
print(model.performance_metrics())
print(model.predictions())
print(model.get_predictions())
2 changes: 1 addition & 1 deletion src/sageworks/core/artifacts/model_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def shapley_values(self, capture_uuid: str = "training_holdout") -> Union[list[p

# Grab our regression predictions from S3
print("Captured Predictions: (might be None)")
print(my_model.predictions())
print(my_model.get_predictions())

# Grab our Shapley values from S3
print("Shapley Values: (might be None)")
Expand Down
2 changes: 1 addition & 1 deletion src/sageworks/web_components/regression_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def create_component(self, component_id: str) -> dcc.Graph:

def update_properties(self, model: Model, inference_run: str = None) -> go.Figure:
# Get predictions for specific inference
df = model.predictions(inference_run)
df = model.get_predictions(inference_run)

if df is None:
return self.display_text("No Data")
Expand Down
2 changes: 1 addition & 1 deletion tests/transforms/model_metrics_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_retrieval_with_capture_uuid():
print(f"\n\n*** Retrieval with Capture UUID ({capture_uuid}) ***")
pprint(model_class.inference_metadata(capture_uuid).head()) # Needed
pprint(model_class.performance_metrics(capture_uuid).head()) # Might be deprecated
pprint(model_class.predictions(capture_uuid).head()) # Needed
pprint(model_class.get_predictions(capture_uuid).head()) # Needed
pprint(model_class.confusion_matrix(capture_uuid)) # Might be deprecated
# Classifiers have a list of dataframes for shap values
shap_list = model_class.shapley_values(capture_uuid)
Expand Down

0 comments on commit 4d2f414

Please sign in to comment.