Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove create_new_model_version arg of ModelConfig #2030

Merged
merged 15 commits into from
Nov 10, 2023
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,6 @@ zenml_tutorial/
mlstacks_reset.sh

.local/

# exclude installed dashboard folder
src/zenml/zen_server/dashboard
2 changes: 1 addition & 1 deletion src/zenml/artifacts/external_artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _get_artifact_from_model(
name=self.model_name,
version=self.model_version,
)
model_version = model_config._get_model_version()
model_version = model_config.get_or_create_model_version()

for artifact_getter in [
model_version.get_artifact_object,
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _print_artifacts_links_generic(
"""
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=None
model_version_name_or_number_or_id=ModelStages.LATEST
if model_version_name_or_number_or_id == "0"
else model_version_name_or_number_or_id,
)
Expand Down Expand Up @@ -665,7 +665,7 @@ def list_model_version_pipeline_runs(
"""
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=None
model_version_name_or_number_or_id=ModelStages.LATEST
if model_version_name_or_number_or_id == "0"
else model_version_name_or_number_or_id,
)
Expand Down
15 changes: 3 additions & 12 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5077,16 +5077,13 @@ def delete_model_version(
def get_model_version(
self,
model_name_or_id: Union[str, UUID],
model_version_name_or_number_or_id: Optional[
Union[str, int, UUID, ModelStages]
] = None,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
) -> ModelVersionResponseModel:
"""Get an existing model version from Model Control Plane.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
If skipped latest version will be retrieved.

Returns:
The model version of interest.
Expand Down Expand Up @@ -5145,16 +5142,13 @@ def list_model_version_artifact_links(
self,
model_name_or_id: Union[str, UUID],
model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel,
model_version_name_or_number_or_id: Optional[
Union[str, int, UUID, ModelStages]
] = None,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
) -> Page[ModelVersionArtifactResponseModel]:
"""Get model version to artifact links by filter in Model Control Plane.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
If skipped latest version will be retrieved.
model_version_artifact_link_filter_model: All filter parameters including pagination
params.

Expand All @@ -5181,16 +5175,13 @@ def list_model_version_pipeline_run_links(
self,
model_name_or_id: Union[str, UUID],
model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel,
model_version_name_or_number_or_id: Optional[
Union[str, int, UUID, ModelStages]
] = None,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
) -> Page[ModelVersionPipelineRunResponseModel]:
"""Get all model version to pipeline run links by filter.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
If skipped latest version will be retrieved.
model_version_pipeline_run_link_filter_model: All filter parameters including pagination
params.

Expand Down
1 change: 1 addition & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,4 @@ class ModelStages(StrEnum):
STAGING = "staging"
PRODUCTION = "production"
ARCHIVED = "archived"
LATEST = "latest"
5 changes: 2 additions & 3 deletions src/zenml/model/artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ArtifactConfig(BaseModel):
model_name: The name of the model to link artifact to.
model_version: The identifier of the model version to link artifact to.
It can be exact version ("23"), exact version number (42), stage
(ModelStages.PRODUCTION) or None for the latest version.
(ModelStages.PRODUCTION) or ModelStages.LATEST for the latest version.
model_stage: The stage of the model version to link artifact to.
artifact_name: The override name of a link instead of an artifact name.
overwrite: Whether to overwrite an existing link or create new versions.
Expand Down Expand Up @@ -81,7 +81,6 @@ def _model_config(self) -> "ModelConfig":
on_the_fly_config = ModelConfig(
name=self.model_name,
version=self.model_version,
create_new_model_version=False,
)
return on_the_fly_config

Expand Down Expand Up @@ -120,7 +119,7 @@ def _link_to_model_version(
# Create a ZenML client
client = Client()

model_version = model_config._get_model_version()
model_version = model_config.get_or_create_model_version()

artifact_name = self.artifact_name
if artifact_name is None:
Expand Down
170 changes: 71 additions & 99 deletions src/zenml/model/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ class ModelConfig(BaseModel):
ethics: The ethical implications of the model.
tags: Tags associated with the model.
version: The model version name, number or stage is optional and points model context
to a specific version/stage. If skipped and `create_new_model_version` is False -
latest model version will be used.
to a specific version/stage. If skipped new model version will be created.
version_description: The description of the model version.
create_new_model_version: Whether to create a new model version during execution
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
if available in active stack.
delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.
Expand All @@ -71,11 +69,11 @@ class ModelConfig(BaseModel):
tags: Optional[List[str]]
version: Optional[Union[ModelStages, int, str]]
version_description: Optional[str]
create_new_model_version: bool = False
save_models_to_registry: bool = True
delete_new_version_on_failure: bool = True

suppress_class_validation_warnings: bool = False
was_created_in_this_run: bool = False

class Config:
"""Config class."""
Expand All @@ -91,43 +89,12 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:

Returns:
Dict of validated values.

Raises:
ValueError: If validation failed on one of the checks.
"""
create_new_model_version = values.get(
"create_new_model_version", False
)
suppress_class_validation_warnings = values.get(
"suppress_class_validation_warnings", False
)
version = values.get("version", None)

if create_new_model_version:
misuse_message = (
"`version` set to {set} cannot be used with `create_new_model_version`."
"You can leave it default or set to a non-stage and non-numeric string.\n"
"Examples:\n"
" - `version` set to 1 or '1' is interpreted as a version number\n"
" - `version` set to 'production' is interpreted as a stage\n"
" - `version` set to 'my_first_version_in_2023' is a valid version to be created\n"
" - `version` set to 'My Second Version!' is a valid version to be created\n"
)
if isinstance(version, ModelStages) or version in [
stage.value for stage in ModelStages
]:
raise ValueError(
misuse_message.format(set="a `ModelStages` instance")
)
if str(version).isnumeric():
raise ValueError(misuse_message.format(set="a numeric value"))
if version is None:
if not suppress_class_validation_warnings:
logger.info(
"Creation of new model version was requested, but no version name was explicitly provided. "
f"Setting `version` to `{RUNNING_MODEL_VERSION}`."
)
values["version"] = RUNNING_MODEL_VERSION
if (
version in [stage.value for stage in ModelStages]
and not suppress_class_validation_warnings
Expand All @@ -151,18 +118,25 @@ def _validate_config_in_runtime(self) -> None:
"""
try:
model_version = self._get_model_version()
if self.create_new_model_version:
if self.version is None or self.version == RUNNING_MODEL_VERSION:
self.version = RUNNING_MODEL_VERSION
for run_name, run in model_version.pipeline_runs.items():
if run.status == ExecutionStatus.RUNNING:
raise RuntimeError(
f"New model version was requested, but pipeline run `{run_name}` "
f"is still running with version `{model_version.name}`."
f"New unnamed model version was requested, "
f"but pipeline run `{run_name}` have not finished yet. "
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
"This run also operates with unnamed model version - "
"new run will be stopped to prevent unexpected behavior."
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
)

if self.delete_new_version_on_failure:
raise RuntimeError(
f"Cannot create version `{self.version}` "
f"for model `{self.name}` since it already exists"
f"for model `{self.name}` since it already exists. "
"This could happen for unforeseen reasons (e.g. unexpected "
"intteruption of previous pipeline run flow).\n"
"If you would like to remove the staling version use "
"following CLI command:\n"
f"`zenml model version delete {self.name} {self.version}`"
)
except KeyError:
self.get_or_create_model_version()
Expand Down Expand Up @@ -205,43 +179,6 @@ def get_or_create_model(self) -> "ModelResponseModel":

return model

def _create_model_version(
self, model: "ModelResponseModel"
) -> "ModelVersionResponseModel":
"""This method creates a model version for Model Control Plane.

Args:
model: The model containing the model version.

Returns:
The model version based on configuration.
"""
from zenml.client import Client
from zenml.models.model_models import ModelVersionRequestModel

zenml_client = Client()
model_version_request = ModelVersionRequestModel(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=self.version,
description=self.version_description,
model=model.id,
)
mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
try:
mv = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
)
model_version = mv
except KeyError:
model_version = zenml_client.create_model_version(
model_version=mv_request
)
logger.info(f"New model version `{self.version}` was created.")

return model_version

def _get_model_version(self) -> "ModelVersionResponseModel":
"""This method gets a model version from Model Control Plane.

Expand All @@ -251,19 +188,11 @@ def _get_model_version(self) -> "ModelVersionResponseModel":
from zenml.client import Client

zenml_client = Client()
if self.version is None:
# raise if not found
model_version = zenml_client.get_model_version(
model_name_or_id=self.name
)
else:
# by version name or stage or number
# raise if not found
model_version = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
)
return model_version
return zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version
or RUNNING_MODEL_VERSION,
)

def get_or_create_model_version(self) -> "ModelVersionResponseModel":
"""This method should get or create a model and a model version from Model Control Plane.
Expand All @@ -273,23 +202,47 @@ def get_or_create_model_version(self) -> "ModelVersionResponseModel":

Model Version returned by this method is resolved based on model configuration:
- If there is an existing model version leftover from the previous failed run with
`delete_new_version_on_failure` is set to False and `create_new_model_version` is True,
`delete_new_version_on_failure` is set to False and `version` is None,
leftover model version will be reused.
- Otherwise if `create_new_model_version` is True, a new model version is created.
- If `create_new_model_version` is False a model version will be fetched based on the version:
- If `version` is not set, the latest model version will be fetched.
- Otherwise if `version` is None, a new model version is created.
- If `version` is not None a model version will be fetched based on the version:
- If `version` is set to an integer or digit string, the model version with the matching number will be fetched.
- If `version` is set to a string, the model version with the matching version will be fetched.
- If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

Returns:
The model version based on configuration.
"""
from zenml.client import Client
from zenml.models.model_models import ModelVersionRequestModel

model = self.get_or_create_model()
if self.create_new_model_version:
mv = self._create_model_version(model)
else:
mv = self._get_model_version()
return mv

if self.version is None:
logger.info(
"Creation of new model version was requested, but no version name was explicitly provided. "
f"Setting `version` to `{RUNNING_MODEL_VERSION}`."
)
self.version = RUNNING_MODEL_VERSION

zenml_client = Client()
model_version_request = ModelVersionRequestModel(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=self.version,
description=self.version_description,
model=model.id,
)
mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
try:
model_version = self._get_model_version()
except KeyError:
model_version = zenml_client.create_model_version(
model_version=mv_request
)
self.was_created_in_this_run = True
logger.info(f"New model version `{self.version}` was created.")
return model_version

def _merge(self, model_config: "ModelConfig") -> None:
self.license = self.license or model_config.license
Expand All @@ -307,3 +260,22 @@ def _merge(self, model_config: "ModelConfig") -> None:
self.delete_new_version_on_failure &= (
model_config.delete_new_version_on_failure
)

def __hash__(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this used for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Get hash of the `ModelConfig`.

Returns:
Hash function results
"""
return hash(
"::".join(
(
str(v)
for v in (
self.name,
self.version,
self.delete_new_version_on_failure,
)
)
)
)
Loading