Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add run metadata and tag indices #3310

Merged
merged 8 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion src/zenml/zen_stores/migrations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from sqlalchemy.exc import (
OperationalError,
)
from sqlalchemy.schema import CreateTable
from sqlalchemy.schema import CreateIndex, CreateTable
from sqlmodel import (
create_engine,
select,
Expand Down Expand Up @@ -249,6 +249,7 @@ def backup_database_to_storage(
# them to the create table statement.

# Extract the unique constraints from the table schema
index_create_statements = []
unique_constraints = []
for index in table.indexes:
if index.unique:
Expand All @@ -258,6 +259,38 @@ def backup_database_to_storage(
unique_constraints.append(
f"UNIQUE KEY `{index.name}` ({', '.join(unique_columns)})"
)
else:
if index.name in {
fk.name for fk in table.foreign_key_constraints
}:
# Foreign key indices are already handled by the
# table creation statement.
continue

index_create = str(CreateIndex(index)).strip() # type: ignore[no-untyped-call]
index_create = index_create.replace(
f"CREATE INDEX {index.name}",
f"CREATE INDEX `{index.name}`",
)
index_create = index_create.replace(
f"ON {table.name}", f"ON `{table.name}`"
)

for column_name in index.columns.keys():
# We need this logic here to avoid the column names
# inside the index name
index_create = index_create.replace(
f"({column_name}", f"(`{column_name}`"
)
index_create = index_create.replace(
f"{column_name},", f"`{column_name}`,"
)
index_create = index_create.replace(
f"{column_name})", f"`{column_name}`)"
)

index_create = index_create.replace('"', "") + ";"
index_create_statements.append(index_create)

# Add the unique constraints to the create table statement
if unique_constraints:
Expand Down Expand Up @@ -290,6 +323,14 @@ def backup_database_to_storage(
)
)

for stmt in index_create_statements:
store_db_info(
dict(
table=table.name,
index_create_stmt=stmt,
)
)

# 2. extract the table data in batches
order_by = [col for col in table.primary_key]

Expand Down Expand Up @@ -356,6 +397,12 @@ def restore_database_from_storage(
"self_references", False
)

if "index_create_stmt" in table_dump:
# execute the index creation statement
connection.execute(text(table_dump["index_create_stmt"]))
# Reload the database metadata after creating the index
metadata.reflect(bind=self.engine)

if "data" in table_dump:
# insert the data into the database
table = metadata.tables[table_name]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Add run metadata and tag index [4d5524b92a30].

Revision ID: 4d5524b92a30
Revises: 0.73.0
Create Date: 2025-01-30 11:30:36.736452

"""

from alembic import op
from sqlalchemy import inspect

# revision identifiers, used by Alembic.
revision = "4d5524b92a30"
down_revision = "0.73.0"
branch_labels = None
depends_on = None


def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
connection = op.get_bind()

inspector = inspect(connection)
for index in inspector.get_indexes("run_metadata_resource"):
# This index was manually added to some databases to improve the
# speed and cache utilisation. In this case we simply return here and
# don't continue with the migration.
if (
index["name"]
== "ix_run_metadata_resource_resource_id_resource_type_run_metadata_"
):
return

# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table(
"run_metadata_resource", schema=None
) as batch_op:
batch_op.create_index(
"ix_run_metadata_resource_resource_id_resource_type_run_metadata_",
["resource_id", "resource_type", "run_metadata_id"],
unique=False,
)

with op.batch_alter_table("tag_resource", schema=None) as batch_op:
batch_op.create_index(
"ix_tag_resource_resource_id_resource_type_tag_id",
["resource_id", "resource_type", "tag_id"],
unique=False,
)

# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade database schema and/or data back to the previous revision."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("tag_resource", schema=None) as batch_op:
batch_op.drop_index("ix_tag_resource_resource_id_resource_type_tag_id")

with op.batch_alter_table(
"run_metadata_resource", schema=None
) as batch_op:
batch_op.drop_index(
"ix_run_metadata_resource_resource_id_resource_type_run_metadata_"
)

# ### end Alembic commands ###
17 changes: 15 additions & 2 deletions src/zenml/zen_stores/schemas/run_metadata_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""SQLModel implementation of pipeline run metadata tables."""
"""SQLModel implementation of run metadata tables."""

from typing import Optional
from uuid import UUID, uuid4
Expand All @@ -21,7 +21,10 @@

from zenml.zen_stores.schemas.base_schemas import BaseSchema
from zenml.zen_stores.schemas.component_schemas import StackComponentSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
from zenml.zen_stores.schemas.schema_utils import (
build_foreign_key_field,
build_index,
)
from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema
from zenml.zen_stores.schemas.user_schemas import UserSchema
from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema
Expand Down Expand Up @@ -82,6 +85,16 @@ class RunMetadataResourceSchema(SQLModel, table=True):
"""Table for linking resources to run metadata entries."""

__tablename__ = "run_metadata_resource"
__table_args__ = (
build_index(
table_name=__tablename__,
column_names=[
"resource_id",
"resource_type",
"run_metadata_id",
],
),
)

id: UUID = Field(default_factory=uuid4, primary_key=True)
resource_id: UUID
Expand Down
36 changes: 34 additions & 2 deletions src/zenml/zen_stores/schemas/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# permissions and limitations under the License.
"""Utility functions for SQLModel schemas."""

from typing import Any
from typing import Any, List

from sqlalchemy import Column, ForeignKey
from sqlalchemy import Column, ForeignKey, Index
from sqlmodel import Field


Expand Down Expand Up @@ -84,3 +84,35 @@ def build_foreign_key_field(
**sa_column_kwargs,
),
)


def get_index_name(table_name: str, column_names: List[str]) -> str:
"""Get the name for an index.

Args:
table_name: The name of the table for which the index will be created.
column_names: Names of the columns on which the index will be created.

Returns:
The index name.
"""
columns = "_".join(column_names)
# MySQL allows a maximum of 64 characters in identifiers
return f"ix_{table_name}_{columns}"[:64]


def build_index(
table_name: str, column_names: List[str], **kwargs: Any
) -> Index:
"""Build an index object.

Args:
table_name: The name of the table for which the index will be created.
column_names: Names of the columns on which the index will be created.
**kwargs: Additional keyword arguments to pass to the Index.

Returns:
The index.
"""
name = get_index_name(table_name=table_name, column_names=column_names)
return Index(name, *column_names, **kwargs)
15 changes: 14 additions & 1 deletion src/zenml/zen_stores/schemas/tag_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
)
from zenml.utils.time_utils import utc_now
from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
from zenml.zen_stores.schemas.schema_utils import (
build_foreign_key_field,
build_index,
)


class TagSchema(NamedSchema, table=True):
Expand Down Expand Up @@ -111,6 +114,16 @@ class TagResourceSchema(BaseSchema, table=True):
"""SQL Model for tag resource relationship."""

__tablename__ = "tag_resource"
__table_args__ = (
build_index(
table_name=__tablename__,
column_names=[
"resource_id",
"resource_type",
"tag_id",
],
),
)

tag_id: UUID = build_foreign_key_field(
source=__tablename__,
Expand Down