Skip to content

Commit

Permalink
changing some of the method names to reduce confusion/overlap
Browse files Browse the repository at this point in the history
  • Loading branch information
brifordwylie committed May 28, 2024
1 parent 1794432 commit b6bea8b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
12 changes: 6 additions & 6 deletions src/sageworks/core/artifacts/model_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def confusion_matrix(self, capture_uuid: str = "latest") -> Union[pd.DataFrame,
self.log.warning(f"Confusion Matrix {capture_uuid} not found for {self.model_name}!")
return None

def predictions(self, capture_uuid: str = "training_holdout") -> Union[pd.DataFrame, None]:
def get_predictions(self, capture_uuid: str = "training_holdout") -> Union[pd.DataFrame, None]:
"""Retrieve the predictions for this model
Args:
Expand All @@ -222,10 +222,10 @@ def predictions(self, capture_uuid: str = "training_holdout") -> Union[pd.DataFr
pd.DataFrame: DataFrame of the Predictions (might be None)
"""
# Grab the metrics from the SageWorks Metadata (try inference first, then training)
inference_preds = self.inference_predictions(capture_uuid)
inference_preds = self.get_inference_predictions(capture_uuid)
if inference_preds is not None:
return inference_preds
return self.validation_predictions()
return self._get_validation_predictions()

def set_input(self, input: str, force: bool = False):
"""Override: Set the input data for this artifact
Expand Down Expand Up @@ -420,7 +420,7 @@ def details(self, recompute=False) -> dict:
details["predictions"] = None
else:
details["confusion_matrix"] = None
details["predictions"] = self.predictions()
details["predictions"] = self.get_predictions()

# Grab the inference metadata
details["inference_meta"] = self.inference_metadata()
Expand Down Expand Up @@ -733,7 +733,7 @@ def inference_metadata(self, capture_uuid: str = "training_holdout") -> Union[pd
self.log.info(f"Could not find model inference meta at {s3_path}...")
return None

def inference_predictions(self, capture_uuid: str = "training_holdout") -> Union[pd.DataFrame, None]:
def get_inference_predictions(self, capture_uuid: str = "training_holdout") -> Union[pd.DataFrame, None]:
"""Retrieve the captured prediction results for this model
Args:
Expand All @@ -746,7 +746,7 @@ def inference_predictions(self, capture_uuid: str = "training_holdout") -> Union
s3_path = f"{self.endpoint_inference_path}/{capture_uuid}/inference_predictions.csv"
return pull_s3_data(s3_path)

def validation_predictions(self) -> Union[pd.DataFrame, None]:
def _get_validation_predictions(self) -> Union[pd.DataFrame, None]:
"""Internal: Retrieve the captured prediction results for this model
Returns:
Expand Down
12 changes: 6 additions & 6 deletions tests/transforms/model_metrics_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ def test_retrieval_with_capture_uuid():

def test_validation_predictions():
print("\n\n*** Validation Predictions ***")
pprint(model_reg.validation_predictions().head())
pprint(model_class.validation_predictions().head())
pprint(model_reg._get_validation_predictions().head())
pprint(model_class._get_validation_predictions().head())


def test_inference_predictions():
print("\n\n*** Inference Predictions ***")
if model_reg.inference_predictions() is None:
if model_reg.get_inference_predictions() is None:
print(f"Model {model_reg.uuid} has no inference predictions!")
exit(1)
pprint(model_reg.inference_predictions().head())
if model_class.inference_predictions() is None:
pprint(model_reg.get_inference_predictions().head())
if model_class.get_inference_predictions() is None:
print(f"Model {model_class.uuid} has no inference predictions!")
exit(1)
pprint(model_class.inference_predictions().head())
pprint(model_class.get_inference_predictions().head())


def test_confusion_matrix():
Expand Down

0 comments on commit b6bea8b

Please sign in to comment.