Skip to content

Commit

Permalink
Fix mypy issues in existing migration files.
Browse files Browse the repository at this point in the history
Not tested, but mypy thinks it is OK now.
  • Loading branch information
andy-slac committed Dec 5, 2023
1 parent ba9d8fc commit 6999ca1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def upgrade() -> None:

# drop mapping table
_LOG.debug("Dropping mapping table")
op.drop_table(ID_MAP_TABLE_NAME, schema)
op.drop_table(ID_MAP_TABLE_NAME, schema=schema)

# refresh schema from database
metadata = sa.schema.MetaData(schema=schema)
Expand Down
12 changes: 6 additions & 6 deletions migrations/datasets/4e2d7a28475b.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def _migrate_default(
# There may be very many records in dataset table to fit everything in
# memory, so split the whole thing on dataset_type_id.
query = sa.select(table.columns["dataset_type_id"]).select_from(table).distinct()
result = bind.execute(query).scalars()
dataset_type_ids = sorted(result)
scalars = bind.execute(query).scalars()
dataset_type_ids = sorted(scalars)
_LOG.info("Found %s dataset types in dataset table", len(dataset_type_ids))

for dataset_type_id in dataset_type_ids:
Expand All @@ -140,8 +140,8 @@ def _migrate_default(
iterator = iter(rows)
count = 0
while chunk := list(itertools.islice(iterator, 1000)):
query = tmp_table.insert().values(chunk)
result = bind.execute(query)
insert = tmp_table.insert().values(chunk)
result = bind.execute(insert)
count += result.rowcount
_LOG.info("Inserted %s rows into temporary table", count)

Expand All @@ -156,12 +156,12 @@ def _migrate_default(
)

# Update ingest date from a temporary table.
query = table.update().values(
update = table.update().values(
ingest_date=sa.select(tmp_table.columns["ingest_date"])
.where(tmp_table.columns["id"] == table.columns["id"])
.scalar_subquery()
)
result = bind.execute(query)
result = bind.execute(update)
_LOG.info("Updated %s rows in dataset table", result.rowcount)

# Update manager schema version.
Expand Down
8 changes: 6 additions & 2 deletions migrations/obscore-config/4fe28ef5030f.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import json
from typing import TYPE_CHECKING

import yaml
from alembic import context, op
Expand All @@ -15,6 +16,9 @@
from lsst.daf.butler_migrate.registry import make_registry
from lsst.utils import doImportType

if TYPE_CHECKING:
from lsst.daf.butler.registry.obscore import ObsCoreLiveTableManager

# revision identifiers, used by Alembic.
revision = "4fe28ef5030f"
down_revision = "2daeabfb5019"
Expand Down Expand Up @@ -142,7 +146,7 @@ def _make_obscore_table(obscore_config: dict) -> None:
manager_class_name = attributes.get("config:registry.managers.obscore")
if manager_class_name is None:
raise ValueError("Registry obscore manager has to be configured in butler_attributes")
manager_class = doImportType(manager_class_name)
manager_class: type[ObsCoreLiveTableManager] = doImportType(manager_class_name)

repository = context.config.get_section_option("daf_butler_migrate", "repository")
assert repository is not None, "Need repository in configuration"
Expand All @@ -154,7 +158,7 @@ def _make_obscore_table(obscore_config: dict) -> None:
database = registry._db
managers = registry._managers
with database.declareStaticTables(create=False) as staticTablesContext:
manager = manager_class.initialize(
manager: ObsCoreLiveTableManager = manager_class.initialize( # type: ignore[assignment]
database,
staticTablesContext,
universe=registry.dimensions,
Expand Down

0 comments on commit 6999ca1

Please sign in to comment.