diff --git a/src/sageworks/core/artifacts/endpoint_core.py b/src/sageworks/core/artifacts/endpoint_core.py index c5ce19972..b80348bc3 100644 --- a/src/sageworks/core/artifacts/endpoint_core.py +++ b/src/sageworks/core/artifacts/endpoint_core.py @@ -441,6 +441,12 @@ def inference(self, eval_df: pd.DataFrame, capture_uuid: str = None) -> pd.DataF # Get the target column target_column = ModelCore(self.model_name).target() + # Sanity Check that the target column is present + if target_column not in prediction_df.columns: + self.log.warning(f"Target Column {target_column} not found in prediction_df!") + self.log.warning("In order to compute metrics, the target column must be present!") + return prediction_df + # Compute the standard performance metrics for this model model_type = self.model_type() if model_type == ModelType.REGRESSOR.value: