Skip to content

Commit

Permalink
Fixed and improved sorting (#3266)
Browse files Browse the repository at this point in the history
* first fixes

* final fixes

* adding tag list back

* fixing the failures

* fixing the problem with the tags and adjusting the tests

* adding a small comment

* formatting
  • Loading branch information
bcdurak authored Jan 9, 2025
1 parent 6f337a8 commit c2fe19e
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 44 deletions.
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
)
FILTERING_DATETIME_FORMAT: str = "%Y-%m-%d %H:%M:%S"
SORT_PIPELINES_BY_LATEST_RUN_KEY = "latest_run"
SORT_BY_LATEST_VERSION_KEY = "latest_version"

# Metadata constants
METADATA_ORCHESTRATOR_URL = "orchestrator_url"
Expand Down
60 changes: 32 additions & 28 deletions src/zenml/models/v2/base/scoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def apply_sorting(
UserSchema, getattr(table, "user_id") == UserSchema.id
)

query = query.add_columns(UserSchema.name)

if operand == SorterOps.ASCENDING:
query = query.order_by(asc(column))
else:
Expand Down Expand Up @@ -449,6 +451,8 @@ def apply_sorting(
getattr(table, "workspace_id") == WorkspaceSchema.id,
)

query = query.add_columns(WorkspaceSchema.name)

if operand == SorterOps.ASCENDING:
query = query.order_by(asc(column))
else:
Expand All @@ -470,10 +474,9 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter):
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
"tag",
]

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS,
"tag",
"tags",
]

def apply_filter(
Expand Down Expand Up @@ -540,8 +543,8 @@ def apply_sorting(
"""
sort_by, operand = self.sorting_params

if sort_by == "tag":
from sqlmodel import and_, asc, desc, func
if sort_by == "tags":
from sqlmodel import asc, desc, func, select

from zenml.enums import SorterOps, TaggableResourceTypes
from zenml.zen_stores.schemas import (
Expand All @@ -566,35 +569,36 @@ def apply_sorting(
RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE,
}

query = (
query.outerjoin(
TagResourceSchema,
and_(
table.id == TagResourceSchema.resource_id,
TagResourceSchema.resource_type
== resource_type_mapping[table],
),
sorted_tags = (
select(TagResourceSchema.resource_id, TagSchema.name)
.join(TagSchema, TagResourceSchema.tag_id == TagSchema.id) # type: ignore[arg-type]
.filter(
TagResourceSchema.resource_type # type: ignore[arg-type]
== resource_type_mapping[table]
)
.order_by(
asc(TagResourceSchema.resource_id), asc(TagSchema.name)
)
.outerjoin(TagSchema, TagResourceSchema.tag_id == TagSchema.id)
.group_by(table.id)
).alias("sorted_tags")

tags_subquery = (
select(
sorted_tags.c.resource_id,
func.group_concat(sorted_tags.c.name, ", ").label(
"tags_list"
),
).group_by(sorted_tags.c.resource_id)
).alias("tags_subquery")

query = query.add_columns(tags_subquery.c.tags_list).outerjoin(
tags_subquery, table.id == tags_subquery.c.resource_id
)

# Apply ordering based on the tags list
if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(
func.group_concat(TagSchema.name, ",").label(
"tags_list"
)
)
)
query = query.order_by(asc("tags_list"))
else:
query = query.order_by(
desc(
func.group_concat(TagSchema.name, ",").label(
"tags_list"
)
)
)
query = query.order_by(desc("tags_list"))

return query

Expand Down
88 changes: 86 additions & 2 deletions src/zenml/models/v2/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,21 @@
# permissions and limitations under the License.
"""Models representing artifacts."""

from typing import TYPE_CHECKING, Dict, List, Optional
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Type,
TypeVar,
)
from uuid import UUID

from pydantic import BaseModel, Field

from zenml.constants import STR_FIELD_MAX_LENGTH
from zenml.constants import SORT_BY_LATEST_VERSION_KEY, STR_FIELD_MAX_LENGTH
from zenml.models.v2.base.base import (
BaseDatedResponseBody,
BaseIdentifiedResponse,
Expand All @@ -31,6 +40,11 @@

if TYPE_CHECKING:
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
from zenml.zen_stores.schemas import BaseSchema

AnySchema = TypeVar("AnySchema", bound=BaseSchema)

AnyQuery = TypeVar("AnyQuery", bound=Any)

# ------------------ Request Model ------------------

Expand Down Expand Up @@ -174,3 +188,73 @@ class ArtifactFilter(WorkspaceScopedTaggableFilter):

name: Optional[str] = None
has_custom_name: Optional[bool] = None

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS,
SORT_BY_LATEST_VERSION_KEY,
]

def apply_sorting(
self,
query: AnyQuery,
table: Type["AnySchema"],
) -> AnyQuery:
"""Apply sorting to the query for Artifacts.
Args:
query: The query to which to apply the sorting.
table: The query table.
Returns:
The query with sorting applied.
"""
from sqlmodel import asc, case, col, desc, func, select

from zenml.enums import SorterOps
from zenml.zen_stores.schemas import (
ArtifactSchema,
ArtifactVersionSchema,
)

sort_by, operand = self.sorting_params

if sort_by == SORT_BY_LATEST_VERSION_KEY:
# Subquery to find the latest version per artifact
latest_version_subquery = (
select(
ArtifactSchema.id,
case(
(
func.max(ArtifactVersionSchema.created).is_(None),
ArtifactSchema.created,
),
else_=func.max(ArtifactVersionSchema.created),
).label("latest_version_created"),
)
.outerjoin(
ArtifactVersionSchema,
ArtifactSchema.id == ArtifactVersionSchema.artifact_id, # type: ignore[arg-type]
)
.group_by(col(ArtifactSchema.id))
.subquery()
)

query = query.add_columns(
latest_version_subquery.c.latest_version_created,
).where(ArtifactSchema.id == latest_version_subquery.c.id)

# Apply sorting based on the operand
if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(latest_version_subquery.c.latest_version_created),
asc(ArtifactSchema.id),
)
else:
query = query.order_by(
desc(latest_version_subquery.c.latest_version_created),
desc(ArtifactSchema.id),
)
return query

# For other sorting cases, delegate to the parent class
return super().apply_sorting(query=query, table=table)
83 changes: 81 additions & 2 deletions src/zenml/models/v2/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
# permissions and limitations under the License.
"""Models representing models."""

from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Any, ClassVar, List, Optional, Type, TypeVar
from uuid import UUID

from pydantic import BaseModel, Field

from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
from zenml.constants import (
SORT_BY_LATEST_VERSION_KEY,
STR_FIELD_MAX_LENGTH,
TEXT_FIELD_MAX_LENGTH,
)
from zenml.models.v2.base.scoped import (
WorkspaceScopedRequest,
WorkspaceScopedResponse,
Expand All @@ -32,6 +36,11 @@
if TYPE_CHECKING:
from zenml.model.model import Model
from zenml.models.v2.core.tag import TagResponse
from zenml.zen_stores.schemas import BaseSchema

AnySchema = TypeVar("AnySchema", bound=BaseSchema)

AnyQuery = TypeVar("AnyQuery", bound=Any)

# ------------------ Request Model ------------------

Expand Down Expand Up @@ -320,3 +329,73 @@ class ModelFilter(WorkspaceScopedTaggableFilter):
default=None,
description="Name of the Model",
)

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS,
SORT_BY_LATEST_VERSION_KEY,
]

def apply_sorting(
self,
query: AnyQuery,
table: Type["AnySchema"],
) -> AnyQuery:
"""Apply sorting to the query for Models.
Args:
query: The query to which to apply the sorting.
table: The query table.
Returns:
The query with sorting applied.
"""
from sqlmodel import asc, case, col, desc, func, select

from zenml.enums import SorterOps
from zenml.zen_stores.schemas import (
ModelSchema,
ModelVersionSchema,
)

sort_by, operand = self.sorting_params

if sort_by == SORT_BY_LATEST_VERSION_KEY:
# Subquery to find the latest version per model
latest_version_subquery = (
select(
ModelSchema.id,
case(
(
func.max(ModelVersionSchema.created).is_(None),
ModelSchema.created,
),
else_=func.max(ModelVersionSchema.created),
).label("latest_version_created"),
)
.outerjoin(
ModelVersionSchema,
ModelSchema.id == ModelVersionSchema.model_id, # type: ignore[arg-type]
)
.group_by(col(ModelSchema.id))
.subquery()
)

query = query.add_columns(
latest_version_subquery.c.latest_version_created,
).where(ModelSchema.id == latest_version_subquery.c.id)

# Apply sorting based on the operand
if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(latest_version_subquery.c.latest_version_created),
asc(ModelSchema.id),
)
else:
query = query.order_by(
desc(latest_version_subquery.c.latest_version_created),
desc(ModelSchema.id),
)
return query

# For other sorting cases, delegate to the parent class
return super().apply_sorting(query=query, table=table)
27 changes: 15 additions & 12 deletions src/zenml/models/v2/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def apply_sorting(
# Subquery to find the latest run per pipeline
latest_run_subquery = (
select(
PipelineRunSchema.pipeline_id,
PipelineSchema.id,
case(
(
func.max(PipelineRunSchema.created).is_(None),
Expand All @@ -366,25 +366,28 @@ def apply_sorting(
else_=func.max(PipelineRunSchema.created),
).label("latest_run"),
)
.group_by(col(PipelineRunSchema.pipeline_id))
.outerjoin(
PipelineRunSchema,
PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type]
)
.group_by(col(PipelineSchema.id))
.subquery()
)

# Join the subquery with the pipelines
query = query.outerjoin(
latest_run_subquery,
PipelineSchema.id == latest_run_subquery.c.pipeline_id,
)
query = query.add_columns(
latest_run_subquery.c.latest_run,
).where(PipelineSchema.id == latest_run_subquery.c.id)

if operand == SorterOps.ASCENDING:
query = query.order_by(
asc(latest_run_subquery.c.latest_run)
).order_by(col(PipelineSchema.id))
asc(latest_run_subquery.c.latest_run),
asc(PipelineSchema.id),
)
else:
query = query.order_by(
desc(latest_run_subquery.c.latest_run)
).order_by(col(PipelineSchema.id))

desc(latest_run_subquery.c.latest_run),
desc(PipelineSchema.id),
)
return query
else:
return super().apply_sorting(query=query, table=table)
2 changes: 2 additions & 0 deletions src/zenml/models/v2/core/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,8 @@ def apply_sorting(
else:
return super().apply_sorting(query=query, table=table)

query = query.add_columns(column)

if operand == SorterOps.ASCENDING:
query = query.order_by(asc(column))
else:
Expand Down
7 changes: 7 additions & 0 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@
)
from sqlalchemy.orm import Mapped, noload
from sqlalchemy.util import immutabledict

# Important to note: The select function of SQLModel works slightly differently
# from the select function of sqlalchemy. If you input only one entity on the
# select function of SQLModel, it automatically maps it to a SelectOfScalar.
# As a result, it will not return a tuple as a result, but the first entity in
# the tuple. While this is convenient in most cases, in unique cases like using
# the "add_columns" functionality, one might encounter unexpected results.
from sqlmodel import (
Session,
SQLModel,
Expand Down
Loading

0 comments on commit c2fe19e

Please sign in to comment.