From 1db4a837407eaef050d6b82ba02072bed045dfd3 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Sun, 15 Dec 2024 22:46:23 +0100 Subject: [PATCH 1/7] first fixes --- src/zenml/constants.py | 1 + src/zenml/models/v2/base/scoped.py | 86 ----------- src/zenml/models/v2/core/artifact.py | 86 ++++++++++- src/zenml/models/v2/core/model.py | 81 ++++++++++- src/zenml/models/v2/core/pipeline.py | 15 +- .../functional/models/test_sorting.py | 133 ++++++++++++++++++ 6 files changed, 305 insertions(+), 97 deletions(-) create mode 100644 tests/integration/functional/models/test_sorting.py diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 183b1acce16..a7a13edb614 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -429,6 +429,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ) FILTERING_DATETIME_FORMAT: str = "%Y-%m-%d %H:%M:%S" SORT_PIPELINES_BY_LATEST_RUN_KEY = "latest_run" +SORT_BY_LATEST_VERSION_KEY = "latest_version" # Metadata constants METADATA_ORCHESTRATOR_URL = "orchestrator_url" diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index f5267f4840d..573116a19ea 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -466,16 +466,6 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): description="Tag to apply to the filter query.", default=None ) - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, - "tag", - ] - - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ - *WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS, - "tag", - ] - def apply_filter( self, query: AnyQuery, @@ -524,79 +514,3 @@ def get_custom_filters( ) return custom_filters - - def apply_sorting( - self, - query: AnyQuery, - table: Type["AnySchema"], - ) -> AnyQuery: - """Apply sorting to the query. - - Args: - query: The query to which to apply the sorting. - table: The query table. - - Returns: - The query with sorting applied. - """ - sort_by, operand = self.sorting_params - - if sort_by == "tag": - from sqlmodel import and_, asc, desc, func - - from zenml.enums import SorterOps, TaggableResourceTypes - from zenml.zen_stores.schemas import ( - ArtifactSchema, - ArtifactVersionSchema, - ModelSchema, - ModelVersionSchema, - PipelineRunSchema, - PipelineSchema, - RunTemplateSchema, - TagResourceSchema, - TagSchema, - ) - - resource_type_mapping = { - ArtifactSchema: TaggableResourceTypes.ARTIFACT, - ArtifactVersionSchema: TaggableResourceTypes.ARTIFACT_VERSION, - ModelSchema: TaggableResourceTypes.MODEL, - ModelVersionSchema: TaggableResourceTypes.MODEL_VERSION, - PipelineSchema: TaggableResourceTypes.PIPELINE, - PipelineRunSchema: TaggableResourceTypes.PIPELINE_RUN, - RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE, - } - - query = ( - query.outerjoin( - TagResourceSchema, - and_( - table.id == TagResourceSchema.resource_id, - TagResourceSchema.resource_type - == resource_type_mapping[table], - ), - ) - .outerjoin(TagSchema, TagResourceSchema.tag_id == TagSchema.id) - .group_by(table.id) - ) - - if operand == SorterOps.ASCENDING: - query = query.order_by( - asc( - func.group_concat(TagSchema.name, ",").label( - "tags_list" - ) - ) - ) - else: - query = query.order_by( - desc( - func.group_concat(TagSchema.name, ",").label( - "tags_list" - ) - ) - ) - - return query - - return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/artifact.py b/src/zenml/models/v2/core/artifact.py index c62a7cee1a5..8cebcba6e42 100644 --- a/src/zenml/models/v2/core/artifact.py +++ b/src/zenml/models/v2/core/artifact.py @@ -13,12 +13,21 @@ # permissions and limitations under the License. """Models representing artifacts.""" -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, +) from uuid import UUID from pydantic import BaseModel, Field -from zenml.constants import STR_FIELD_MAX_LENGTH +from zenml.constants import SORT_BY_LATEST_VERSION_KEY, STR_FIELD_MAX_LENGTH from zenml.models.v2.base.base import ( BaseDatedResponseBody, BaseIdentifiedResponse, @@ -31,6 +40,11 @@ if TYPE_CHECKING: from zenml.models.v2.core.artifact_version import ArtifactVersionResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + +AnyQuery = TypeVar("AnyQuery", bound=Any) # ------------------ Request Model ------------------ @@ -174,3 +188,71 @@ class ArtifactFilter(WorkspaceScopedTaggableFilter): name: Optional[str] = None has_custom_name: Optional[bool] = None + + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + SORT_BY_LATEST_VERSION_KEY, + ] + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query for Artifacts. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, case, desc, func, select + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import ( + ArtifactSchema, + ArtifactVersionSchema, + ) + + sort_by, operand = self.sorting_params + + if sort_by == SORT_BY_LATEST_VERSION_KEY: + # Subquery to find the latest version per artifact + latest_version_subquery = ( + select( + ArtifactVersionSchema.artifact_id, + case( + ( + func.max(ArtifactVersionSchema.created).is_(None), + ArtifactSchema.created, + ), + else_=func.max(ArtifactVersionSchema.created), + ).label("latest_version_created"), + ) + .group_by(ArtifactVersionSchema.artifact_id) + .subquery() + ) + + # Join the subquery with the main artifacts query + query = query.outerjoin( + latest_version_subquery, + ArtifactSchema.id == latest_version_subquery.c.artifact_id, + ) + + # Apply sorting based on the operand + if operand == SorterOps.ASCENDING: + query = query.order_by( + asc(latest_version_subquery.c.latest_version_created), + asc(ArtifactSchema.id), + ) + else: + query = query.order_by( + desc(latest_version_subquery.c.latest_version_created), + desc(ArtifactSchema.id), + ) + return query + + # For other sorting cases, delegate to the parent class + return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index 0b5272ab7e6..e8b781ecd2c 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -13,12 +13,16 @@ # permissions and limitations under the License. """Models representing models.""" -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Type, TypeVar from uuid import UUID from pydantic import BaseModel, Field -from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH +from zenml.constants import ( + SORT_BY_LATEST_VERSION_KEY, + STR_FIELD_MAX_LENGTH, + TEXT_FIELD_MAX_LENGTH, +) from zenml.models.v2.base.scoped import ( WorkspaceScopedRequest, WorkspaceScopedResponse, @@ -32,6 +36,11 @@ if TYPE_CHECKING: from zenml.model.model import Model from zenml.models.v2.core.tag import TagResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + +AnyQuery = TypeVar("AnyQuery", bound=Any) # ------------------ Request Model ------------------ @@ -320,3 +329,71 @@ class ModelFilter(WorkspaceScopedTaggableFilter): default=None, description="Name of the Model", ) + + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + SORT_BY_LATEST_VERSION_KEY, + ] + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query for Models. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, case, desc, func, select + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import ( + ModelSchema, + ModelVersionSchema, + ) + + sort_by, operand = self.sorting_params + + if sort_by == SORT_BY_LATEST_VERSION_KEY: + # Subquery to find the latest version per model + latest_version_subquery = ( + select( + ModelVersionSchema.model_id, + case( + ( + func.max(ModelVersionSchema.created).is_(None), + ModelSchema.created, + ), + else_=func.max(ModelVersionSchema.created), + ).label("latest_version_created"), + ) + .group_by(ModelVersionSchema.model_id) + .subquery() + ) + + # Join the subquery with the main artifacts query + query = query.outerjoin( + latest_version_subquery, + ModelSchema.id == latest_version_subquery.c.model_id, + ) + + # Apply sorting based on the operand + if operand == SorterOps.ASCENDING: + query = query.order_by( + asc(latest_version_subquery.c.latest_version_created), + asc(ModelSchema.id), + ) + else: + query = query.order_by( + desc(latest_version_subquery.c.latest_version_created), + desc(ModelSchema.id), + ) + return query + + # For other sorting cases, delegate to the parent class + return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 199e9cce959..707e9a86104 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -346,7 +346,7 @@ def apply_sorting( Returns: The query with sorting applied. """ - from sqlmodel import asc, case, col, desc, func, select + from sqlmodel import asc, case, desc, func, select from zenml.enums import SorterOps from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema @@ -366,7 +366,7 @@ def apply_sorting( else_=func.max(PipelineRunSchema.created), ).label("latest_run"), ) - .group_by(col(PipelineRunSchema.pipeline_id)) + .group_by(PipelineRunSchema.pipeline_id) .subquery() ) @@ -378,13 +378,14 @@ def apply_sorting( if operand == SorterOps.ASCENDING: query = query.order_by( - asc(latest_run_subquery.c.latest_run) - ).order_by(col(PipelineSchema.id)) + asc(latest_run_subquery.c.latest_run), + asc(PipelineSchema.id), + ) else: query = query.order_by( - desc(latest_run_subquery.c.latest_run) - ).order_by(col(PipelineSchema.id)) - + desc(latest_run_subquery.c.latest_run), + desc(PipelineSchema.id), + ) return query else: return super().apply_sorting(query=query, table=table) diff --git a/tests/integration/functional/models/test_sorting.py b/tests/integration/functional/models/test_sorting.py new file mode 100644 index 00000000000..67f17ce9b39 --- /dev/null +++ b/tests/integration/functional/models/test_sorting.py @@ -0,0 +1,133 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import Annotated + +from zenml import Model, pipeline, step +from zenml.constants import ( + SORT_BY_LATEST_VERSION_KEY, + SORT_PIPELINES_BY_LATEST_RUN_KEY, +) + + +@step +def first_step() -> Annotated[int, "int_artifact"]: + """Step to return an int.""" + return 3 + + +@pipeline(enable_cache=False) +def first_pipeline(): + """Pipeline definition to test the different sorting mechanisms.""" + _ = first_step() + + +@step +def second_step() -> Annotated[str, "str_artifact"]: + """Step to return a string.""" + return "3" + + +@pipeline(enable_cache=False) +def second_pipeline(): + """Pipeline definition to test the different sorting mechanisms.""" + _ = second_step() + + +def test_sorting_entities(clean_client): + """Testing different sorting functionalities.""" + first_pipeline_first_run = first_pipeline.with_options( + model=Model(name="Model2"), + )() + _ = first_pipeline.with_options( + model=Model(name="Model1", version="second"), + )() + _ = first_pipeline.with_options( + model=Model(name="Model1", version="first"), + )() + second_pipeline_first_run = second_pipeline() + + # Sorting runs by the name of the user + clean_client.list_pipeline_runs(sort_by="user") + clean_client.list_pipeline_runs(sort_by="asc:user") + clean_client.list_pipeline_runs(sort_by="desc:user") + + # Sorting runs by the name of the workspace + clean_client.list_pipeline_runs(sort_by="workspace") + clean_client.list_pipeline_runs(sort_by="asc:workspace") + clean_client.list_pipeline_runs(sort_by="desc:workspace") + + # Sorting pipelines by latest run + results = clean_client.list_pipelines( + sort_by=f"{SORT_PIPELINES_BY_LATEST_RUN_KEY}" + ) + assert results[0].id == first_pipeline_first_run.pipeline.id + assert results[1].id == second_pipeline_first_run.pipeline.id + + results = clean_client.list_pipelines( + sort_by=f"asc:{SORT_PIPELINES_BY_LATEST_RUN_KEY}" + ) + assert results[0].id == first_pipeline_first_run.pipeline.id + assert results[1].id == second_pipeline_first_run.pipeline.id + + results = clean_client.list_pipelines( + sort_by=f"desc:{SORT_PIPELINES_BY_LATEST_RUN_KEY}" + ) + assert results[0].id == second_pipeline_first_run.pipeline.id + assert results[1].id == first_pipeline_first_run.pipeline.id + + # Sorting runs by pipeline name + results = clean_client.list_pipeline_runs(sort_by="asc:name") + assert results[0].name.startswith("first_") + assert results[-1].name.startswith("second_") + + # Sorting runs by stack name + clean_client.list_pipeline_runs(sort_by="asc:stack") + clean_client.list_pipeline_runs(sort_by="desc:stack") + + # Sorting runs by model name + results = clean_client.list_pipeline_runs(sort_by="asc:model") + assert results[0].model_version.model.name == "Model1" + assert results[-1].model_version.model.name == "Model2" + clean_client.list_pipeline_runs(sort_by="desc:model") + + # Sorting runs by model version + results = clean_client.list_pipeline_runs(sort_by="asc:model_version") + + assert results[0].model_version.name == "1" + assert results[-1].model_version.name == "second" + + clean_client.list_pipeline_runs(sort_by="desc:model") + + # Sorting artifacts by latest version + results = clean_client.list_artifacts( + sort_by=f"asc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "int_artifact" + + results = clean_client.list_artifacts( + sort_by=f"desc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "str_artifact" + + # Sorting models by latest version + results = clean_client.list_models( + sort_by=f"asc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "Model2" + + results = clean_client.list_models( + sort_by=f"desc:{SORT_BY_LATEST_VERSION_KEY}" + ) + assert results[0].name == "Model1" From 37d2f0c761cee88e77c2f20494057927bcbefa86 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 16 Dec 2024 01:12:57 +0100 Subject: [PATCH 2/7] final fixes --- src/zenml/models/v2/base/scoped.py | 4 ++++ src/zenml/models/v2/core/artifact.py | 16 +++++++++------- src/zenml/models/v2/core/model.py | 16 +++++++++------- src/zenml/models/v2/core/pipeline.py | 16 +++++++++------- src/zenml/models/v2/core/pipeline_run.py | 2 ++ 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 573116a19ea..00c41769882 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -245,6 +245,8 @@ def apply_sorting( UserSchema, getattr(table, "user_id") == UserSchema.id ) + query = query.add_columns(UserSchema.name) + if operand == SorterOps.ASCENDING: query = query.order_by(asc(column)) else: @@ -449,6 +451,8 @@ def apply_sorting( getattr(table, "workspace_id") == WorkspaceSchema.id, ) + query = query.add_columns(WorkspaceSchema.name) + if operand == SorterOps.ASCENDING: query = query.order_by(asc(column)) else: diff --git a/src/zenml/models/v2/core/artifact.py b/src/zenml/models/v2/core/artifact.py index 8cebcba6e42..bd87395b173 100644 --- a/src/zenml/models/v2/core/artifact.py +++ b/src/zenml/models/v2/core/artifact.py @@ -208,7 +208,7 @@ def apply_sorting( Returns: The query with sorting applied. """ - from sqlmodel import asc, case, desc, func, select + from sqlmodel import asc, case, col, desc, func, select from zenml.enums import SorterOps from zenml.zen_stores.schemas import ( @@ -222,7 +222,7 @@ def apply_sorting( # Subquery to find the latest version per artifact latest_version_subquery = ( select( - ArtifactVersionSchema.artifact_id, + ArtifactSchema.id, case( ( func.max(ArtifactVersionSchema.created).is_(None), @@ -231,14 +231,16 @@ def apply_sorting( else_=func.max(ArtifactVersionSchema.created), ).label("latest_version_created"), ) - .group_by(ArtifactVersionSchema.artifact_id) + .outerjoin( + ArtifactVersionSchema, + ArtifactSchema.id == ArtifactVersionSchema.artifact_id, # type: ignore[arg-type] + ) + .group_by(col(ArtifactSchema.id)) .subquery() ) - # Join the subquery with the main artifacts query - query = query.outerjoin( - latest_version_subquery, - ArtifactSchema.id == latest_version_subquery.c.artifact_id, + query = query.add_columns( + latest_version_subquery.c.latest_version_created, ) # Apply sorting based on the operand diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index e8b781ecd2c..5669a9d2237 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -349,7 +349,7 @@ def apply_sorting( Returns: The query with sorting applied. """ - from sqlmodel import asc, case, desc, func, select + from sqlmodel import asc, case, col, desc, func, select from zenml.enums import SorterOps from zenml.zen_stores.schemas import ( @@ -363,7 +363,7 @@ def apply_sorting( # Subquery to find the latest version per model latest_version_subquery = ( select( - ModelVersionSchema.model_id, + ModelSchema.id, case( ( func.max(ModelVersionSchema.created).is_(None), @@ -372,14 +372,16 @@ def apply_sorting( else_=func.max(ModelVersionSchema.created), ).label("latest_version_created"), ) - .group_by(ModelVersionSchema.model_id) + .outerjoin( + ModelVersionSchema, + ModelSchema.id == ModelVersionSchema.model_id, # type: ignore[arg-type] + ) + .group_by(col(ModelSchema.id)) .subquery() ) - # Join the subquery with the main artifacts query - query = query.outerjoin( - latest_version_subquery, - ModelSchema.id == latest_version_subquery.c.model_id, + query = query.add_columns( + latest_version_subquery.c.latest_version_created, ) # Apply sorting based on the operand diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 707e9a86104..7cfd4109a95 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -346,7 +346,7 @@ def apply_sorting( Returns: The query with sorting applied. """ - from sqlmodel import asc, case, desc, func, select + from sqlmodel import asc, case, col, desc, func, select from zenml.enums import SorterOps from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema @@ -357,7 +357,7 @@ def apply_sorting( # Subquery to find the latest run per pipeline latest_run_subquery = ( select( - PipelineRunSchema.pipeline_id, + PipelineSchema.id, case( ( func.max(PipelineRunSchema.created).is_(None), @@ -366,14 +366,16 @@ def apply_sorting( else_=func.max(PipelineRunSchema.created), ).label("latest_run"), ) - .group_by(PipelineRunSchema.pipeline_id) + .outerjoin( + PipelineRunSchema, + PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type] + ) + .group_by(col(PipelineSchema.id)) .subquery() ) - # Join the subquery with the pipelines - query = query.outerjoin( - latest_run_subquery, - PipelineSchema.id == latest_run_subquery.c.pipeline_id, + query = query.add_columns( + latest_run_subquery.c.latest_run, ) if operand == SorterOps.ASCENDING: diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 3a22f642953..740fbc6711e 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -982,6 +982,8 @@ def apply_sorting( else: return super().apply_sorting(query=query, table=table) + query = query.add_columns(column) + if operand == SorterOps.ASCENDING: query = query.order_by(asc(column)) else: From b6d37e16c9f46b79dd295892459bda6963322bb2 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 16 Dec 2024 02:07:32 +0100 Subject: [PATCH 3/7] adding tag list back --- src/zenml/models/v2/base/scoped.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 00c41769882..56ea3202347 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -466,6 +466,11 @@ def apply_sorting( class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): """Model to enable advanced scoping with workspace and tagging.""" + FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, + "tag", + ] + tag: Optional[str] = Field( description="Tag to apply to the filter query.", default=None ) From 47ed477cb09a9e6eb4a4b077785067f0fe9b23b0 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Mon, 16 Dec 2024 02:27:16 +0100 Subject: [PATCH 4/7] fixing the failures --- src/zenml/models/v2/core/artifact.py | 2 +- src/zenml/models/v2/core/model.py | 2 +- src/zenml/models/v2/core/pipeline.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zenml/models/v2/core/artifact.py b/src/zenml/models/v2/core/artifact.py index bd87395b173..e36b602b6e3 100644 --- a/src/zenml/models/v2/core/artifact.py +++ b/src/zenml/models/v2/core/artifact.py @@ -241,7 +241,7 @@ def apply_sorting( query = query.add_columns( latest_version_subquery.c.latest_version_created, - ) + ).where(ArtifactSchema.id == latest_version_subquery.c.id) # Apply sorting based on the operand if operand == SorterOps.ASCENDING: diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index 5669a9d2237..bb341c1d5a1 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -382,7 +382,7 @@ def apply_sorting( query = query.add_columns( latest_version_subquery.c.latest_version_created, - ) + ).where(ModelSchema.id == latest_version_subquery.c.id) # Apply sorting based on the operand if operand == SorterOps.ASCENDING: diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 7cfd4109a95..43ea80dc7e0 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -376,7 +376,7 @@ def apply_sorting( query = query.add_columns( latest_run_subquery.c.latest_run, - ) + ).where(PipelineSchema.id == latest_run_subquery.c.id) if operand == SorterOps.ASCENDING: query = query.order_by( From 7cfed6798778d79218f3f01a0d38f6b932509a33 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 7 Jan 2025 14:46:57 +0100 Subject: [PATCH 5/7] fixing the problem with the tags and adjusting the tests --- src/zenml/models/v2/base/scoped.py | 89 ++++++++++++++++++- .../functional/models/test_sorting.py | 14 ++- 2 files changed, 97 insertions(+), 6 deletions(-) diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 56ea3202347..8bd8f1415d5 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -466,14 +466,18 @@ def apply_sorting( class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): """Model to enable advanced scoping with workspace and tagging.""" + tag: Optional[str] = Field( + description="Tag to apply to the filter query.", default=None + ) + FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "tag", ] - - tag: Optional[str] = Field( - description="Tag to apply to the filter query.", default=None - ) + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS, + "tags", + ] def apply_filter( self, @@ -523,3 +527,80 @@ def get_custom_filters( ) return custom_filters + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + sort_by, operand = self.sorting_params + + if sort_by == "tags": + from sqlmodel import asc, desc, func, select + + from zenml.enums import SorterOps, TaggableResourceTypes + from zenml.zen_stores.schemas import ( + ArtifactSchema, + ArtifactVersionSchema, + ModelSchema, + ModelVersionSchema, + PipelineRunSchema, + PipelineSchema, + RunTemplateSchema, + TagResourceSchema, + TagSchema, + ) + + resource_type_mapping = { + ArtifactSchema: TaggableResourceTypes.ARTIFACT, + ArtifactVersionSchema: TaggableResourceTypes.ARTIFACT_VERSION, + ModelSchema: TaggableResourceTypes.MODEL, + ModelVersionSchema: TaggableResourceTypes.MODEL_VERSION, + PipelineSchema: TaggableResourceTypes.PIPELINE, + PipelineRunSchema: TaggableResourceTypes.PIPELINE_RUN, + RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE, + } + + sorted_tags = ( + select(TagResourceSchema.resource_id, TagSchema.name) + .join(TagSchema, TagResourceSchema.tag_id == TagSchema.id) # type: ignore[arg-type] + .filter( + TagResourceSchema.resource_type # type: ignore[arg-type] + == resource_type_mapping[table] + ) + .order_by( + asc(TagResourceSchema.resource_id), asc(TagSchema.name) + ) + ).alias("sorted_tags") + + tags_subquery = ( + select( + sorted_tags.c.resource_id, + func.group_concat(sorted_tags.c.name, ", ").label( + "tags_list" + ), + ).group_by(sorted_tags.c.resource_id) + ).alias("tags_subquery") + + query = query.add_columns(tags_subquery.c.tags_list).outerjoin( + tags_subquery, table.id == tags_subquery.c.resource_id + ) + + # Apply ordering based on the tags list + if operand == SorterOps.ASCENDING: + query = query.order_by(asc("tags_list")) + else: + query = query.order_by(desc("tags_list")) + + return query + + return super().apply_sorting(query=query, table=table) diff --git a/tests/integration/functional/models/test_sorting.py b/tests/integration/functional/models/test_sorting.py index 67f17ce9b39..138468fe867 100644 --- a/tests/integration/functional/models/test_sorting.py +++ b/tests/integration/functional/models/test_sorting.py @@ -48,12 +48,15 @@ def second_pipeline(): def test_sorting_entities(clean_client): """Testing different sorting functionalities.""" first_pipeline_first_run = first_pipeline.with_options( + tags=["tag_1", "a", "z"], model=Model(name="Model2"), )() - _ = first_pipeline.with_options( + first_pipeline_second_run = first_pipeline.with_options( + tags=["tag_2", "z"], model=Model(name="Model1", version="second"), )() - _ = first_pipeline.with_options( + first_pipeline_third_run = first_pipeline.with_options( + tags=["tag_3", "a", "ab"], model=Model(name="Model1", version="first"), )() second_pipeline_first_run = second_pipeline() @@ -68,6 +71,13 @@ def test_sorting_entities(clean_client): clean_client.list_pipeline_runs(sort_by="asc:workspace") clean_client.list_pipeline_runs(sort_by="desc:workspace") + # Sorting any taggable entity by tags + results = clean_client.list_pipeline_runs(sort_by="asc:tags") + assert results[0].id == second_pipeline_first_run.id + assert results[1].id == first_pipeline_third_run.id + assert results[-1].id == first_pipeline_second_run.id + clean_client.list_pipeline_runs(sort_by="desc:tags") + # Sorting pipelines by latest run results = clean_client.list_pipelines( sort_by=f"{SORT_PIPELINES_BY_LATEST_RUN_KEY}" From abad2ec9b968495ed68bf574f3ed0116af1481c1 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 7 Jan 2025 16:08:28 +0100 Subject: [PATCH 6/7] adding a small comment --- src/zenml/zen_stores/sql_zen_store.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 19bdda8b28f..a253af777bd 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -64,6 +64,12 @@ ) from sqlalchemy.orm import Mapped, noload from sqlalchemy.util import immutabledict +# Important to note: The select function of SQLModel works slightly differently +# from the select function of sqlalchemy. If you input only one entity on the +# select function of SQLModel, it automatically maps it to a SelectOfScalar. +# As a result, it will not return a tuple as a result, but the first entity in +# the tuple. While this is convenient in most cases, in unique cases like using +# the "add_columns" functionality, one might encounter unexpected results. from sqlmodel import ( Session, SQLModel, @@ -1048,7 +1054,7 @@ def filter_and_paginate( item_schemas = custom_fetch_result # select the items in the current page item_schemas = item_schemas[ - filter_model.offset : filter_model.offset + filter_model.size + filter_model.offset: filter_model.offset + filter_model.size ] else: item_schemas = session.exec( From 0c9e8064823fc77a8ebf11354a57d576c92d5de4 Mon Sep 17 00:00:00 2001 From: Baris Can Durak Date: Tue, 7 Jan 2025 16:08:58 +0100 Subject: [PATCH 7/7] formatting --- src/zenml/zen_stores/sql_zen_store.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index a253af777bd..1483a782de7 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -64,6 +64,7 @@ ) from sqlalchemy.orm import Mapped, noload from sqlalchemy.util import immutabledict + # Important to note: The select function of SQLModel works slightly differently # from the select function of sqlalchemy. If you input only one entity on the # select function of SQLModel, it automatically maps it to a SelectOfScalar. @@ -1054,7 +1055,7 @@ def filter_and_paginate( item_schemas = custom_fetch_result # select the items in the current page item_schemas = item_schemas[ - filter_model.offset: filter_model.offset + filter_model.size + filter_model.offset : filter_model.offset + filter_model.size ] else: item_schemas = session.exec(