From c2fe19eb039c133a53ba68730251209a84876285 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20Can=20Durak?= <36421093+bcdurak@users.noreply.github.com> Date: Thu, 9 Jan 2025 11:43:55 +0100 Subject: [PATCH] Fixed and improved sorting (#3266) * first fixes * final fixes * adding tag list back * fixing the failures * fixing the problem with the tags and adjusting the tests * adding a small comment * formatting --- src/zenml/constants.py | 1 + src/zenml/models/v2/base/scoped.py | 60 ++++---- src/zenml/models/v2/core/artifact.py | 88 ++++++++++- src/zenml/models/v2/core/model.py | 83 +++++++++- src/zenml/models/v2/core/pipeline.py | 27 ++-- src/zenml/models/v2/core/pipeline_run.py | 2 + src/zenml/zen_stores/sql_zen_store.py | 7 + .../functional/models/test_sorting.py | 143 ++++++++++++++++++ 8 files changed, 367 insertions(+), 44 deletions(-) create mode 100644 tests/integration/functional/models/test_sorting.py diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 183b1acce16..a7a13edb614 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -429,6 +429,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ) FILTERING_DATETIME_FORMAT: str = "%Y-%m-%d %H:%M:%S" SORT_PIPELINES_BY_LATEST_RUN_KEY = "latest_run" +SORT_BY_LATEST_VERSION_KEY = "latest_version" # Metadata constants METADATA_ORCHESTRATOR_URL = "orchestrator_url" diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 16ca14f4b5c..4eefc97f4f0 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -245,6 +245,8 @@ def apply_sorting( UserSchema, getattr(table, "user_id") == UserSchema.id ) + query = query.add_columns(UserSchema.name) + if operand == SorterOps.ASCENDING: query = query.order_by(asc(column)) else: @@ -449,6 +451,8 @@ def apply_sorting( getattr(table, "workspace_id") == WorkspaceSchema.id, ) + query = query.add_columns(WorkspaceSchema.name) + if operand == SorterOps.ASCENDING: query = query.order_by(asc(column)) else: @@ -470,10 +474,9 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "tag", ] - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ *WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS, - "tag", + "tags", ] def apply_filter( @@ -540,8 +543,8 @@ def apply_sorting( """ sort_by, operand = self.sorting_params - if sort_by == "tag": - from sqlmodel import and_, asc, desc, func + if sort_by == "tags": + from sqlmodel import asc, desc, func, select from zenml.enums import SorterOps, TaggableResourceTypes from zenml.zen_stores.schemas import ( @@ -566,35 +569,36 @@ def apply_sorting( RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE, } - query = ( - query.outerjoin( - TagResourceSchema, - and_( - table.id == TagResourceSchema.resource_id, - TagResourceSchema.resource_type - == resource_type_mapping[table], - ), + sorted_tags = ( + select(TagResourceSchema.resource_id, TagSchema.name) + .join(TagSchema, TagResourceSchema.tag_id == TagSchema.id) # type: ignore[arg-type] + .filter( + TagResourceSchema.resource_type # type: ignore[arg-type] + == resource_type_mapping[table] + ) + .order_by( + asc(TagResourceSchema.resource_id), asc(TagSchema.name) ) - .outerjoin(TagSchema, TagResourceSchema.tag_id == TagSchema.id) - .group_by(table.id) + ).alias("sorted_tags") + + tags_subquery = ( + select( + sorted_tags.c.resource_id, + func.group_concat(sorted_tags.c.name, ", ").label( + "tags_list" + ), + ).group_by(sorted_tags.c.resource_id) + ).alias("tags_subquery") + + query = query.add_columns(tags_subquery.c.tags_list).outerjoin( + tags_subquery, table.id == tags_subquery.c.resource_id ) + # Apply ordering based on the tags list if operand == SorterOps.ASCENDING: - query = query.order_by( - asc( - func.group_concat(TagSchema.name, ",").label( - "tags_list" - ) - ) - ) + query = query.order_by(asc("tags_list")) else: - query = query.order_by( - desc( - func.group_concat(TagSchema.name, ",").label( - "tags_list" - ) - ) - ) + query = query.order_by(desc("tags_list")) return query diff --git a/src/zenml/models/v2/core/artifact.py b/src/zenml/models/v2/core/artifact.py index c62a7cee1a5..e36b602b6e3 100644 --- a/src/zenml/models/v2/core/artifact.py +++ b/src/zenml/models/v2/core/artifact.py @@ -13,12 +13,21 @@ # permissions and limitations under the License. """Models representing artifacts.""" -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, +) from uuid import UUID from pydantic import BaseModel, Field -from zenml.constants import STR_FIELD_MAX_LENGTH +from zenml.constants import SORT_BY_LATEST_VERSION_KEY, STR_FIELD_MAX_LENGTH from zenml.models.v2.base.base import ( BaseDatedResponseBody, BaseIdentifiedResponse, @@ -31,6 +40,11 @@ if TYPE_CHECKING: from zenml.models.v2.core.artifact_version import ArtifactVersionResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + +AnyQuery = TypeVar("AnyQuery", bound=Any) # ------------------ Request Model ------------------ @@ -174,3 +188,73 @@ class ArtifactFilter(WorkspaceScopedTaggableFilter): name: Optional[str] = None has_custom_name: Optional[bool] = None + + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + SORT_BY_LATEST_VERSION_KEY, + ] + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query for Artifacts. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, case, col, desc, func, select + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import ( + ArtifactSchema, + ArtifactVersionSchema, + ) + + sort_by, operand = self.sorting_params + + if sort_by == SORT_BY_LATEST_VERSION_KEY: + # Subquery to find the latest version per artifact + latest_version_subquery = ( + select( + ArtifactSchema.id, + case( + ( + func.max(ArtifactVersionSchema.created).is_(None), + ArtifactSchema.created, + ), + else_=func.max(ArtifactVersionSchema.created), + ).label("latest_version_created"), + ) + .outerjoin( + ArtifactVersionSchema, + ArtifactSchema.id == ArtifactVersionSchema.artifact_id, # type: ignore[arg-type] + ) + .group_by(col(ArtifactSchema.id)) + .subquery() + ) + + query = query.add_columns( + latest_version_subquery.c.latest_version_created, + ).where(ArtifactSchema.id == latest_version_subquery.c.id) + + # Apply sorting based on the operand + if operand == SorterOps.ASCENDING: + query = query.order_by( + asc(latest_version_subquery.c.latest_version_created), + asc(ArtifactSchema.id), + ) + else: + query = query.order_by( + desc(latest_version_subquery.c.latest_version_created), + desc(ArtifactSchema.id), + ) + return query + + # For other sorting cases, delegate to the parent class + return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index 0b5272ab7e6..bb341c1d5a1 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -13,12 +13,16 @@ # permissions and limitations under the License. """Models representing models.""" -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Type, TypeVar from uuid import UUID from pydantic import BaseModel, Field -from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH +from zenml.constants import ( + SORT_BY_LATEST_VERSION_KEY, + STR_FIELD_MAX_LENGTH, + TEXT_FIELD_MAX_LENGTH, +) from zenml.models.v2.base.scoped import ( WorkspaceScopedRequest, WorkspaceScopedResponse, @@ -32,6 +36,11 @@ if TYPE_CHECKING: from zenml.model.model import Model from zenml.models.v2.core.tag import TagResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + +AnyQuery = TypeVar("AnyQuery", bound=Any) # ------------------ Request Model ------------------ @@ -320,3 +329,73 @@ class ModelFilter(WorkspaceScopedTaggableFilter): default=None, description="Name of the Model", ) + + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + SORT_BY_LATEST_VERSION_KEY, + ] + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query for Models. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, case, col, desc, func, select + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import ( + ModelSchema, + ModelVersionSchema, + ) + + sort_by, operand = self.sorting_params + + if sort_by == SORT_BY_LATEST_VERSION_KEY: + # Subquery to find the latest version per model + latest_version_subquery = ( + select( + ModelSchema.id, + case( + ( + func.max(ModelVersionSchema.created).is_(None), + ModelSchema.created, + ), + else_=func.max(ModelVersionSchema.created), + ).label("latest_version_created"), + ) + .outerjoin( + ModelVersionSchema, + ModelSchema.id == ModelVersionSchema.model_id, # type: ignore[arg-type] + ) + .group_by(col(ModelSchema.id)) + .subquery() + ) + + query = query.add_columns( + latest_version_subquery.c.latest_version_created, + ).where(ModelSchema.id == latest_version_subquery.c.id) + + # Apply sorting based on the operand + if operand == SorterOps.ASCENDING: + query = query.order_by( + asc(latest_version_subquery.c.latest_version_created), + asc(ModelSchema.id), + ) + else: + query = query.order_by( + desc(latest_version_subquery.c.latest_version_created), + desc(ModelSchema.id), + ) + return query + + # For other sorting cases, delegate to the parent class + return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 199e9cce959..43ea80dc7e0 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -357,7 +357,7 @@ def apply_sorting( # Subquery to find the latest run per pipeline latest_run_subquery = ( select( - PipelineRunSchema.pipeline_id, + PipelineSchema.id, case( ( func.max(PipelineRunSchema.created).is_(None), @@ -366,25 +366,28 @@ def apply_sorting( else_=func.max(PipelineRunSchema.created), ).label("latest_run"), ) - .group_by(col(PipelineRunSchema.pipeline_id)) + .outerjoin( + PipelineRunSchema, + PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type] + ) + .group_by(col(PipelineSchema.id)) .subquery() ) - # Join the subquery with the pipelines - query = query.outerjoin( - latest_run_subquery, - PipelineSchema.id == latest_run_subquery.c.pipeline_id, - ) + query = query.add_columns( + latest_run_subquery.c.latest_run, + ).where(PipelineSchema.id == latest_run_subquery.c.id) if operand == SorterOps.ASCENDING: query = query.order_by( - asc(latest_run_subquery.c.latest_run) - ).order_by(col(PipelineSchema.id)) + asc(latest_run_subquery.c.latest_run), + asc(PipelineSchema.id), + ) else: query = query.order_by( - desc(latest_run_subquery.c.latest_run) - ).order_by(col(PipelineSchema.id)) - + desc(latest_run_subquery.c.latest_run), + desc(PipelineSchema.id), + ) return query else: return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 3a22f642953..740fbc6711e 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -982,6 +982,8 @@ def apply_sorting( else: return super().apply_sorting(query=query, table=table) + query = query.add_columns(column) + if operand == SorterOps.ASCENDING: query = query.order_by(asc(column)) else: diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 19bdda8b28f..1483a782de7 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -64,6 +64,13 @@ ) from sqlalchemy.orm import Mapped, noload from sqlalchemy.util import immutabledict + +# Important to note: The select function of SQLModel works slightly differently +# from the select function of sqlalchemy. If you input only one entity on the +# select function of SQLModel, it automatically maps it to a SelectOfScalar. +# As a result, it will not return a tuple as a result, but the first entity in +# the tuple. While this is convenient in most cases, in unique cases like using +# the "add_columns" functionality, one might encounter unexpected results. from sqlmodel import ( Session, SQLModel, diff --git a/tests/integration/functional/models/test_sorting.py b/tests/integration/functional/models/test_sorting.py new file mode 100644 index 00000000000..138468fe867 --- /dev/null +++ b/tests/integration/functional/models/test_sorting.py @@ -0,0 +1,143 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import Annotated + +from zenml import Model, pipeline, step +from zenml.constants import ( + SORT_BY_LATEST_VERSION_KEY, + SORT_PIPELINES_BY_LATEST_RUN_KEY, +) + + +@step +def first_step() -> Annotated[int, "int_artifact"]: + """Step to return an int.""" + return 3 + + +@pipeline(enable_cache=False) +def first_pipeline(): + """Pipeline definition to test the different sorting mechanisms.""" + _ = first_step() + + +@step +def second_step() -> Annotated[str, "str_artifact"]: + """Step to return a string.""" + return "3" + + +@pipeline(enable_cache=False) +def second_pipeline(): + """Pipeline definition to test the different sorting mechanisms.""" + _ = second_step() + + +def test_sorting_entities(clean_client): + """Testing different sorting functionalities.""" + first_pipeline_first_run = first_pipeline.with_options( + tags=["tag_1", "a", "z"], + model=Model(name="Model2"), + )() + first_pipeline_second_run = first_pipeline.with_options( + tags=["tag_2", "z"], + model=Model(name="Model1", version="second"), + )() + first_pipeline_third_run = first_pipeline.with_options( + tags=["tag_3", "a", "ab"], + model=Model(name="Model1", version="first"), + )() + second_pipeline_first_run = second_pipeline() + + # Sorting runs by the name of the user + clean_client.list_pipeline_runs(sort_by="user") + clean_client.list_pipeline_runs(sort_by="asc:user") + clean_client.list_pipeline_runs(sort_by="desc:user") + + # Sorting runs by the name of the workspace + clean_client.list_pipeline_runs(sort_by="workspace") + clean_client.list_pipeline_runs(sort_by="asc:workspace") + clean_client.list_pipeline_runs(sort_by="desc:workspace") + + # Sorting any taggable entity by tags + results = clean_client.list_pipeline_runs(sort_by="asc:tags") + assert results[0].id == second_pipeline_first_run.id + assert results[1].id == first_pipeline_third_run.id + assert results[-1].id == first_pipeline_second_run.id + clean_client.list_pipeline_runs(sort_by="desc:tags") + + # Sorting pipelines by latest run + results = clean_client.list_pipelines( + sort_by=f"{SORT_PIPELINES_BY_LATEST_RUN_KEY}" + ) + assert results[0].id == first_pipeline_first_run.pipeline.id + assert results[1].id == second_pipeline_first_run.pipeline.id + + results = clean_client.list_pipelines( + sort_by=f"asc:{SORT_PIPELINES_BY_LATEST_RUN_KEY}" + ) + assert results[0].id == first_pipeline_first_run.pipeline.id + assert results[1].id == second_pipeline_first_run.pipeline.id + + results = clean_client.list_pipelines( + sort_by=f"desc:{SORT_PIPELINES_BY_LATEST_RUN_KEY}" + ) + assert results[0].id == second_pipeline_first_run.pipeline.id + assert results[1].id == first_pipeline_first_run.pipeline.id + + # Sorting runs by pipeline name + results = clean_client.list_pipeline_runs(sort_by="asc:name") + assert results[0].name.startswith("first_") + assert results[-1].name.startswith("second_") + + # Sorting runs by stack name + clean_client.list_pipeline_runs(sort_by="asc:stack") + clean_client.list_pipeline_runs(sort_by="desc:stack") + + # Sorting runs by model name + results = clean_client.list_pipeline_runs(sort_by="asc:model") + assert results[0].model_version.model.name == "Model1" + assert results[-1].model_version.model.name == "Model2" + clean_client.list_pipeline_runs(sort_by="desc:model") + + # Sorting runs by model version + results = clean_client.list_pipeline_runs(sort_by="asc:model_version") + + assert results[0].model_version.name == "1" + assert results[-1].model_version.name == "second" + + clean_client.list_pipeline_runs(sort_by="desc:model") + + # Sorting artifacts by latest version + results = clean_client.list_artifacts( + sort_by=f"asc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "int_artifact" + + results = clean_client.list_artifacts( + sort_by=f"desc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "str_artifact" + + # Sorting models by latest version + results = clean_client.list_models( + sort_by=f"asc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "Model2" + + results = clean_client.list_models( + sort_by=f"desc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "Model1"