Skip to content

Commit

Permalink
More scalability improvements (#3206)
Browse files Browse the repository at this point in the history
* More scalability improvements

* Reduce DB and rbac calls for pipeline run creation

* Cleanup model version URL logging

* Remove remainders of rbac for workspaces

* Auto-update of E2E template

* Remove user RBAC

* Reduce retries back

* Call get method unhydrated

* Fix docstring

---------

Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
schustmi and actions-user authored Nov 26, 2024
1 parent c6af690 commit ae7f0e2
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 212 deletions.
82 changes: 0 additions & 82 deletions src/zenml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# permissions and limitations under the License.
"""Model user facing interface to pass into pipeline or step."""

import datetime
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Tuple,
Union,
)
from uuid import UUID
Expand All @@ -41,7 +39,6 @@
ModelResponse,
ModelVersionResponse,
PipelineRunResponse,
StepRunResponse,
)

logger = get_logger(__name__)
Expand Down Expand Up @@ -743,85 +740,6 @@ def __hash__(self) -> int:
)
)

def _prepare_model_version_before_step_launch(
self,
pipeline_run: "PipelineRunResponse",
step_run: Optional["StepRunResponse"],
return_logs: bool,
) -> Tuple[str, "PipelineRunResponse", Optional["StepRunResponse"]]:
"""Prepares model version inside pipeline run.
Args:
pipeline_run: pipeline run
step_run: step run (passed only if model version is defined in a step explicitly)
return_logs: whether to return logs or not
Returns:
Logs related to the Dashboard URL to show later.
"""
from zenml.client import Client
from zenml.models import PipelineRunUpdate, StepRunUpdate

logs = ""

# copy Model instance to prevent corrupting configs of the
# subsequent runs, if they share the same config object
self_copy = self.model_copy()

# in case request is within the step and no self-configuration is provided
# try reuse what's in the pipeline run first
if step_run is None and pipeline_run.model_version is not None:
self_copy.version = pipeline_run.model_version.name
self_copy.model_version_id = pipeline_run.model_version.id
# otherwise try to fill the templated name, if needed
elif isinstance(self_copy.version, str):
if pipeline_run.start_time:
start_time = pipeline_run.start_time
else:
start_time = datetime.datetime.now(datetime.timezone.utc)
self_copy.version = format_name_template(
self_copy.version,
date=start_time.strftime("%Y_%m_%d"),
time=start_time.strftime("%H_%M_%S_%f"),
)

# if exact model not yet defined - try to get/create and update it
# back to the run accordingly
if self_copy.model_version_id is None:
model_version_response = self_copy._get_or_create_model_version()

client = Client()
# update the configured model version id in runs accordingly
if step_run:
step_run = client.zen_store.update_run_step(
step_run_id=step_run.id,
step_run_update=StepRunUpdate(
model_version_id=model_version_response.id
),
)
else:
pipeline_run = client.zen_store.update_run(
run_id=pipeline_run.id,
run_update=PipelineRunUpdate(
model_version_id=model_version_response.id
),
)

if return_logs:
from zenml.utils.cloud_utils import try_get_model_version_url

if logs_to_show := try_get_model_version_url(
model_version_response
):
logs = logs_to_show
else:
logs = (
"Models can be viewed in the dashboard using ZenML Pro. Sign up "
"for a free trial at https://www.zenml.io/pro/"
)
self.model_version_id = self_copy.model_version_id
return logs, pipeline_run, step_run

@property
def _lazy_version(self) -> Optional[str]:
"""Get version name for lazy loader.
Expand Down
11 changes: 8 additions & 3 deletions src/zenml/orchestrators/step_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,10 +518,15 @@ def log_model_version_dashboard_url(
Args:
model_version: The model version for which to log the dashboard URL.
"""
from zenml.utils.cloud_utils import try_get_model_version_url
from zenml.utils.dashboard_utils import get_model_version_url

if model_version_url_logs := try_get_model_version_url(model_version):
logger.info(model_version_url_logs)
if model_version_url := get_model_version_url(model_version.id):
logger.info(
"Dashboard URL for Model Version `%s (%s)`:\n%s",
model_version.model.name,
model_version.name,
model_version_url,
)
else:
logger.info(
"Models can be viewed in the dashboard using ZenML Pro. Sign up "
Expand Down
40 changes: 0 additions & 40 deletions src/zenml/utils/cloud_utils.py

This file was deleted.

4 changes: 3 additions & 1 deletion src/zenml/zen_server/cloud_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def session(self) -> requests.Session:
token = self._fetch_auth_token()
self._session.headers.update({"Authorization": "Bearer " + token})

retries = Retry(total=5, backoff_factor=0.1)
retries = Retry(
total=5, backoff_factor=0.1, status_forcelist=[502, 504]
)
self._session.mount(
"https://",
HTTPAdapter(
Expand Down
10 changes: 6 additions & 4 deletions src/zenml/zen_server/rbac/endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def verify_permissions_and_list_entities(
def verify_permissions_and_update_entity(
id: UUIDOrStr,
update_model: AnyUpdate,
get_method: Callable[[UUIDOrStr], AnyResponse],
get_method: Callable[[UUIDOrStr, bool], AnyResponse],
update_method: Callable[[UUIDOrStr, AnyUpdate], AnyResponse],
) -> AnyResponse:
"""Verify permissions and update an entity.
Expand All @@ -203,15 +203,16 @@ def verify_permissions_and_update_entity(
Returns:
A model of the updated entity.
"""
model = get_method(id)
# We don't need the hydrated version here
model = get_method(id, False)
verify_permission_for_model(model, action=Action.UPDATE)
updated_model = update_method(model.id, update_model)
return dehydrate_response_model(updated_model)


def verify_permissions_and_delete_entity(
id: UUIDOrStr,
get_method: Callable[[UUIDOrStr], AnyResponse],
get_method: Callable[[UUIDOrStr, bool], AnyResponse],
delete_method: Callable[[UUIDOrStr], None],
) -> AnyResponse:
"""Verify permissions and delete an entity.
Expand All @@ -224,7 +225,8 @@ def verify_permissions_and_delete_entity(
Returns:
The deleted entity.
"""
model = get_method(id)
# We don't need the hydrated version here
model = get_method(id, False)
verify_permission_for_model(model, action=Action.DELETE)
delete_method(model.id)

Expand Down
5 changes: 3 additions & 2 deletions src/zenml/zen_server/rbac/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class ResourceType(StrEnum):
PIPELINE_DEPLOYMENT = "pipeline_deployment"
PIPELINE_BUILD = "pipeline_build"
RUN_TEMPLATE = "run_template"
USER = "user"
SERVICE = "service"
RUN_METADATA = "run_metadata"
SECRET = "secret"
Expand All @@ -70,7 +69,9 @@ class ResourceType(StrEnum):
TAG = "tag"
TRIGGER = "trigger"
TRIGGER_EXECUTION = "trigger_execution"
WORKSPACE = "workspace"
# Deactivated for now
# USER = "user"
# WORKSPACE = "workspace"


class Resource(BaseModel):
Expand Down
11 changes: 4 additions & 7 deletions src/zenml/zen_server/rbac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,6 @@ def get_resource_type_for_model(
TagResponse,
TriggerExecutionResponse,
TriggerResponse,
UserResponse,
WorkspaceResponse,
)

mapping: Dict[
Expand All @@ -434,8 +432,8 @@ def get_resource_type_for_model(
ModelVersionResponse: ResourceType.MODEL_VERSION,
ArtifactResponse: ResourceType.ARTIFACT,
ArtifactVersionResponse: ResourceType.ARTIFACT_VERSION,
WorkspaceResponse: ResourceType.WORKSPACE,
UserResponse: ResourceType.USER,
# WorkspaceResponse: ResourceType.WORKSPACE,
# UserResponse: ResourceType.USER,
PipelineDeploymentResponse: ResourceType.PIPELINE_DEPLOYMENT,
PipelineBuildResponse: ResourceType.PIPELINE_BUILD,
PipelineRunResponse: ResourceType.PIPELINE_RUN,
Expand Down Expand Up @@ -570,7 +568,6 @@ def get_schema_for_resource_type(
TriggerExecutionSchema,
TriggerSchema,
UserSchema,
WorkspaceSchema,
)

mapping: Dict[ResourceType, Type["BaseSchema"]] = {
Expand All @@ -588,13 +585,13 @@ def get_schema_for_resource_type(
ResourceType.SERVICE: ServiceSchema,
ResourceType.TAG: TagSchema,
ResourceType.SERVICE_ACCOUNT: UserSchema,
ResourceType.WORKSPACE: WorkspaceSchema,
# ResourceType.WORKSPACE: WorkspaceSchema,
ResourceType.PIPELINE_RUN: PipelineRunSchema,
ResourceType.PIPELINE_DEPLOYMENT: PipelineDeploymentSchema,
ResourceType.PIPELINE_BUILD: PipelineBuildSchema,
ResourceType.RUN_TEMPLATE: RunTemplateSchema,
ResourceType.RUN_METADATA: RunMetadataSchema,
ResourceType.USER: UserSchema,
# ResourceType.USER: UserSchema,
ResourceType.ACTION: ActionSchema,
ResourceType.EVENT_SOURCE: EventSourceSchema,
ResourceType.TRIGGER: TriggerSchema,
Expand Down
Loading

0 comments on commit ae7f0e2

Please sign in to comment.