Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove create_new_model_version arg of ModelConfig #2030

Merged
merged 15 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,6 @@ zenml_tutorial/
mlstacks_reset.sh

.local/

# exclude installed dashboard folder
src/zenml/zen_server/dashboard
2 changes: 1 addition & 1 deletion src/zenml/artifacts/external_artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _get_artifact_from_model(
name=self.model_name,
version=self.model_version,
)
model_version = model_config._get_model_version()
model_version = model_config.get_or_create_model_version()

for artifact_getter in [
model_version.get_artifact_object,
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _print_artifacts_links_generic(
"""
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=None
model_version_name_or_number_or_id=ModelStages.LATEST
if model_version_name_or_number_or_id == "0"
else model_version_name_or_number_or_id,
)
Expand Down Expand Up @@ -665,7 +665,7 @@ def list_model_version_pipeline_runs(
"""
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=None
model_version_name_or_number_or_id=ModelStages.LATEST
if model_version_name_or_number_or_id == "0"
else model_version_name_or_number_or_id,
)
Expand Down
15 changes: 5 additions & 10 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5086,14 +5086,15 @@ def get_model_version(
Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
If skipped latest version will be retrieved.
If skipped - latest version is retrieved.

Returns:
The model version of interest.
"""
return self.zen_store.get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id
or ModelStages.LATEST,
)

def list_model_versions(
Expand Down Expand Up @@ -5145,16 +5146,13 @@ def list_model_version_artifact_links(
self,
model_name_or_id: Union[str, UUID],
model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel,
model_version_name_or_number_or_id: Optional[
Union[str, int, UUID, ModelStages]
] = None,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
) -> Page[ModelVersionArtifactResponseModel]:
"""Get model version to artifact links by filter in Model Control Plane.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
If skipped latest version will be retrieved.
model_version_artifact_link_filter_model: All filter parameters including pagination
params.

Expand All @@ -5181,16 +5179,13 @@ def list_model_version_pipeline_run_links(
self,
model_name_or_id: Union[str, UUID],
model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel,
model_version_name_or_number_or_id: Optional[
Union[str, int, UUID, ModelStages]
] = None,
model_version_name_or_number_or_id: Union[str, int, UUID, ModelStages],
) -> Page[ModelVersionPipelineRunResponseModel]:
"""Get all model version to pipeline run links by filter.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
If skipped latest version will be retrieved.
model_version_pipeline_run_link_filter_model: All filter parameters including pagination
params.

Expand Down
1 change: 1 addition & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,4 @@ class ModelStages(StrEnum):
STAGING = "staging"
PRODUCTION = "production"
ARCHIVED = "archived"
LATEST = "latest"
5 changes: 2 additions & 3 deletions src/zenml/model/artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ArtifactConfig(BaseModel):
model_name: The name of the model to link artifact to.
model_version: The identifier of the model version to link artifact to.
It can be exact version ("23"), exact version number (42), stage
(ModelStages.PRODUCTION) or None for the latest version.
(ModelStages.PRODUCTION) or ModelStages.LATEST for the latest version.
model_stage: The stage of the model version to link artifact to.
artifact_name: The override name of a link instead of an artifact name.
overwrite: Whether to overwrite an existing link or create new versions.
Expand Down Expand Up @@ -81,7 +81,6 @@ def _model_config(self) -> "ModelConfig":
on_the_fly_config = ModelConfig(
name=self.model_name,
version=self.model_version,
create_new_model_version=False,
)
return on_the_fly_config

Expand Down Expand Up @@ -120,7 +119,7 @@ def _link_to_model_version(
# Create a ZenML client
client = Client()

model_version = model_config._get_model_version()
model_version = model_config.get_or_create_model_version()

artifact_name = self.artifact_name
if artifact_name is None:
Expand Down
173 changes: 74 additions & 99 deletions src/zenml/model/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ class ModelConfig(BaseModel):
ethics: The ethical implications of the model.
tags: Tags associated with the model.
version: The model version name, number or stage is optional and points model context
to a specific version/stage. If skipped and `create_new_model_version` is False -
latest model version will be used.
to a specific version/stage. If skipped new model version will be created.
version_description: The description of the model version.
create_new_model_version: Whether to create a new model version during execution
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
if available in active stack.
delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.
Expand All @@ -71,11 +69,11 @@ class ModelConfig(BaseModel):
tags: Optional[List[str]]
version: Optional[Union[ModelStages, int, str]]
version_description: Optional[str]
create_new_model_version: bool = False
save_models_to_registry: bool = True
delete_new_version_on_failure: bool = True

suppress_class_validation_warnings: bool = False
was_created_in_this_run: bool = False

class Config:
"""Config class."""
Expand All @@ -91,43 +89,12 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:

Returns:
Dict of validated values.

Raises:
ValueError: If validation failed on one of the checks.
"""
create_new_model_version = values.get(
"create_new_model_version", False
)
suppress_class_validation_warnings = values.get(
"suppress_class_validation_warnings", False
)
version = values.get("version", None)

if create_new_model_version:
misuse_message = (
"`version` set to {set} cannot be used with `create_new_model_version`."
"You can leave it default or set to a non-stage and non-numeric string.\n"
"Examples:\n"
" - `version` set to 1 or '1' is interpreted as a version number\n"
" - `version` set to 'production' is interpreted as a stage\n"
" - `version` set to 'my_first_version_in_2023' is a valid version to be created\n"
" - `version` set to 'My Second Version!' is a valid version to be created\n"
)
if isinstance(version, ModelStages) or version in [
stage.value for stage in ModelStages
]:
raise ValueError(
misuse_message.format(set="a `ModelStages` instance")
)
if str(version).isnumeric():
raise ValueError(misuse_message.format(set="a numeric value"))
if version is None:
if not suppress_class_validation_warnings:
logger.info(
"Creation of new model version was requested, but no version name was explicitly provided. "
f"Setting `version` to `{RUNNING_MODEL_VERSION}`."
)
values["version"] = RUNNING_MODEL_VERSION
if (
version in [stage.value for stage in ModelStages]
and not suppress_class_validation_warnings
Expand All @@ -151,18 +118,28 @@ def _validate_config_in_runtime(self) -> None:
"""
try:
model_version = self._get_model_version()
if self.create_new_model_version:
if self.version is None or self.version == RUNNING_MODEL_VERSION:
self.version = RUNNING_MODEL_VERSION
for run_name, run in model_version.pipeline_runs.items():
if run.status == ExecutionStatus.RUNNING:
raise RuntimeError(
f"New model version was requested, but pipeline run `{run_name}` "
f"is still running with version `{model_version.name}`."
"You have configured a model context without explicit "
"`version` argument passed in, so new a unnamed model "
"version has to be created, but pipeline run "
f"`{run_name}` has not finished yet. To proceed you can:\n"
"- Wait for previous run to finish\n"
"- Provide explicit `version` in configuration"
)

if self.delete_new_version_on_failure:
raise RuntimeError(
f"Cannot create version `{self.version}` "
f"for model `{self.name}` since it already exists"
f"for model `{self.name}` since it already exists "
"and recovery mode is disabled. "
"This could happen for unforeseen reasons (e.g. unexpected "
"interruption of previous pipeline run flow).\n"
"If you would like to remove the staling version use "
"following CLI command:\n"
f"`zenml model version delete {self.name} {self.version}`"
)
except KeyError:
self.get_or_create_model_version()
Expand Down Expand Up @@ -205,43 +182,6 @@ def get_or_create_model(self) -> "ModelResponseModel":

return model

def _create_model_version(
self, model: "ModelResponseModel"
) -> "ModelVersionResponseModel":
"""This method creates a model version for Model Control Plane.

Args:
model: The model containing the model version.

Returns:
The model version based on configuration.
"""
from zenml.client import Client
from zenml.models.model_models import ModelVersionRequestModel

zenml_client = Client()
model_version_request = ModelVersionRequestModel(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=self.version,
description=self.version_description,
model=model.id,
)
mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
try:
mv = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
)
model_version = mv
except KeyError:
model_version = zenml_client.create_model_version(
model_version=mv_request
)
logger.info(f"New model version `{self.version}` was created.")

return model_version

def _get_model_version(self) -> "ModelVersionResponseModel":
"""This method gets a model version from Model Control Plane.

Expand All @@ -251,19 +191,11 @@ def _get_model_version(self) -> "ModelVersionResponseModel":
from zenml.client import Client

zenml_client = Client()
if self.version is None:
# raise if not found
model_version = zenml_client.get_model_version(
model_name_or_id=self.name
)
else:
# by version name or stage or number
# raise if not found
model_version = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version,
)
return model_version
return zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_number_or_id=self.version
or RUNNING_MODEL_VERSION,
)

def get_or_create_model_version(self) -> "ModelVersionResponseModel":
"""This method should get or create a model and a model version from Model Control Plane.
Expand All @@ -273,23 +205,47 @@ def get_or_create_model_version(self) -> "ModelVersionResponseModel":

Model Version returned by this method is resolved based on model configuration:
- If there is an existing model version leftover from the previous failed run with
`delete_new_version_on_failure` is set to False and `create_new_model_version` is True,
`delete_new_version_on_failure` is set to False and `version` is None,
leftover model version will be reused.
- Otherwise if `create_new_model_version` is True, a new model version is created.
- If `create_new_model_version` is False a model version will be fetched based on the version:
- If `version` is not set, the latest model version will be fetched.
- Otherwise if `version` is None, a new model version is created.
- If `version` is not None a model version will be fetched based on the version:
- If `version` is set to an integer or digit string, the model version with the matching number will be fetched.
- If `version` is set to a string, the model version with the matching version will be fetched.
- If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

Returns:
The model version based on configuration.
"""
from zenml.client import Client
from zenml.models.model_models import ModelVersionRequestModel

model = self.get_or_create_model()
if self.create_new_model_version:
mv = self._create_model_version(model)
else:
mv = self._get_model_version()
return mv

if self.version is None:
logger.info(
"Creation of new model version was requested, but no version name was explicitly provided. "
f"Setting `version` to `{RUNNING_MODEL_VERSION}`."
)
self.version = RUNNING_MODEL_VERSION

zenml_client = Client()
model_version_request = ModelVersionRequestModel(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
name=self.version,
description=self.version_description,
model=model.id,
)
mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
try:
model_version = self._get_model_version()
except KeyError:
model_version = zenml_client.create_model_version(
model_version=mv_request
)
self.was_created_in_this_run = True
logger.info(f"New model version `{self.version}` was created.")
return model_version

def _merge(self, model_config: "ModelConfig") -> None:
self.license = self.license or model_config.license
Expand All @@ -307,3 +263,22 @@ def _merge(self, model_config: "ModelConfig") -> None:
self.delete_new_version_on_failure &= (
model_config.delete_new_version_on_failure
)

def __hash__(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this used for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Get hash of the `ModelConfig`.

Returns:
Hash function results
"""
return hash(
"::".join(
(
str(v)
for v in (
self.name,
self.version,
self.delete_new_version_on_failure,
)
)
)
)
Loading