diff --git a/src/sageworks/core/artifacts/model_core.py b/src/sageworks/core/artifacts/model_core.py index 87c67368d..34473ef16 100644 --- a/src/sageworks/core/artifacts/model_core.py +++ b/src/sageworks/core/artifacts/model_core.py @@ -377,6 +377,7 @@ def details(self, recompute=False) -> dict: self.log.info("Recomputing Model Details...") details = self.summary() + details["pipeline"] = self.get_pipeline() details["model_type"] = self.model_type.value details["model_package_group_arn"] = self.group_arn() details["model_package_arn"] = self.model_package_arn() @@ -414,6 +415,19 @@ def details(self, recompute=False) -> dict: # Return the details return details + # Pipeline for this model + def get_pipeline(self) -> str: + """Get the pipeline for this model""" + return self.sageworks_meta().get("sageworks_pipeline") + + def set_pipeline(self, pipeline: str): + """Set the pipeline for this model + + Args: + pipeline (str): Pipeline that was used to create this model + """ + self.upsert_sageworks_meta({"sageworks_pipeline": pipeline}) + def expected_meta(self) -> list[str]: """Metadata we expect to see for this Model when it's ready Returns: