Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the efficiency of some SQL queries #3263

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions src/zenml/zen_stores/schemas/artifact_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@
from zenml.zen_stores.schemas.model_schemas import (
ModelVersionArtifactSchema,
)
from zenml.zen_stores.schemas.run_metadata_schemas import (
RunMetadataResourceSchema,
)
from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
from zenml.zen_stores.schemas.tag_schemas import TagSchema


class ArtifactSchema(NamedSchema, table=True):
Expand All @@ -82,12 +80,13 @@ class ArtifactSchema(NamedSchema, table=True):
back_populates="artifact",
sa_relationship_kwargs={"cascade": "delete"},
)
tags: List["TagResourceSchema"] = Relationship(
back_populates="artifact",
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
cascade="delete",
overlaps="tags",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
viewonly=True,
),
)

Expand Down Expand Up @@ -136,7 +135,7 @@ def to_model(
body = ArtifactResponseBody(
created=self.created,
updated=self.updated,
tags=[t.tag.to_model() for t in self.tags],
tags=[tag.to_model() for tag in self.tags],
latest_version_name=latest_name,
latest_version_id=latest_id,
)
Expand Down Expand Up @@ -192,12 +191,13 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True):
uri: str = Field(sa_column=Column(TEXT, nullable=False))
materializer: str = Field(sa_column=Column(TEXT, nullable=False))
data_type: str = Field(sa_column=Column(TEXT, nullable=False))
tags: List["TagResourceSchema"] = Relationship(
back_populates="artifact_version",
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
cascade="delete",
overlaps="tags",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
viewonly=True,
),
)
save_type: str = Field(sa_column=Column(TEXT, nullable=False))
Expand Down Expand Up @@ -244,12 +244,12 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True):
workspace: "WorkspaceSchema" = Relationship(
back_populates="artifact_versions"
)
run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship(
back_populates="artifact_versions",
run_metadata: List["RunMetadataSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)",
cascade="delete",
overlaps="run_metadata_resources",
secondary="run_metadata_resource",
primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)",
secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
viewonly=True,
),
)
output_of_step_runs: List["StepRunOutputArtifactSchema"] = Relationship(
Expand Down Expand Up @@ -365,7 +365,7 @@ def to_model(
data_type=data_type,
created=self.created,
updated=self.updated,
tags=[t.tag.to_model() for t in self.tags],
tags=[tag.to_model() for tag in self.tags],
producer_pipeline_run_id=producer_pipeline_run_id,
save_type=ArtifactSaveType(self.save_type),
artifact_store_id=self.artifact_store_id,
Expand Down
42 changes: 21 additions & 21 deletions src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@
from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
from zenml.zen_stores.schemas.run_metadata_schemas import (
RunMetadataResourceSchema,
)
from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
from zenml.zen_stores.schemas.tag_schemas import TagSchema
from zenml.zen_stores.schemas.user_schemas import UserSchema
from zenml.zen_stores.schemas.utils import (
RunMetadataInterface,
Expand Down Expand Up @@ -114,12 +112,13 @@ class ModelSchema(NamedSchema, table=True):
save_models_to_registry: bool = Field(
sa_column=Column(BOOLEAN, nullable=False)
)
tags: List["TagResourceSchema"] = Relationship(
back_populates="model",
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
cascade="delete",
overlaps="tags",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
viewonly=True,
),
)
model_versions: List["ModelVersionSchema"] = Relationship(
Expand Down Expand Up @@ -168,7 +167,7 @@ def to_model(
Returns:
The created `ModelResponse`.
"""
tags = [t.tag.to_model() for t in self.tags]
tags = [tag.to_model() for tag in self.tags]

if self.model_versions:
version_numbers = [mv.number for mv in self.model_versions]
Expand Down Expand Up @@ -299,12 +298,13 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
back_populates="model_version",
sa_relationship_kwargs={"cascade": "delete"},
)
tags: List["TagResourceSchema"] = Relationship(
back_populates="model_version",
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
cascade="delete",
overlaps="tags",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
viewonly=True,
),
)

Expand All @@ -316,12 +316,12 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
description: str = Field(sa_column=Column(TEXT, nullable=True))
stage: str = Field(sa_column=Column(TEXT, nullable=True))

run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship(
back_populates="model_versions",
run_metadata: List["RunMetadataSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)",
cascade="delete",
overlaps="run_metadata_resources",
secondary="run_metadata_resource",
primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)",
secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
viewonly=True,
),
)
pipeline_runs: List["PipelineRunSchema"] = Relationship(
Expand Down Expand Up @@ -471,7 +471,7 @@ def to_model(
data_artifact_ids=data_artifact_ids,
deployment_artifact_ids=deployment_artifact_ids,
pipeline_run_ids=pipeline_run_ids,
tags=[t.tag.to_model() for t in self.tags],
tags=[tag.to_model() for tag in self.tags],
)

return ModelVersionResponse(
Expand Down
39 changes: 19 additions & 20 deletions src/zenml/zen_stores/schemas/pipeline_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@
ModelVersionPipelineRunSchema,
ModelVersionSchema,
)
from zenml.zen_stores.schemas.run_metadata_schemas import (
RunMetadataResourceSchema,
)
from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
from zenml.zen_stores.schemas.service_schemas import ServiceSchema
from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema
from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
from zenml.zen_stores.schemas.tag_schemas import TagSchema


class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
Expand Down Expand Up @@ -140,12 +138,12 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
)
workspace: "WorkspaceSchema" = Relationship(back_populates="runs")
user: Optional["UserSchema"] = Relationship(back_populates="runs")
run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship(
back_populates="pipeline_runs",
run_metadata: List["RunMetadataSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)",
cascade="delete",
overlaps="run_metadata_resources",
secondary="run_metadata_resource",
primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)",
secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)",
viewonly=True,
),
)
logs: Optional["LogsSchema"] = Relationship(
Expand Down Expand Up @@ -215,11 +213,13 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
services: List["ServiceSchema"] = Relationship(
back_populates="pipeline_run",
)
tags: List["TagResourceSchema"] = Relationship(
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.PIPELINE_RUN.value}', foreign(TagResourceSchema.resource_id)==PipelineRunSchema.id)",
cascade="delete",
overlaps="tags",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE_RUN.value}', foreign(TagResourceSchema.resource_id)==PipelineRunSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
viewonly=True,
),
)

Expand Down Expand Up @@ -291,12 +291,6 @@ def to_model(
Raises:
RuntimeError: if the model creation fails.
"""
orchestrator_environment = (
json.loads(self.orchestrator_environment)
if self.orchestrator_environment
else {}
)

if self.deployment is not None:
deployment = self.deployment.to_model(include_metadata=True)

Expand Down Expand Up @@ -377,6 +371,11 @@ def to_model(
# in the response -> We need to reset the metadata here
step.metadata = None

orchestrator_environment = (
json.loads(self.orchestrator_environment)
if self.orchestrator_environment
else {}
)
metadata = PipelineRunResponseMetadata(
workspace=self.workspace.to_model(),
run_metadata=self.fetch_metadata(),
Expand Down Expand Up @@ -405,7 +404,7 @@ def to_model(

resources = PipelineRunResponseResources(
model_version=model_version,
tags=[t.tag.to_model() for t in self.tags],
tags=[tag.to_model() for tag in self.tags],
)

return PipelineRunResponse(
Expand Down
14 changes: 8 additions & 6 deletions src/zenml/zen_stores/schemas/pipeline_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema
from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
from zenml.zen_stores.schemas.tag_schemas import TagSchema


class PipelineSchema(NamedSchema, table=True):
Expand Down Expand Up @@ -95,11 +95,13 @@ class PipelineSchema(NamedSchema, table=True):
deployments: List["PipelineDeploymentSchema"] = Relationship(
back_populates="pipeline",
)
tags: List["TagResourceSchema"] = Relationship(
tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.PIPELINE.value}', foreign(TagResourceSchema.resource_id)==PipelineSchema.id)",
cascade="delete",
overlaps="tags",
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE.value}', foreign(TagResourceSchema.resource_id)==PipelineSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
viewonly=True,
),
)

Expand Down Expand Up @@ -162,7 +164,7 @@ def to_model(
latest_run_user=latest_run_user.to_model()
if latest_run_user
else None,
tags=[t.tag.to_model() for t in self.tags],
tags=[tag.to_model() for tag in self.tags],
)

return PipelineResponse(
Expand Down
47 changes: 9 additions & 38 deletions src/zenml/zen_stores/schemas/run_metadata_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,19 @@
# permissions and limitations under the License.
"""SQLModel implementation of pipeline run metadata tables."""

from typing import TYPE_CHECKING, List, Optional
from typing import List, Optional
from uuid import UUID, uuid4

from sqlalchemy import TEXT, VARCHAR, Column
from sqlalchemy import TEXT, VARCHAR, Column, Index
from sqlmodel import Field, Relationship, SQLModel

from zenml.enums import MetadataResourceTypes
from zenml.zen_stores.schemas.base_schemas import BaseSchema
from zenml.zen_stores.schemas.component_schemas import StackComponentSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema
from zenml.zen_stores.schemas.user_schemas import UserSchema
from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema

if TYPE_CHECKING:
from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema
from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema


class RunMetadataSchema(BaseSchema, table=True):
"""SQL Model for run metadata."""
Expand Down Expand Up @@ -93,6 +87,13 @@ class RunMetadataResourceSchema(SQLModel, table=True):
"""Table for linking resources to run metadata entries."""

__tablename__ = "run_metadata_resource"
__table_args__ = (
Index(
"run_metadata_resource_index",
"resource_id",
"resource_type",
),
)

id: UUID = Field(default_factory=uuid4, primary_key=True)
resource_id: UUID
Expand All @@ -108,33 +109,3 @@ class RunMetadataResourceSchema(SQLModel, table=True):

# Relationship back to the base metadata table
run_metadata: RunMetadataSchema = Relationship(back_populates="resources")

# Relationship to link specific resource types
pipeline_runs: List["PipelineRunSchema"] = Relationship(
back_populates="run_metadata_resources",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)",
overlaps="run_metadata_resources,step_runs,artifact_versions,model_versions",
),
)
step_runs: List["StepRunSchema"] = Relationship(
back_populates="run_metadata_resources",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)",
overlaps="run_metadata_resources,pipeline_runs,artifact_versions,model_versions",
),
)
artifact_versions: List["ArtifactVersionSchema"] = Relationship(
back_populates="run_metadata_resources",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)",
overlaps="run_metadata_resources,pipeline_runs,step_runs,model_versions",
),
)
model_versions: List["ModelVersionSchema"] = Relationship(
back_populates="run_metadata_resources",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)",
overlaps="run_metadata_resources,pipeline_runs,step_runs,artifact_versions",
),
)
Loading