From a003f68314d97c584842f5b0a088557cf200a5a6 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 7 Nov 2023 13:33:20 +0100 Subject: [PATCH 01/12] remove `create_new_model_version` --- .gitignore | 3 + .../artifacts/external_artifact_config.py | 2 +- src/zenml/cli/model.py | 4 +- src/zenml/client.py | 15 +- src/zenml/enums.py | 1 + src/zenml/model/artifact_config.py | 5 +- src/zenml/model/model_config.py | 162 ++++++-------- src/zenml/models/model_models.py | 19 +- src/zenml/new/pipelines/model_utils.py | 9 - src/zenml/new/pipelines/pipeline.py | 151 +++++++------ src/zenml/orchestrators/step_runner.py | 5 +- src/zenml/zen_stores/rest_zen_store.py | 4 +- src/zenml/zen_stores/sql_zen_store.py | 11 +- src/zenml/zen_stores/zen_store_interface.py | 5 +- tests/integration/functional/cli/conftest.py | 2 +- .../functional/model/test_artifact_config.py | 59 ++--- .../functional/model/test_model_config.py | 29 +-- .../pipelines/test_pipeline_config.py | 7 +- .../steps/test_external_artifact.py | 31 +-- .../functional/steps/test_model_config.py | 209 ++++++------------ tests/unit/model/test_model_config_init.py | 40 +--- 21 files changed, 294 insertions(+), 479 deletions(-) diff --git a/.gitignore b/.gitignore index 632b6ec76b4..d5008461fb0 100644 --- a/.gitignore +++ b/.gitignore @@ -197,3 +197,6 @@ zenml_tutorial/ mlstacks_reset.sh .local/ + +# exclude installed dashboard folder +src/zenml/zen_server/dashboard diff --git a/src/zenml/artifacts/external_artifact_config.py b/src/zenml/artifacts/external_artifact_config.py index 4c32493472e..0547c1b2b83 100644 --- a/src/zenml/artifacts/external_artifact_config.py +++ b/src/zenml/artifacts/external_artifact_config.py @@ -113,7 +113,7 @@ def _get_artifact_from_model( name=self.model_name, version=self.model_version, ) - model_version = model_config._get_model_version() + model_version = model_config.get_or_create_model_version() for artifact_getter in [ model_version.get_artifact_object, diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index 5f72018e43d..8784691210f 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -511,7 +511,7 @@ def _print_artifacts_links_generic( """ model_version = Client().get_model_version( model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=None + model_version_name_or_number_or_id=ModelStages.LATEST if model_version_name_or_number_or_id == "0" else model_version_name_or_number_or_id, ) @@ -665,7 +665,7 @@ def list_model_version_pipeline_runs( """ model_version = Client().get_model_version( model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=None + model_version_name_or_number_or_id=ModelStages.LATEST if model_version_name_or_number_or_id == "0" else model_version_name_or_number_or_id, ) diff --git a/src/zenml/client.py b/src/zenml/client.py index 4b66dbe9a98..7eb021bec46 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5077,16 +5077,13 @@ def delete_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], ) -> ModelVersionResponseModel: """Get an existing model version from Model Control Plane. Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped latest version will be retrieved. Returns: The model version of interest. @@ -5145,16 +5142,13 @@ def list_model_version_artifact_links( self, model_name_or_id: Union[str, UUID], model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], ) -> Page[ModelVersionArtifactResponseModel]: """Get model version to artifact links by filter in Model Control Plane. Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped latest version will be retrieved. model_version_artifact_link_filter_model: All filter parameters including pagination params. @@ -5181,16 +5175,13 @@ def list_model_version_pipeline_run_links( self, model_name_or_id: Union[str, UUID], model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], ) -> Page[ModelVersionPipelineRunResponseModel]: """Get all model version to pipeline run links by filter. Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped latest version will be retrieved. model_version_pipeline_run_link_filter_model: All filter parameters including pagination params. diff --git a/src/zenml/enums.py b/src/zenml/enums.py index 1835daf887d..73eb2bfd6ac 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -291,3 +291,4 @@ class ModelStages(StrEnum): STAGING = "staging" PRODUCTION = "production" ARCHIVED = "archived" + LATEST = "latest" diff --git a/src/zenml/model/artifact_config.py b/src/zenml/model/artifact_config.py index e9ad32b2f45..21cbca1f580 100644 --- a/src/zenml/model/artifact_config.py +++ b/src/zenml/model/artifact_config.py @@ -35,7 +35,7 @@ class ArtifactConfig(BaseModel): model_name: The name of the model to link artifact to. model_version: The identifier of the model version to link artifact to. It can be exact version ("23"), exact version number (42), stage - (ModelStages.PRODUCTION) or None for the latest version. + (ModelStages.PRODUCTION) or ModelStages.LATEST for the latest version. model_stage: The stage of the model version to link artifact to. artifact_name: The override name of a link instead of an artifact name. overwrite: Whether to overwrite an existing link or create new versions. @@ -81,7 +81,6 @@ def _model_config(self) -> "ModelConfig": on_the_fly_config = ModelConfig( name=self.model_name, version=self.model_version, - create_new_model_version=False, ) return on_the_fly_config @@ -120,7 +119,7 @@ def _link_to_model_version( # Create a ZenML client client = Client() - model_version = model_config._get_model_version() + model_version = model_config.get_or_create_model_version() artifact_name = self.artifact_name if artifact_name is None: diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 4fc00954e35..ada80b00eb5 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -51,10 +51,8 @@ class ModelConfig(BaseModel): ethics: The ethical implications of the model. tags: Tags associated with the model. version: The model version name, number or stage is optional and points model context - to a specific version/stage. If skipped and `create_new_model_version` is False - - latest model version will be used. + to a specific version/stage. If skipped new model version will be created. version_description: The description of the model version. - create_new_model_version: Whether to create a new model version during execution save_models_to_registry: Whether to save all ModelArtifacts to Model Registry, if available in active stack. delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. @@ -71,11 +69,11 @@ class ModelConfig(BaseModel): tags: Optional[List[str]] version: Optional[Union[ModelStages, int, str]] version_description: Optional[str] - create_new_model_version: bool = False save_models_to_registry: bool = True delete_new_version_on_failure: bool = True suppress_class_validation_warnings: bool = False + was_created_in_this_run: bool = False class Config: """Config class.""" @@ -91,43 +89,12 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: Returns: Dict of validated values. - - Raises: - ValueError: If validation failed on one of the checks. """ - create_new_model_version = values.get( - "create_new_model_version", False - ) suppress_class_validation_warnings = values.get( "suppress_class_validation_warnings", False ) version = values.get("version", None) - if create_new_model_version: - misuse_message = ( - "`version` set to {set} cannot be used with `create_new_model_version`." - "You can leave it default or set to a non-stage and non-numeric string.\n" - "Examples:\n" - " - `version` set to 1 or '1' is interpreted as a version number\n" - " - `version` set to 'production' is interpreted as a stage\n" - " - `version` set to 'my_first_version_in_2023' is a valid version to be created\n" - " - `version` set to 'My Second Version!' is a valid version to be created\n" - ) - if isinstance(version, ModelStages) or version in [ - stage.value for stage in ModelStages - ]: - raise ValueError( - misuse_message.format(set="a `ModelStages` instance") - ) - if str(version).isnumeric(): - raise ValueError(misuse_message.format(set="a numeric value")) - if version is None: - if not suppress_class_validation_warnings: - logger.info( - "Creation of new model version was requested, but no version name was explicitly provided. " - f"Setting `version` to `{RUNNING_MODEL_VERSION}`." - ) - values["version"] = RUNNING_MODEL_VERSION if ( version in [stage.value for stage in ModelStages] and not suppress_class_validation_warnings @@ -151,12 +118,15 @@ def _validate_config_in_runtime(self) -> None: """ try: model_version = self._get_model_version() - if self.create_new_model_version: + if self.version is None or self.version == RUNNING_MODEL_VERSION: + self.version = RUNNING_MODEL_VERSION for run_name, run in model_version.pipeline_runs.items(): if run.status == ExecutionStatus.RUNNING: raise RuntimeError( - f"New model version was requested, but pipeline run `{run_name}` " - f"is still running with version `{model_version.name}`." + f"New unnamed model version was requested, " + f"but pipeline run `{run_name}` have not finished yet. " + "This run also operates with unnamed model version - " + "new run will be stopped to prevent unexpected behavior." ) if self.delete_new_version_on_failure: @@ -205,43 +175,6 @@ def get_or_create_model(self) -> "ModelResponseModel": return model - def _create_model_version( - self, model: "ModelResponseModel" - ) -> "ModelVersionResponseModel": - """This method creates a model version for Model Control Plane. - - Args: - model: The model containing the model version. - - Returns: - The model version based on configuration. - """ - from zenml.client import Client - from zenml.models.model_models import ModelVersionRequestModel - - zenml_client = Client() - model_version_request = ModelVersionRequestModel( - user=zenml_client.active_user.id, - workspace=zenml_client.active_workspace.id, - name=self.version, - description=self.version_description, - model=model.id, - ) - mv_request = ModelVersionRequestModel.parse_obj(model_version_request) - try: - mv = zenml_client.get_model_version( - model_name_or_id=self.name, - model_version_name_or_number_or_id=self.version, - ) - model_version = mv - except KeyError: - model_version = zenml_client.create_model_version( - model_version=mv_request - ) - logger.info(f"New model version `{self.version}` was created.") - - return model_version - def _get_model_version(self) -> "ModelVersionResponseModel": """This method gets a model version from Model Control Plane. @@ -251,19 +184,11 @@ def _get_model_version(self) -> "ModelVersionResponseModel": from zenml.client import Client zenml_client = Client() - if self.version is None: - # raise if not found - model_version = zenml_client.get_model_version( - model_name_or_id=self.name - ) - else: - # by version name or stage or number - # raise if not found - model_version = zenml_client.get_model_version( - model_name_or_id=self.name, - model_version_name_or_number_or_id=self.version, - ) - return model_version + return zenml_client.get_model_version( + model_name_or_id=self.name, + model_version_name_or_number_or_id=self.version + or RUNNING_MODEL_VERSION, + ) def get_or_create_model_version(self) -> "ModelVersionResponseModel": """This method should get or create a model and a model version from Model Control Plane. @@ -273,23 +198,47 @@ def get_or_create_model_version(self) -> "ModelVersionResponseModel": Model Version returned by this method is resolved based on model configuration: - If there is an existing model version leftover from the previous failed run with - `delete_new_version_on_failure` is set to False and `create_new_model_version` is True, + `delete_new_version_on_failure` is set to False and `version` is None, leftover model version will be reused. - - Otherwise if `create_new_model_version` is True, a new model version is created. - - If `create_new_model_version` is False a model version will be fetched based on the version: - - If `version` is not set, the latest model version will be fetched. + - Otherwise if `version` is None, a new model version is created. + - If `version` is not None a model version will be fetched based on the version: + - If `version` is set to an integer or digit string, the model version with the matching number will be fetched. - If `version` is set to a string, the model version with the matching version will be fetched. - If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched. Returns: The model version based on configuration. """ + from zenml.client import Client + from zenml.models.model_models import ModelVersionRequestModel + model = self.get_or_create_model() - if self.create_new_model_version: - mv = self._create_model_version(model) - else: - mv = self._get_model_version() - return mv + + if self.version is None: + logger.info( + "Creation of new model version was requested, but no version name was explicitly provided. " + f"Setting `version` to `{RUNNING_MODEL_VERSION}`." + ) + self.version = RUNNING_MODEL_VERSION + + zenml_client = Client() + model_version_request = ModelVersionRequestModel( + user=zenml_client.active_user.id, + workspace=zenml_client.active_workspace.id, + name=self.version, + description=self.version_description, + model=model.id, + ) + mv_request = ModelVersionRequestModel.parse_obj(model_version_request) + try: + model_version = self._get_model_version() + except KeyError: + model_version = zenml_client.create_model_version( + model_version=mv_request + ) + logger.info(f"New model version `{self.version}` was created.") + self.was_created_in_this_run = True + return model_version def _merge(self, model_config: "ModelConfig") -> None: self.license = self.license or model_config.license @@ -307,3 +256,22 @@ def _merge(self, model_config: "ModelConfig") -> None: self.delete_new_version_on_failure &= ( model_config.delete_new_version_on_failure ) + + def __hash__(self) -> int: + """Get hash of the `ModelConfig`. + + Returns: + Hash function results + """ + return hash( + "::".join( + ( + str(v) + for v in ( + self.name, + self.version, + self.delete_new_version_on_failure, + ) + ) + ) + ) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 0f11176ac20..47c7a42b08d 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -659,27 +659,24 @@ def versions(self) -> List[ModelVersionResponseModel]: ) def get_version( - self, version: Optional[Union[str, int, ModelStages]] = None + self, version: Union[str, int, ModelStages] ) -> ModelVersionResponseModel: """Get specific version of the model. Args: - version: version name, number, stage or None for latest version. + version: version name, number, stage Returns: The requested model version. """ from zenml.client import Client - if version is None: - return Client().get_model_version(model_name_or_id=self.name) - else: - return Client().get_model_version( - model_name_or_id=self.name, - model_version_name_or_number_or_id=getattr( - version, "value", version - ), - ) + return Client().get_model_version( + model_name_or_id=self.name, + model_version_name_or_number_or_id=getattr( + version, "value", version + ), + ) class ModelFilterModel(WorkspaceScopedFilterModel): diff --git a/src/zenml/new/pipelines/model_utils.py b/src/zenml/new/pipelines/model_utils.py index 2ff509adaaf..63a54c7d12e 100644 --- a/src/zenml/new/pipelines/model_utils.py +++ b/src/zenml/new/pipelines/model_utils.py @@ -64,18 +64,9 @@ def update_request( Args: model_config: Model Config Model object. requester: Requester of a new model version. - - Raises: - ValueError: If the model version name is configured differently by different requesters. """ self.requesters.append(requester) if self._model_config is None: self._model_config = model_config - if self._model_config.version != model_config.version: - raise ValueError( - f"A mismatch of `version` name in model configurations provided for `{model_config.name} detected." - "Since a new model version is requested for this model, all `version` names must match or left default." - ) - self._model_config._merge(model_config) diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 0a24bd341b7..485c361a55f 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -638,9 +638,6 @@ def _run( stack.validate() new_version_requests = self.get_new_version_requests(deployment) - deployment = self.update_new_versions_requests( - deployment, new_version_requests - ) local_repo_context = ( code_repository_utils.find_active_code_repository() @@ -805,7 +802,7 @@ def log_pipeline_deployment_metadata( def get_new_version_requests( self, deployment: "PipelineDeploymentBaseModel" - ) -> Dict[str, NewModelVersionRequest]: + ) -> Dict[Tuple[str, Optional[str]], NewModelVersionRequest]: """Get the running versions of the models that are used in the pipeline run. Args: @@ -815,9 +812,9 @@ def get_new_version_requests( A dict of new model version request objects. """ new_versions_requested: Dict[ - str, NewModelVersionRequest + Tuple[str, Optional[str]], NewModelVersionRequest ] = defaultdict(NewModelVersionRequest) - other_model_configs: List["ModelConfig"] = [] + other_model_configs: Set["ModelConfig"] = set() all_steps_have_own_config = True for step in deployment.step_configurations.values(): step_model_config = step.config.model_config @@ -826,33 +823,71 @@ def get_new_version_requests( and step.config.model_config is not None ) if step_model_config: - if step_model_config.create_new_model_version: + try: + step_model_config._get_model_version() + version_existed = True + except KeyError: + version_existed = False + if not version_existed: new_versions_requested[ - step_model_config.name + ( + step_model_config.name, + str(step_model_config.version) or None, + ) ].update_request( step_model_config, NewModelVersionRequest.Requester( source="step", name=step.config.name ), ) + if ( + step_model_config.version is None + and ( + step_model_config.name, + str(step_model_config.version) or None, + ) + in new_versions_requested + ): + step_model_config.version = ( + constants.RUNNING_MODEL_VERSION + ) else: - other_model_configs.append(step_model_config) + other_model_configs.add(step_model_config) if not all_steps_have_own_config: pipeline_model_config = ( deployment.pipeline_configuration.model_config ) if pipeline_model_config: - if pipeline_model_config.create_new_model_version: + try: + pipeline_model_config._get_model_version() + version_existed = True + except KeyError: + version_existed = False + if not version_existed: new_versions_requested[ - pipeline_model_config.name + ( + pipeline_model_config.name, + str(pipeline_model_config.version) or None, + ) ].update_request( pipeline_model_config, NewModelVersionRequest.Requester( source="pipeline", name=self.name ), ) + if ( + pipeline_model_config.version is None + and ( + pipeline_model_config.name, + str(pipeline_model_config.version) or None, + ) + in new_versions_requested + ): + pipeline_model_config.version = ( + constants.RUNNING_MODEL_VERSION + ) else: - other_model_configs.append(pipeline_model_config) + other_model_configs.add(pipeline_model_config) elif deployment.pipeline_configuration.model_config is not None: logger.warning( f"ModelConfig of pipeline `{self.name}` is overridden in all steps. " @@ -867,94 +902,70 @@ def get_new_version_requests( def _validate_new_version_requests( self, - new_versions_requested: Dict[str, NewModelVersionRequest], + new_versions_requested: Dict[ + Tuple[str, Optional[str]], NewModelVersionRequest + ], ) -> None: """Validate the model configurations that are used in the pipeline run. Args: new_versions_requested: A dict of new model version request objects. """ - for model_name, data in new_versions_requested.items(): + for key, data in new_versions_requested.items(): + model_name, model_version = key if len(data.requesters) > 1: logger.warning( - f"New version of model `{model_name}` requested in multiple decorators:\n" - f"{data.requesters}\n We recommend that `create_new_model_version` is configured " - "only in one place of the pipeline." + f"New version of model version `{model_name}::{model_version or 'NEW'}` " + f"requested in multiple decorators:\n{data.requesters}\n We recommend " + "that `ModelConfig` requesting new version is configured only in one " + "place of the pipeline." ) data.model_config._validate_config_in_runtime() - def update_new_versions_requests( - self, - deployment: "PipelineDeploymentBaseModel", - new_version_requests: Dict[str, NewModelVersionRequest], - ) -> "PipelineDeploymentBaseModel": - """Update model configurations that are used in the pipeline run. - - This method is updating create_new_model_version for all model configurations in the pipeline, - who deal with model name with existing request to create a new mode version. - - Args: - deployment: The pipeline deployment configuration. - new_version_requests: Dict of models requesting new versions and their definition points. - - Returns: - Updated pipeline deployment configuration. - """ - for step_name in deployment.step_configurations: - step_model_config = deployment.step_configurations[ - step_name - ].config.model_config - if ( - step_model_config is not None - and step_model_config.name in new_version_requests - ): - step_model_config.version = new_version_requests[ - step_model_config.name - ].model_config.version - step_model_config.create_new_model_version = True - pipeline_model_config = deployment.pipeline_configuration.model_config - if ( - pipeline_model_config is not None - and pipeline_model_config.name in new_version_requests - ): - pipeline_model_config.version = new_version_requests[ - pipeline_model_config.name - ].model_config.version - pipeline_model_config.create_new_model_version = True - return deployment - def register_running_versions( - self, new_version_requests: Dict[str, NewModelVersionRequest] + self, + new_versions_requested: Dict[ + Tuple[str, Optional[str]], NewModelVersionRequest + ], ) -> None: """Registers the running versions of the models used in the given pipeline run. Args: - new_version_requests: Dict of models requesting new versions and their definition points. + new_versions_requested: Dict of models requesting new versions and their definition points. """ - for model_name, new_version_request in new_version_requests.items(): - if new_version_request.model_config.delete_new_version_on_failure: - mv = Client().get_model_version( - model_name_or_id=model_name, - model_version_name_or_number_or_id=new_version_request.model_config.version, - ) - mv._update_default_running_version_name() + for key, new_version_request in new_versions_requested.items(): + model_name, model_version = key + if not model_version: + if ( + new_version_request.model_config.delete_new_version_on_failure + ): + mv = Client().get_model_version( + model_name_or_id=model_name, + model_version_name_or_number_or_id=constants.RUNNING_MODEL_VERSION, + ) + mv._update_default_running_version_name() def delete_running_versions_without_recovery( - self, new_version_requests: Dict[str, NewModelVersionRequest] + self, + new_versions_requested: Dict[ + Tuple[str, Optional[str]], NewModelVersionRequest + ], ) -> None: """Delete the running versions of the models without `restore` after fail. Args: - new_version_requests: Dict of models requesting new versions and their definition points. + new_versions_requested: Dict of models requesting new versions and their definition points. """ - for model_name, new_version_request in new_version_requests.items(): + for key, new_version_request in new_versions_requested.items(): + model_name, model_version = key if ( new_version_request.model_config.delete_new_version_on_failure and new_version_request.model_config.version is not None ): model = Client().get_model_version( model_name_or_id=model_name, - model_version_name_or_number_or_id=new_version_request.model_config.version, + model_version_name_or_number_or_id=model_version + or constants.RUNNING_MODEL_VERSION, ) Client().delete_model_version( model_name_or_id=model_name, diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index d9385ab3a6b..46957b526db 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -695,7 +695,7 @@ def _get_model_versions_from_artifacts( if artifact_config is not None: try: model_version = ( - artifact_config._model_config._get_model_version() + artifact_config._model_config.get_or_create_model_version() ) models.add((model_version.model.id, model_version.id)) except RuntimeError: @@ -719,6 +719,7 @@ def _get_model_versions_from_external_artifacts( if ( external_artifact.model_artifact_name is not None and external_artifact.model_name is not None + and external_artifact.model_version is not None ): model_version = client.get_model_version( model_name_or_id=external_artifact.model_name, @@ -735,7 +736,7 @@ def _get_model_versions_from_config(self) -> Set[Tuple[UUID, UUID]]: """ try: mc = get_step_context().model_config - model_version = mc._get_model_version() + model_version = mc.get_or_create_model_version() return {(model_version.model.id, model_version.id)} except StepContextError: return set() diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 1c74dc1e26a..fa225620f27 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2430,9 +2430,7 @@ def delete_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], ) -> ModelVersionResponseModel: """Get an existing model version. diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index f4a504b0163..9f4d9b0704f 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -58,7 +58,6 @@ from zenml.config.store_config import StoreConfiguration from zenml.constants import ( ENV_ZENML_DISABLE_DATABASE_MIGRATION, - LATEST_MODEL_VERSION_PLACEHOLDER, ) from zenml.enums import ( LoggingLevels, @@ -5739,16 +5738,13 @@ def create_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], ) -> ModelVersionResponseModel: """Get an existing model version. Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped latest version will be retrieved. Returns: The model version of interest. @@ -5762,9 +5758,8 @@ def get_model_version( ModelVersionSchema.model_id == model.id ) if ( - model_version_name_or_number_or_id is None - or model_version_name_or_number_or_id - == LATEST_MODEL_VERSION_PLACEHOLDER + str(model_version_name_or_number_or_id) + == ModelStages.LATEST.value ): query = query.order_by(ModelVersionSchema.created.desc()) # type: ignore[attr-defined] elif model_version_name_or_number_or_id in [ diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index bc0512a9340..a72ed70812e 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1810,16 +1810,13 @@ def delete_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Optional[ - Union[str, int, UUID, ModelStages] - ] = None, + model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], ) -> ModelVersionResponseModel: """Get an existing model version. Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped latest version will be retrieved. Returns: The model version of interest. diff --git a/tests/integration/functional/cli/conftest.py b/tests/integration/functional/cli/conftest.py index 05ca3474811..058a10cd72c 100644 --- a/tests/integration/functional/cli/conftest.py +++ b/tests/integration/functional/cli/conftest.py @@ -90,7 +90,7 @@ def step_2() -> ( @pipeline( - model_config=ModelConfig(name=NAME, create_new_model_version=True), + model_config=ModelConfig(name=NAME), name=NAME, ) def pipeline(): diff --git a/tests/integration/functional/model/test_artifact_config.py b/tests/integration/functional/model/test_artifact_config.py index 8e56ba1e6d7..e4389fc6186 100644 --- a/tests/integration/functional/model/test_artifact_config.py +++ b/tests/integration/functional/model/test_artifact_config.py @@ -38,13 +38,13 @@ MODEL_NAME = "foo" -@step(model_config=ModelConfig(name=MODEL_NAME, create_new_model_version=True)) +@step(model_config=ModelConfig(name=MODEL_NAME)) def single_output_step_from_context() -> Annotated[int, ArtifactConfig()]: """Untyped single output linked as Artifact from step context.""" return 1 -@step(model_config=ModelConfig(name=MODEL_NAME, create_new_model_version=True)) +@step(model_config=ModelConfig(name=MODEL_NAME)) def single_output_step_from_context_model() -> ( Annotated[int, ModelArtifactConfig(save_to_model_registry=True)] ): @@ -52,7 +52,7 @@ def single_output_step_from_context_model() -> ( return 1 -@step(model_config=ModelConfig(name=MODEL_NAME, create_new_model_version=True)) +@step(model_config=ModelConfig(name=MODEL_NAME)) def single_output_step_from_context_deployment() -> ( Annotated[int, DeploymentArtifactConfig()] ): @@ -83,7 +83,7 @@ def test_link_minimalistic(): model = client.get_model(MODEL_NAME) assert model.name == MODEL_NAME - mv = client.get_model_version(MODEL_NAME) + mv = client.get_model_version(MODEL_NAME, ModelStages.LATEST) assert mv.name == "1" links = client.list_model_version_artifact_links( model_name_or_id=model.id, @@ -115,7 +115,7 @@ def test_link_minimalistic(): assert one_is_artifact -@step(model_config=ModelConfig(name=MODEL_NAME, create_new_model_version=True)) +@step(model_config=ModelConfig(name=MODEL_NAME)) def multi_named_output_step_from_context() -> ( Tuple[ Annotated[int, "1", ArtifactConfig()], @@ -144,7 +144,7 @@ def test_link_multiple_named_outputs(): model = client.get_model(MODEL_NAME) assert model.name == MODEL_NAME - mv = client.get_model_version(MODEL_NAME) + mv = client.get_model_version(MODEL_NAME, ModelStages.LATEST) assert mv.name == "1" al = client.list_model_version_artifact_links( model_name_or_id=model.id, @@ -161,7 +161,7 @@ def test_link_multiple_named_outputs(): assert {al.name for al in al} == {"1", "2", "3"} -@step(model_config=ModelConfig(name=MODEL_NAME, create_new_model_version=True)) +@step(model_config=ModelConfig(name=MODEL_NAME)) def multi_named_output_step_not_tracked() -> ( Tuple[ Annotated[int, "1"], @@ -190,7 +190,7 @@ def test_link_multiple_named_outputs_without_links(): model = client.get_model(MODEL_NAME) assert model.name == MODEL_NAME - mv = client.get_model_version(MODEL_NAME) + mv = client.get_model_version(MODEL_NAME, ModelStages.LATEST) assert mv.name == "1" artifact_links = client.list_model_version_artifact_links( model_name_or_id=model.id, @@ -496,7 +496,7 @@ def single_output_step_with_versioning() -> ( @pipeline( enable_cache=False, - model_config=ModelConfig(name=MODEL_NAME, stage=ModelStages.PRODUCTION), + model_config=ModelConfig(name=MODEL_NAME, version=ModelStages.PRODUCTION), ) def simple_pipeline_with_versioning(): """Single output with overwrite disabled and step context.""" @@ -574,7 +574,7 @@ def step_with_manual_linkage() -> ( @pipeline( enable_cache=False, - model_config=ModelConfig(name=MODEL_NAME), + model_config=ModelConfig(name=MODEL_NAME, version=ModelStages.LATEST), ) def simple_pipeline_with_manual_linkage(): """Multi output linking by function.""" @@ -594,7 +594,7 @@ def step_with_manual_and_implicit_linkage() -> ( @pipeline( enable_cache=False, - model_config=ModelConfig(name=MODEL_NAME), + model_config=ModelConfig(name=MODEL_NAME, version=ModelStages.LATEST), ) def simple_pipeline_with_manual_and_implicit_linkage(): """Multi output: 2 is linked by function, 1 is linked implicitly.""" @@ -697,27 +697,6 @@ def simple_pipeline_with_manual_linkage_fail_on_override(): def test_link_with_manual_linkage_fail_on_override(): """Test that step fails on manual linkage, cause Annotated provided.""" with model_killer(): - client = Client() - user = client.active_user.id - ws = client.active_workspace.id - - # manual creation needed, as we work with specific versions - model = client.create_model( - ModelRequestModel( - name=MODEL_NAME, - user=user, - workspace=ws, - ) - ) - client.create_model_version( - ModelVersionRequestModel( - user=user, - workspace=ws, - name="good_one", - model=model.id, - ) - ) - simple_pipeline_with_manual_linkage_fail_on_override() @@ -745,7 +724,9 @@ def simple_pipeline_with_manual_linkage_flexible_config( ArtifactConfig( model_name=MODEL_NAME, model_version=ModelStages.PRODUCTION ), - ArtifactConfig(model_name=MODEL_NAME), + ArtifactConfig( + model_name=MODEL_NAME, model_version=ModelStages.LATEST + ), ArtifactConfig(model_name=MODEL_NAME, model_version=1), ), ids=("exact_version", "exact_stage", "latest_version", "exact_number"), @@ -824,7 +805,7 @@ def test_artifacts_linked_from_cache_steps(): """Test that artifacts are linked from cache steps.""" @pipeline( - model_config=ModelConfig(name="foo", create_new_model_version=True), + model_config=ModelConfig(name="foo"), enable_cache=False, ) def _inner_pipeline(force_disable_cache: bool = False): @@ -844,7 +825,7 @@ def _inner_pipeline(force_disable_cache: bool = False): for i in range(1, 3): fake_version = ModelConfig( - name="bar", create_new_model_version=True + name="bar" ).get_or_create_model_version() _inner_pipeline(i != 1) @@ -888,7 +869,7 @@ def test_artifacts_linked_from_cache_steps_same_id(): """ @pipeline( - model_config=ModelConfig(name="foo", create_new_model_version=True), + model_config=ModelConfig(name="foo"), enable_cache=False, ) def _inner_pipeline(force_disable_cache: bool = False): @@ -901,8 +882,8 @@ def _inner_pipeline(force_disable_cache: bool = False): client = Client() for i in range(1, 3): - ModelConfig( - name="bar", create_new_model_version=True + fake_version = ModelConfig( + name="bar" ).get_or_create_model_version() _inner_pipeline(i != 1) @@ -922,3 +903,5 @@ def _inner_pipeline(force_disable_cache: bool = False): ) == 1 ), f"Failed on {i} run" + + fake_version._update_default_running_version_name() diff --git a/tests/integration/functional/model/test_model_config.py b/tests/integration/functional/model/test_model_config.py index 3a38052af91..ce3cbe4da9c 100644 --- a/tests/integration/functional/model/test_model_config.py +++ b/tests/integration/functional/model/test_model_config.py @@ -99,7 +99,7 @@ def test_model_exists(self): def test_model_create_model_and_version(self): """Test if model and version are created, not existing before.""" with ModelContext(create_model=False): - mc = ModelConfig(name=MODEL_NAME, create_new_model_version=True) + mc = ModelConfig(name=MODEL_NAME) with mock.patch("zenml.model.model_config.logger.info") as logger: mv = mc.get_or_create_model_version() logger.assert_called() @@ -123,14 +123,14 @@ def test_model_fetch_model_and_version_by_number_not_found(self): with ModelContext(): mc = ModelConfig(name=MODEL_NAME, version="1.0.0") with pytest.raises(KeyError): - mc.get_or_create_model_version() + mc._get_model_version() def test_model_fetch_model_and_version_by_stage(self): """Test model and model version retrieval by exact stage number.""" with ModelContext( model_version="1.0.0", stage=ModelStages.PRODUCTION ) as (model, mv): - mc = ModelConfig(name=MODEL_NAME, stage=ModelStages.PRODUCTION) + mc = ModelConfig(name=MODEL_NAME, version=ModelStages.PRODUCTION) with mock.patch( "zenml.model.model_config.logger.warning" ) as logger: @@ -144,33 +144,16 @@ def test_model_fetch_model_and_version_by_stage_not_found(self): with ModelContext(model_version="1.0.0"): mc = ModelConfig(name=MODEL_NAME, version=ModelStages.PRODUCTION) with pytest.raises(KeyError): - mc.get_or_create_model_version() + mc._get_model_version() def test_model_fetch_model_and_version_latest(self): """Test model and model version retrieval by latest version.""" with ModelContext(model_version="1.0.0"): - mc = ModelConfig(name=MODEL_NAME) + mc = ModelConfig(name=MODEL_NAME, version=ModelStages.LATEST) mv = mc.get_or_create_model_version() assert mv.name == "1.0.0" - def test_init_create_new_version_with_version_fails(self): - """Test that it is not possible to use `version` as ModelStages and `create_new_model_version` together.""" - with pytest.raises(ValueError): - ModelConfig( - name=MODEL_NAME, - version=ModelStages.PRODUCTION, - create_new_model_version=True, - ) - - mc = ModelConfig( - name=MODEL_NAME, - create_new_model_version=True, - ) - assert mc.name == MODEL_NAME - assert mc.create_new_model_version - assert mc.version == RUNNING_MODEL_VERSION - def test_init_stage_logic(self): """Test that if version is set to string contained in ModelStages user is informed about it.""" with mock.patch("zenml.model.model_config.logger.info") as logger: @@ -189,7 +172,6 @@ def test_recovery_flow(self): with ModelContext(): mc = ModelConfig( name=MODEL_NAME, - create_new_model_version=True, delete_new_version_on_failure=False, ) mv1 = mc.get_or_create_model_version() @@ -197,7 +179,6 @@ def test_recovery_flow(self): mc = ModelConfig( name=MODEL_NAME, - create_new_model_version=True, delete_new_version_on_failure=False, ) mv2 = mc.get_or_create_model_version() diff --git a/tests/integration/functional/pipelines/test_pipeline_config.py b/tests/integration/functional/pipelines/test_pipeline_config.py index 5eccf78bfc3..fd6ed03f548 100644 --- a/tests/integration/functional/pipelines/test_pipeline_config.py +++ b/tests/integration/functional/pipelines/test_pipeline_config.py @@ -29,7 +29,6 @@ def assert_model_config_step(): assert model_config is not None assert model_config.name == "foo" assert model_config.version == RUNNING_MODEL_VERSION - assert model_config.create_new_model_version assert not model_config.delete_new_version_on_failure assert model_config.description == "description" assert model_config.license == "MIT" @@ -54,7 +53,6 @@ def test_pipeline_with_model_config_from_yaml(clean_workspace, tmp_path): """Test that the pipeline can be configured with a model config from a yaml file.""" model_config = ModelConfig( name="foo", - create_new_model_version=True, delete_new_version_on_failure=False, description="description", license="MIT", @@ -117,7 +115,6 @@ def test_pipeline_config_from_file_not_overridden_for_model_config( """ initial_model_config = ModelConfig( name="bar", - create_new_model_version=True, ) config_path = tmp_path / "config.yaml" @@ -138,7 +135,6 @@ def assert_model_config_pipeline(): p.configure( model_config=ModelConfig( name="foo", - create_new_model_version=True, delete_new_version_on_failure=False, description="description", license="MIT", @@ -156,8 +152,7 @@ def assert_model_config_pipeline(): assert p.configuration.model_config is not None assert p.configuration.model_config.name == "foo" - assert p.configuration.model_config.version == RUNNING_MODEL_VERSION - assert p.configuration.model_config.create_new_model_version + assert p.configuration.model_config.version is None assert not p.configuration.model_config.delete_new_version_on_failure assert p.configuration.model_config.description == "description" assert p.configuration.model_config.license == "MIT" diff --git a/tests/integration/functional/steps/test_external_artifact.py b/tests/integration/functional/steps/test_external_artifact.py index 2f6b92f77d0..72216c16200 100644 --- a/tests/integration/functional/steps/test_external_artifact.py +++ b/tests/integration/functional/steps/test_external_artifact.py @@ -20,6 +20,7 @@ from zenml import pipeline, step from zenml.artifacts.external_artifact import ExternalArtifact from zenml.client import Client +from zenml.enums import ModelStages from zenml.model import ArtifactConfig, ModelConfig @@ -45,7 +46,11 @@ def producer_pipeline(run_count: int): producer(run_count) -@pipeline(name="bar", enable_cache=False, model_config=ModelConfig(name="foo")) +@pipeline( + name="bar", + enable_cache=False, + model_config=ModelConfig(name="foo", version=ModelStages.LATEST), +) def consumer_pipeline( model_artifact_version: int, model_artifact_pipeline_name: str = None, @@ -84,7 +89,7 @@ def consumer_pipeline_with_external_artifact_from_another_model( @pipeline( name="bar", enable_cache=False, - model_config=ModelConfig(name="foo", create_new_model_version=True), + model_config=ModelConfig(name="foo"), ) def two_step_producer_pipeline(): producer(1) @@ -102,10 +107,10 @@ def two_step_producer_pipeline(): def test_exchange_of_model_artifacts_between_pipelines(consumer_pipeline): """Test that ExternalArtifact helps to exchange data from Model between pipelines.""" with model_killer(): + producer_pipeline.with_options(model_config=ModelConfig(name="foo"))(1) producer_pipeline.with_options( - model_config=ModelConfig(name="foo", create_new_model_version=True) - )(1) - producer_pipeline.with_options(model_config=ModelConfig(name="foo"))( + model_config=ModelConfig(name="foo", version=ModelStages.LATEST) + )( 2 ) # add to latest version consumer_pipeline(1) @@ -159,10 +164,10 @@ def test_external_artifact_pass_on_name_collision_with_pipeline_and_step(): def test_exchange_of_model_artifacts_between_pipelines_by_model_version_number(): """Test that ExternalArtifact helps to exchange data from Model between pipelines using model version number.""" with model_killer(): + producer_pipeline.with_options(model_config=ModelConfig(name="foo"))(1) producer_pipeline.with_options( - model_config=ModelConfig(name="foo", create_new_model_version=True) - )(1) - producer_pipeline.with_options(model_config=ModelConfig(name="foo"))( + model_config=ModelConfig(name="foo", version=ModelStages.LATEST) + )( 2 ) # add to latest version consumer_pipeline.with_options( @@ -185,13 +190,11 @@ def test_exchange_of_model_artifacts_between_pipelines_by_model_version_number() def test_direct_consumption(model_version_name, expected): """Test that ExternalArtifact can fetch data by full config with model version name/number combinations.""" with model_killer(): + producer_pipeline.with_options(model_config=ModelConfig(name="foo"))( + 42 + ) producer_pipeline.with_options( - model_config=ModelConfig(name="foo", create_new_model_version=True) - )(42) - producer_pipeline.with_options( - model_config=ModelConfig( - name="foo", create_new_model_version=True, version="foo" - ) + model_config=ModelConfig(name="foo", version="foo") )(23) artifact_id = ExternalArtifact( model_name="foo", diff --git a/tests/integration/functional/steps/test_model_config.py b/tests/integration/functional/steps/test_model_config.py index 85081d26695..6db4b86fe79 100644 --- a/tests/integration/functional/steps/test_model_config.py +++ b/tests/integration/functional/steps/test_model_config.py @@ -24,7 +24,7 @@ from zenml.artifacts.external_artifact import ExternalArtifact from zenml.client import Client from zenml.constants import RUNNING_MODEL_VERSION -from zenml.enums import ExecutionStatus +from zenml.enums import ExecutionStatus, ModelStages from zenml.model import ArtifactConfig, ModelConfig, link_output_to_model from zenml.models import ( ModelRequestModel, @@ -46,9 +46,7 @@ def test_model_config_passed_to_step_context_via_step(): @pipeline(name="bar", enable_cache=False) def _simple_step_pipeline(): _assert_that_model_config_set.with_options( - model_config=ModelConfig( - name="foo", create_new_model_version=True - ), + model_config=ModelConfig(name="foo"), )() with model_killer(): @@ -60,7 +58,7 @@ def test_model_config_passed_to_step_context_via_pipeline(): @pipeline( name="bar", - model_config=ModelConfig(name="foo", create_new_model_version=True), + model_config=ModelConfig(name="foo"), enable_cache=False, ) def _simple_step_pipeline(): @@ -75,14 +73,12 @@ def test_model_config_passed_to_step_context_via_step_and_pipeline(): @pipeline( name="bar", - model_config=ModelConfig(name="bar", create_new_model_version=True), + model_config=ModelConfig(name="bar"), enable_cache=False, ) def _simple_step_pipeline(): _assert_that_model_config_set.with_options( - model_config=ModelConfig( - name="foo", create_new_model_version=True - ), + model_config=ModelConfig(name="foo"), )() with model_killer(): @@ -94,30 +90,26 @@ def test_model_config_passed_to_step_context_and_switches(): @pipeline( name="bar", - model_config=ModelConfig(name="bar", create_new_model_version=True), + model_config=ModelConfig(name="bar"), enable_cache=False, ) def _simple_step_pipeline(): # this step will use ModelConfig from itself _assert_that_model_config_set.with_options( - model_config=ModelConfig( - name="foo", create_new_model_version=True - ), + model_config=ModelConfig(name="foo"), )() # this step will use ModelConfig from pipeline _assert_that_model_config_set(name="bar") # and another switch of context _assert_that_model_config_set.with_options( - model_config=ModelConfig( - name="foobar", create_new_model_version=True - ), + model_config=ModelConfig(name="foobar"), )(name="foobar") with model_killer(): _simple_step_pipeline() -@step(model_config=ModelConfig(name="foo", create_new_model_version=True)) +@step(model_config=ModelConfig(name="foo")) def _this_step_creates_a_version(): return 1 @@ -133,9 +125,7 @@ def test_create_new_versions_both_pipeline_and_step(): @pipeline( name="bar", - model_config=ModelConfig( - name="bar", create_new_model_version=True, version_description=desc - ), + model_config=ModelConfig(name="bar", version_description=desc), enable_cache=False, ) def _this_pipeline_creates_a_version(): @@ -149,21 +139,21 @@ def _this_pipeline_creates_a_version(): foo = client.get_model("foo") assert foo.name == "foo" - foo_version = client.get_model_version("foo") + foo_version = client.get_model_version("foo", ModelStages.LATEST) assert foo_version.name == "1" bar = client.get_model("bar") assert bar.name == "bar" - bar_version = client.get_model_version("bar") + bar_version = client.get_model_version("bar", ModelStages.LATEST) assert bar_version.name == "1" assert bar_version.description == desc _this_pipeline_creates_a_version() - foo_version = client.get_model_version("foo") + foo_version = client.get_model_version("foo", ModelStages.LATEST) assert foo_version.name == "2" - bar_version = client.get_model_version("bar") + bar_version = client.get_model_version("bar", ModelStages.LATEST) assert bar_version.name == "2" assert bar_version.description == desc @@ -183,12 +173,12 @@ def _this_pipeline_does_not_create_a_version(): bar = client.get_model("foo") assert bar.name == "foo" - bar_version = client.get_model_version("foo") + bar_version = client.get_model_version("foo", ModelStages.LATEST) assert bar_version.name == "1" _this_pipeline_does_not_create_a_version() - bar_version = client.get_model_version("foo") + bar_version = client.get_model_version("foo", ModelStages.LATEST) assert bar_version.name == "2" @@ -197,7 +187,7 @@ def test_create_new_version_only_in_pipeline(): @pipeline( name="bar", - model_config=ModelConfig(name="bar", create_new_model_version=True), + model_config=ModelConfig(name="bar"), enable_cache=False, ) def _this_pipeline_creates_a_version(): @@ -210,12 +200,12 @@ def _this_pipeline_creates_a_version(): foo = client.get_model("bar") assert foo.name == "bar" - foo_version = client.get_model_version("bar") + foo_version = client.get_model_version("bar", ModelStages.LATEST) assert foo_version.name == "1" _this_pipeline_creates_a_version() - foo_version = client.get_model_version("bar") + foo_version = client.get_model_version("bar", ModelStages.LATEST) assert foo_version.name == "2" @@ -228,7 +218,7 @@ def _this_step_produces_output() -> ( @step def _this_step_tries_to_recover(run_number: int): - mv = get_step_context().model_config._get_model_version() + mv = get_step_context().model_config.get_or_create_model_version() assert ( len(mv.artifact_object_ids["bar::_this_step_produces_output::data"]) == run_number @@ -242,13 +232,11 @@ def _this_step_tries_to_recover(run_number: int): [ ModelConfig( name="foo", - create_new_model_version=True, delete_new_version_on_failure=False, ), ModelConfig( name="foo", version="test running version", - create_new_model_version=True, delete_new_version_on_failure=False, ), ], @@ -299,13 +287,11 @@ def _this_pipeline_will_recover(run_number: int): [ ModelConfig( name="foo", - create_new_model_version=True, delete_new_version_on_failure=True, ), ModelConfig( name="foo", version="test running version", - create_new_model_version=True, delete_new_version_on_failure=True, ), ], @@ -342,7 +328,7 @@ def _this_pipeline_will_not_recover(run_number: int): ) -@step(model_config=ModelConfig(name="foo", create_new_model_version=True)) +@step(model_config=ModelConfig(name="foo")) def _new_version_step(): return 1 @@ -354,7 +340,7 @@ def _no_model_config_step(): @pipeline( enable_cache=False, - model_config=ModelConfig(name="foo", create_new_model_version=True), + model_config=ModelConfig(name="foo"), ) def _new_version_pipeline_overridden_warns(): _new_version_step() @@ -362,7 +348,7 @@ def _new_version_pipeline_overridden_warns(): @pipeline( enable_cache=False, - model_config=ModelConfig(name="foo", create_new_model_version=True), + model_config=ModelConfig(name="foo"), ) def _new_version_pipeline_not_warns(): _no_model_config_step() @@ -381,7 +367,7 @@ def _no_new_version_pipeline_warns_on_steps(): @pipeline( enable_cache=False, - model_config=ModelConfig(name="foo", create_new_model_version=True), + model_config=ModelConfig(name="foo"), ) def _new_version_pipeline_warns_on_steps(): _new_version_step() @@ -399,11 +385,11 @@ def _new_version_pipeline_warns_on_steps(): (_no_new_version_pipeline_not_warns, ""), ( _no_new_version_pipeline_warns_on_steps, - "`create_new_model_version` is configured only in one", + "is configured only in one place of the pipeline", ), ( _new_version_pipeline_warns_on_steps, - "`create_new_model_version` is configured only in one", + "is configured only in one place of the pipeline", ), ], ids=[ @@ -459,21 +445,19 @@ def test_pipeline_run_link_attached_from_pipeline_context(pipeline): run_name=run_name_1, model_config=ModelConfig( name="foo", - create_new_model_version=True, delete_new_version_on_failure=True, ), )() run_name_2 = f"bar_run_{uuid4()}" pipeline.with_options( run_name=run_name_2, - model_config=ModelConfig( - name="foo", - ), + model_config=ModelConfig(name="foo", version=ModelStages.LATEST), )() model = client.get_model("foo") mv = client.get_model_version( model_name_or_id=model.id, + model_version_name_or_number_or_id=ModelStages.LATEST, ) assert len(mv.pipeline_run_ids) == 2 @@ -515,22 +499,18 @@ def test_pipeline_run_link_attached_from_step_context(pipeline): )( ModelConfig( name="foo", - create_new_model_version=True, delete_new_version_on_failure=True, ) ) run_name_2 = f"bar_run_{uuid4()}" pipeline.with_options( run_name=run_name_2, - )( - ModelConfig( - name="foo", - ) - ) + )(ModelConfig(name="foo", version=ModelStages.LATEST)) model = client.get_model("foo") mv = client.get_model_version( model_name_or_id=model.id, + model_version_name_or_number_or_id=ModelStages.LATEST, ) assert len(mv.pipeline_run_ids) == 2 @@ -566,26 +546,32 @@ def _pipeline_run_link_attached_from_artifact_context_multiple_step(): _this_step_has_model_config_on_artifact_level() -@pipeline(enable_cache=False, model_config=ModelConfig(name="pipeline")) +@pipeline( + enable_cache=False, + model_config=ModelConfig(name="pipeline", version=ModelStages.LATEST), +) def _pipeline_run_link_attached_from_mixed_context_single_step(): _this_step_has_model_config_on_artifact_level() _this_step_produces_output() _this_step_produces_output.with_options( - model_config=ModelConfig(name="step"), + model_config=ModelConfig(name="step", version=ModelStages.LATEST), )() -@pipeline(enable_cache=False, model_config=ModelConfig(name="pipeline")) +@pipeline( + enable_cache=False, + model_config=ModelConfig(name="pipeline", version=ModelStages.LATEST), +) def _pipeline_run_link_attached_from_mixed_context_multiple_step(): _this_step_has_model_config_on_artifact_level() _this_step_produces_output() _this_step_produces_output.with_options( - model_config=ModelConfig(name="step"), + model_config=ModelConfig(name="step", version=ModelStages.LATEST), )() _this_step_has_model_config_on_artifact_level() _this_step_produces_output() _this_step_produces_output.with_options( - model_config=ModelConfig(name="step"), + model_config=ModelConfig(name="step", version=ModelStages.LATEST), )() @@ -656,6 +642,7 @@ def test_pipeline_run_link_attached_from_mixed_context(pipeline, model_names): for model in models: mv = client.get_model_version( model_name_or_id=model.id, + model_version_name_or_number_or_id=ModelStages.LATEST, ) assert len(mv.pipeline_run_ids) == 2 assert {run_name for run_name in mv.pipeline_run_ids} == { @@ -669,29 +656,35 @@ def _consumer_step(a: int, b: int): assert a == b -@step(model_config=ModelConfig(name="step", create_new_model_version=True)) +@step(model_config=ModelConfig(name="step")) def _producer_step() -> Tuple[int, int, int]: return 1, 2, 3 @pipeline def _consumer_pipeline_with_step_context(): - _consumer_step.with_options(model_config=ModelConfig(name="step"))( - ExternalArtifact(model_artifact_name="output_0"), 1 - ) + _consumer_step.with_options( + model_config=ModelConfig(name="step", version=ModelStages.LATEST) + )(ExternalArtifact(model_artifact_name="output_0"), 1) @pipeline def _consumer_pipeline_with_artifact_context(): _consumer_step( - ExternalArtifact(model_artifact_name="output_1", model_name="step"), 2 + ExternalArtifact( + model_artifact_name="output_1", + model_name="step", + model_version=ModelStages.LATEST, + ), + 2, ) -@pipeline(model_config=ModelConfig(name="step")) +@pipeline(model_config=ModelConfig(name="step", version=ModelStages.LATEST)) def _consumer_pipeline_with_pipeline_context(): _consumer_step( - ExternalArtifact(model_artifact_name="output_2", model_name="step"), 3 + ExternalArtifact(model_artifact_name="output_2"), + 3, ) @@ -724,6 +717,7 @@ def test_that_consumption_also_registers_run_in_model_version(): model = client.get_model(model_name_or_id="step") mv = client.get_model_version( model_name_or_id=model.id, + model_version_name_or_number_or_id=ModelStages.LATEST, ) assert len(mv.pipeline_run_ids) == 4 assert {run_name for run_name in mv.pipeline_run_ids} == { @@ -738,13 +732,13 @@ def test_that_if_some_steps_request_new_version_but_cached_new_version_is_still_ """Test that if one of the steps requests a new version but was cached a new version is still created for other steps.""" with model_killer(): - @pipeline(model_config=ModelConfig(name="step")) + @pipeline( + model_config=ModelConfig(name="step", version=ModelStages.LATEST) + ) def _inner_pipeline(): # this step requests a new version, but can be cached _this_step_produces_output.with_options( - model_config=ModelConfig( - name="step", create_new_model_version=True - ) + model_config=ModelConfig(name="step") )() # this is an always run step _this_step_produces_output.with_options(enable_cache=False)() @@ -770,12 +764,13 @@ def test_that_pipeline_run_is_removed_on_deletion_of_pipeline_run(): """Test that if pipeline run gets deleted - it is removed from model version.""" with model_killer(): - @pipeline(model_config=ModelConfig(name="step"), enable_cache=False) + @pipeline( + model_config=ModelConfig(name="step", version=ModelStages.LATEST), + enable_cache=False, + ) def _inner_pipeline(): _this_step_produces_output.with_options( - model_config=ModelConfig( - name="step", create_new_model_version=True - ) + model_config=ModelConfig(name="step") )() run_1 = f"run_{uuid4()}" @@ -793,15 +788,13 @@ def test_that_pipeline_run_is_removed_on_deletion_of_pipeline(): with model_killer(): @pipeline( - model_config=ModelConfig(name="step"), + model_config=ModelConfig(name="step", version=ModelStages.LATEST), enable_cache=False, name="test_that_pipeline_run_is_removed_on_deletion_of_pipeline", ) def _inner_pipeline(): _this_step_produces_output.with_options( - model_config=ModelConfig( - name="step", create_new_model_version=True - ) + model_config=ModelConfig(name="step") )() run_1 = f"run_{uuid4()}" @@ -821,14 +814,12 @@ def test_that_artifact_is_removed_on_deletion(): with model_killer(): @pipeline( - model_config=ModelConfig(name="step"), + model_config=ModelConfig(name="step", version=ModelStages.LATEST), enable_cache=False, ) def _inner_pipeline(): _this_step_produces_output.with_options( - model_config=ModelConfig( - name="step", create_new_model_version=True - ) + model_config=ModelConfig(name="step") )() run_1 = f"run_{uuid4()}" @@ -850,14 +841,7 @@ def _this_step_fails(): raise Exception("make pipeline fail") -@pytest.mark.parametrize( - "version", - ("test running version", None), - ids=["custom_running_name", "default_running_name"], -) -def test_that_two_pipelines_cannot_run_at_the_same_time_requesting_new_version_and_with_recovery( - version, -): +def test_that_two_pipelines_cannot_run_at_the_same_time_requesting_new_unnamed_version_and_with_recovery(): """Test that if second pipeline for same new version is started in parallel - it will fail.""" @pipeline( @@ -865,8 +849,6 @@ def test_that_two_pipelines_cannot_run_at_the_same_time_requesting_new_version_a enable_cache=False, model_config=ModelConfig( name="multi_run", - version=version, - create_new_model_version=True, delete_new_version_on_failure=False, ), ) @@ -888,57 +870,8 @@ def _this_pipeline_will_fail(): ) with pytest.raises( RuntimeError, - match="New model version was requested, but pipeline run", + match="New unnamed model version was requested", ): _this_pipeline_will_fail.with_options( run_name=f"multi_run_{uuid4()}" )() - - -def test_that_two_pipelines_cannot_create_same_specified_version(): - """Test that if second pipeline for same new version is started after completion of first one - it will fail.""" - - @pipeline( - model_config=ModelConfig( - name="step", - version="test running version", - create_new_model_version=True, - ), - enable_cache=False, - ) - def _inner_pipeline(): - _this_step_produces_output() - - with model_killer(): - _inner_pipeline() - with pytest.raises(RuntimeError, match="Cannot create version"): - _inner_pipeline() - - -def test_that_two_decorators_cannot_request_different_specific_new_version(): - """Test that if multiple decorators request different new versions - it will fail.""" - - @pipeline( - model_config=ModelConfig( - name="step", - version="test running version", - create_new_model_version=True, - ), - enable_cache=False, - ) - def _inner_pipeline(): - _this_step_produces_output() - _this_step_produces_output.with_options( - model_config=ModelConfig( - name="step", - version="test running version 2", - create_new_model_version=True, - ), - )() - - with model_killer(): - with pytest.raises( - ValueError, - match="A mismatch of `version` name in model configurations provided", - ): - _inner_pipeline() diff --git a/tests/unit/model/test_model_config_init.py b/tests/unit/model/test_model_config_init.py index adb68efa578..a40ce09ef01 100644 --- a/tests/unit/model/test_model_config_init.py +++ b/tests/unit/model/test_model_config_init.py @@ -2,20 +2,17 @@ import pytest -from zenml.enums import ModelStages from zenml.model import ModelConfig @pytest.mark.parametrize( - "version_name,create_new_model_version,delete_new_version_on_failure,logger", + "version_name,delete_new_version_on_failure,logger", [ - [None, True, False, "info"], - ["staging", False, False, "info"], - ["1", False, False, "info"], - [1, False, False, "info"], + ["staging", False, "info"], + ["1", False, "info"], + [1, False, "info"], ], ids=[ - "Default running version", "Pick model by text stage", "Pick model by text version number", "Pick model by integer version number", @@ -23,7 +20,6 @@ ) def test_init_warns( version_name, - create_new_model_version, delete_new_version_on_failure, logger, ): @@ -31,34 +27,6 @@ def test_init_warns( ModelConfig( name="foo", version=version_name, - create_new_model_version=create_new_model_version, delete_new_version_on_failure=delete_new_version_on_failure, ) logger.assert_called_once() - - -@pytest.mark.parametrize( - "version_name,create_new_model_version", - [ - [1, True], - ["1", True], - [ModelStages.PRODUCTION, True], - ["production", True], - ], - ids=[ - "Version number as integer and new version request", - "Version number as string and new version request", - "Version stage as instance and new version request", - "Version stage as string and new version request", - ], -) -def test_init_raises( - version_name, - create_new_model_version, -): - with pytest.raises(ValueError): - ModelConfig( - name="foo", - version=version_name, - create_new_model_version=create_new_model_version, - ) From dcd5f6535e632c437772621966ab29adf66e7bbb Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 7 Nov 2023 15:10:32 +0100 Subject: [PATCH 02/12] better error message --- src/zenml/model/model_config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index ada80b00eb5..fd720cc2506 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -132,7 +132,12 @@ def _validate_config_in_runtime(self) -> None: if self.delete_new_version_on_failure: raise RuntimeError( f"Cannot create version `{self.version}` " - f"for model `{self.name}` since it already exists" + f"for model `{self.name}` since it already exists. " + "This could happen for unforseen reasons (e.g. unexpected " + "intteruption of previous pipeline run flow).\n" + "If you would like to remove the staling version use " + "following CLI command:\n" + f"`zenml model version delete {self.name} {self.version}`" ) except KeyError: self.get_or_create_model_version() From ebdace91c8d3ea60ef7095ba0757d13e9208e701 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 7 Nov 2023 17:12:20 +0100 Subject: [PATCH 03/12] refactor --- src/zenml/model/model_config.py | 5 +- src/zenml/new/pipelines/pipeline.py | 123 ++++++++++++---------------- 2 files changed, 55 insertions(+), 73 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index fd720cc2506..58de967b20a 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -128,12 +128,11 @@ def _validate_config_in_runtime(self) -> None: "This run also operates with unnamed model version - " "new run will be stopped to prevent unexpected behavior." ) - if self.delete_new_version_on_failure: raise RuntimeError( f"Cannot create version `{self.version}` " f"for model `{self.name}` since it already exists. " - "This could happen for unforseen reasons (e.g. unexpected " + "This could happen for unforeseen reasons (e.g. unexpected " "intteruption of previous pipeline run flow).\n" "If you would like to remove the staling version use " "following CLI command:\n" @@ -241,8 +240,8 @@ def get_or_create_model_version(self) -> "ModelVersionResponseModel": model_version = zenml_client.create_model_version( model_version=mv_request ) - logger.info(f"New model version `{self.version}` was created.") self.was_created_in_this_run = True + logger.info(f"New model version `{self.version}` was created.") return model_version def _merge(self, model_config: "ModelConfig") -> None: diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 485c361a55f..0ba1bf9742e 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -800,6 +800,40 @@ def log_pipeline_deployment_metadata( except Exception as e: logger.debug(f"Logging pipeline deployment metadata failed: {e}") + def _update_new_requesters( + self, + requester_name: str, + model_config: "ModelConfig", + new_versions_requested: Dict[ + Tuple[str, Optional[str]], NewModelVersionRequest + ], + other_model_configs: Set["ModelConfig"], + ) -> None: + key = ( + model_config.name, + str(model_config.version) if model_config.version else None, + ) + if model_config.version is None: + version_existed = False + else: + try: + model_config._get_model_version() + version_existed = key not in new_versions_requested + except KeyError: + version_existed = False + if not version_existed: + model_config.was_created_in_this_run = True + new_versions_requested[key].update_request( + model_config, + NewModelVersionRequest.Requester( + source="step", name=requester_name + ), + ) + if model_config.version is None and key in new_versions_requested: + model_config.version = constants.RUNNING_MODEL_VERSION + else: + other_model_configs.add(model_config) + def get_new_version_requests( self, deployment: "PipelineDeploymentBaseModel" ) -> Dict[Tuple[str, Optional[str]], NewModelVersionRequest]: @@ -823,71 +857,23 @@ def get_new_version_requests( and step.config.model_config is not None ) if step_model_config: - try: - step_model_config._get_model_version() - version_existed = True - except KeyError: - version_existed = False - if not version_existed: - new_versions_requested[ - ( - step_model_config.name, - str(step_model_config.version) or None, - ) - ].update_request( - step_model_config, - NewModelVersionRequest.Requester( - source="step", name=step.config.name - ), - ) - if ( - step_model_config.version is None - and ( - step_model_config.name, - str(step_model_config.version) or None, - ) - in new_versions_requested - ): - step_model_config.version = ( - constants.RUNNING_MODEL_VERSION - ) - else: - other_model_configs.add(step_model_config) + self._update_new_requesters( + model_config=step_model_config, + requester_name=step.config.name, + new_versions_requested=new_versions_requested, + other_model_configs=other_model_configs, + ) if not all_steps_have_own_config: pipeline_model_config = ( deployment.pipeline_configuration.model_config ) if pipeline_model_config: - try: - pipeline_model_config._get_model_version() - version_existed = True - except KeyError: - version_existed = False - if not version_existed: - new_versions_requested[ - ( - pipeline_model_config.name, - str(pipeline_model_config.version) or None, - ) - ].update_request( - pipeline_model_config, - NewModelVersionRequest.Requester( - source="pipeline", name=self.name - ), - ) - if ( - pipeline_model_config.version is None - and ( - pipeline_model_config.name, - str(pipeline_model_config.version) or None, - ) - in new_versions_requested - ): - pipeline_model_config.version = ( - constants.RUNNING_MODEL_VERSION - ) - else: - other_model_configs.add(pipeline_model_config) + self._update_new_requesters( + model_config=pipeline_model_config, + requester_name=self.name, + new_versions_requested=new_versions_requested, + other_model_configs=other_model_configs, + ) elif deployment.pipeline_configuration.model_config is not None: logger.warning( f"ModelConfig of pipeline `{self.name}` is overridden in all steps. " @@ -933,17 +919,14 @@ def register_running_versions( Args: new_versions_requested: Dict of models requesting new versions and their definition points. """ - for key, new_version_request in new_versions_requested.items(): + for key, _ in new_versions_requested.items(): model_name, model_version = key if not model_version: - if ( - new_version_request.model_config.delete_new_version_on_failure - ): - mv = Client().get_model_version( - model_name_or_id=model_name, - model_version_name_or_number_or_id=constants.RUNNING_MODEL_VERSION, - ) - mv._update_default_running_version_name() + mv = Client().get_model_version( + model_name_or_id=model_name, + model_version_name_or_number_or_id=constants.RUNNING_MODEL_VERSION, + ) + mv._update_default_running_version_name() def delete_running_versions_without_recovery( self, @@ -960,7 +943,7 @@ def delete_running_versions_without_recovery( model_name, model_version = key if ( new_version_request.model_config.delete_new_version_on_failure - and new_version_request.model_config.version is not None + and new_version_request.model_config.was_created_in_this_run ): model = Client().get_model_version( model_name_or_id=model_name, From 54f8d52b5f0bd482723eb4ffa56284e5b295e6cc Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 7 Nov 2023 18:39:41 +0100 Subject: [PATCH 04/12] port lint fix --- src/zenml/client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/zenml/client.py b/src/zenml/client.py index 3ed6128f118..2a7678628a6 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5384,6 +5384,8 @@ def _get_entity_by_id_or_name_or_prefix( entity_label = get_method.__name__.replace("get_", "") + "s" formatted_entity_items = [ f"- {item.name}: (id: {item.id})\n" + if hasattr(item, "name") + else f"- {item.id}\n" for item in entity.items ] raise ZenKeyError( From 4458a0cd1bada580784eabea7f78ff6fa6ef65af Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 8 Nov 2023 06:54:20 +0100 Subject: [PATCH 05/12] update signature in tests --- tests/integration/functional/zen_stores/test_zen_store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index b70f1cea330..60093add611 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2683,6 +2683,7 @@ def test_latest_not_found(self): with pytest.raises(KeyError): zs.get_model_version( model_name_or_id=model.id, + model_version_name_or_number_or_id=ModelStages.LATEST, ) def test_latest_found(self): @@ -2708,6 +2709,7 @@ def test_latest_found(self): ) found_latest = zs.get_model_version( model_name_or_id=model.id, + model_version_name_or_number_or_id=ModelStages.LATEST, ) assert latest.id == found_latest.id From fc49aed49b9bb02c49e1988c9c414fa24defa030 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 8 Nov 2023 06:57:41 +0100 Subject: [PATCH 06/12] fix skipped test --- src/zenml/zen_stores/schemas/model_schemas.py | 11 ++++++++--- .../functional/zen_stores/test_zen_store.py | 1 - 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 8a0b0fc3c64..cf196f7de67 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -119,6 +119,13 @@ def to_model(self) -> ModelResponseModel: Returns: The created `ModelResponseModel`. """ + if self.model_versions: + version_numbers = [mv.number for mv in self.model_versions] + latest_version = self.model_versions[ + version_numbers.index(max(version_numbers)) + ].name + else: + latest_version = None return ModelResponseModel( id=self.id, name=self.name, @@ -134,9 +141,7 @@ def to_model(self) -> ModelResponseModel: trade_offs=self.trade_offs, ethics=self.ethics, tags=json.loads(self.tags) if self.tags else None, - latest_version=self.model_versions[-1].name - if self.model_versions - else None, + latest_version=latest_version, ) def update( diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 60093add611..258c6179047 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2435,7 +2435,6 @@ def test_connector_validation(): class TestModel: - @pytest.mark.skip("TODO: Fix to come from Andrei") def test_latest_version_properly_fetched(self): """Test that latest version can be properly fetched.""" with ModelVersionContext() as model: From b0ca24558cc523361a309dd2c97a233e5022b561 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:15:49 +0100 Subject: [PATCH 07/12] bring back default latest to some interfaces --- src/zenml/client.py | 8 ++++++-- src/zenml/models/model_models.py | 7 ++++--- src/zenml/zen_stores/rest_zen_store.py | 9 +++++---- src/zenml/zen_stores/sql_zen_store.py | 7 ++++++- src/zenml/zen_stores/zen_store_interface.py | 5 ++++- .../integration/functional/zen_stores/test_zen_store.py | 1 - 6 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index 2a7678628a6..8e2fdd283d0 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -5077,20 +5077,24 @@ def delete_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], + model_version_name_or_number_or_id: Optional[ + Union[str, int, UUID, ModelStages] + ] = None, ) -> ModelVersionResponseModel: """Get an existing model version from Model Control Plane. Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. + If skipped - latest version is retrieved. Returns: The model version of interest. """ return self.zen_store.get_model_version( model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id, + model_version_name_or_number_or_id=model_version_name_or_number_or_id + or ModelStages.LATEST, ) def list_model_versions( diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 47c7a42b08d..6c1ab4b0bfb 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -659,12 +659,13 @@ def versions(self) -> List[ModelVersionResponseModel]: ) def get_version( - self, version: Union[str, int, ModelStages] + self, version: Optional[Union[str, int, ModelStages]] = None ) -> ModelVersionResponseModel: """Get specific version of the model. Args: - version: version name, number, stage + version: version name, number, stage. + If skipped - latest version is retrieved. Returns: The requested model version. @@ -674,7 +675,7 @@ def get_version( return Client().get_model_version( model_name_or_id=self.name, model_version_name_or_number_or_id=getattr( - version, "value", version + version, "value", version or ModelStages.LATEST ), ) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index fa225620f27..f3b9a273e23 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -51,7 +51,6 @@ FLAVORS, GET_OR_CREATE, INFO, - LATEST_MODEL_VERSION_PLACEHOLDER, LOGIN, MODEL_VERSIONS, MODELS, @@ -2430,21 +2429,23 @@ def delete_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], + model_version_name_or_number_or_id: Optional[ + Union[str, int, UUID, ModelStages] + ] = None, ) -> ModelVersionResponseModel: """Get an existing model version. Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. - If skipped latest version will be retrieved. + If skipped - latest is retrieved. Returns: The model version of interest. """ return self._get_resource( resource_id=model_version_name_or_number_or_id - or LATEST_MODEL_VERSION_PLACEHOLDER, + or ModelStages.LATEST, route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", response_model=ModelVersionResponseModel, params={ diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 9f4d9b0704f..e7c10e721a3 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5738,13 +5738,16 @@ def create_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], + model_version_name_or_number_or_id: Optional[ + Union[str, int, UUID, ModelStages] + ] = None, ) -> ModelVersionResponseModel: """Get an existing model version. Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. + If skipped - latest is retrieved. Returns: The model version of interest. @@ -5757,6 +5760,8 @@ def get_model_version( query = select(ModelVersionSchema).where( ModelVersionSchema.model_id == model.id ) + if model_version_name_or_number_or_id is None: + model_version_name_or_number_or_id = ModelStages.LATEST if ( str(model_version_name_or_number_or_id) == ModelStages.LATEST.value diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index a72ed70812e..44ac5e4e15b 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1810,13 +1810,16 @@ def delete_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages], + model_version_name_or_number_or_id: Optional[ + Union[str, int, UUID, ModelStages] + ] = None, ) -> ModelVersionResponseModel: """Get an existing model version. Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved. + If skipped - latest is retrieved. Returns: The model version of interest. diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 258c6179047..a1217c6aca6 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2708,7 +2708,6 @@ def test_latest_found(self): ) found_latest = zs.get_model_version( model_name_or_id=model.id, - model_version_name_or_number_or_id=ModelStages.LATEST, ) assert latest.id == found_latest.id From 5821ce74d92012ecca0b7dbca95dd1333cc5d193 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:17:41 +0100 Subject: [PATCH 08/12] better error message --- src/zenml/model/model_config.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 58de967b20a..1abef0e76c5 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -123,10 +123,12 @@ def _validate_config_in_runtime(self) -> None: for run_name, run in model_version.pipeline_runs.items(): if run.status == ExecutionStatus.RUNNING: raise RuntimeError( - f"New unnamed model version was requested, " + f"You configure model context with explicit `version` passed, " + "so new unnamed model version has to be created, " f"but pipeline run `{run_name}` have not finished yet. " - "This run also operates with unnamed model version - " - "new run will be stopped to prevent unexpected behavior." + "To proceed you can:\n" + "- Wait for previous run to finish\n" + "- Provide explicit `version` in configuration" ) if self.delete_new_version_on_failure: raise RuntimeError( From 9ea1beb3df3222e4b084ed9cc4c5855c101d0d5a Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:19:06 +0100 Subject: [PATCH 09/12] Update src/zenml/model/model_config.py Co-authored-by: Felix Altenberger --- src/zenml/model/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 1abef0e76c5..207d1cf5a4e 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -125,7 +125,7 @@ def _validate_config_in_runtime(self) -> None: raise RuntimeError( f"You configure model context with explicit `version` passed, " "so new unnamed model version has to be created, " - f"but pipeline run `{run_name}` have not finished yet. " + f"but pipeline run `{run_name}` has not finished yet. " "To proceed you can:\n" "- Wait for previous run to finish\n" "- Provide explicit `version` in configuration" From c9f4118e94f207d4b7cef7c08a7f201c3dbbe330 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Thu, 9 Nov 2023 16:22:29 +0100 Subject: [PATCH 10/12] spelling --- src/zenml/model/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 1abef0e76c5..60ff1f3f8e5 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -135,7 +135,7 @@ def _validate_config_in_runtime(self) -> None: f"Cannot create version `{self.version}` " f"for model `{self.name}` since it already exists. " "This could happen for unforeseen reasons (e.g. unexpected " - "intteruption of previous pipeline run flow).\n" + "interruption of previous pipeline run flow).\n" "If you would like to remove the staling version use " "following CLI command:\n" f"`zenml model version delete {self.name} {self.version}`" From d155f7481829256f39d4431a7007e10fc1852c86 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 10 Nov 2023 07:53:49 +0100 Subject: [PATCH 11/12] better error message --- src/zenml/model/model_config.py | 8 ++++---- tests/integration/functional/steps/test_model_config.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 2faeb70d0c2..3c5b3f121ae 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -123,10 +123,10 @@ def _validate_config_in_runtime(self) -> None: for run_name, run in model_version.pipeline_runs.items(): if run.status == ExecutionStatus.RUNNING: raise RuntimeError( - f"You configure model context with explicit `version` passed, " - "so new unnamed model version has to be created, " - f"but pipeline run `{run_name}` has not finished yet. " - "To proceed you can:\n" + "You have configured a model context without explicit " + "`version` argument passed in, so new a unnamed model " + "version has to be created, but pipeline run " + f"`{run_name}` has not finished yet. To proceed you can:\n" "- Wait for previous run to finish\n" "- Provide explicit `version` in configuration" ) diff --git a/tests/integration/functional/steps/test_model_config.py b/tests/integration/functional/steps/test_model_config.py index 6db4b86fe79..fc707ab61c4 100644 --- a/tests/integration/functional/steps/test_model_config.py +++ b/tests/integration/functional/steps/test_model_config.py @@ -870,7 +870,7 @@ def _this_pipeline_will_fail(): ) with pytest.raises( RuntimeError, - match="New unnamed model version was requested", + match="You have configured a model context without explicit `version`", ): _this_pipeline_will_fail.with_options( run_name=f"multi_run_{uuid4()}" From a6d42841c1a1369953790b156741245184bd9901 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 10 Nov 2023 13:43:38 +0100 Subject: [PATCH 12/12] slightly improve error message --- src/zenml/model/model_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 3c5b3f121ae..37f1d4b19b7 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -133,7 +133,8 @@ def _validate_config_in_runtime(self) -> None: if self.delete_new_version_on_failure: raise RuntimeError( f"Cannot create version `{self.version}` " - f"for model `{self.name}` since it already exists. " + f"for model `{self.name}` since it already exists " + "and recovery mode is disabled. " "This could happen for unforeseen reasons (e.g. unexpected " "interruption of previous pipeline run flow).\n" "If you would like to remove the staling version use "