From aea476de7242e5b2c4f1a18726e930d5c9bb3c5e Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Mon, 5 Aug 2024 10:55:36 -0700 Subject: [PATCH 01/10] Fix query joiner construction to avoid cartesian joins. The code was joining calib table twice, with and without alias, which caused a cartesian join with that table. --- .../daf/butler/registry/datasets/byDimensions/_storage.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py index 8d409d6921..d66e114f45 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py @@ -622,16 +622,15 @@ def make_query_joiner(self, collections: Sequence[CollectionRecord], fields: Set assert ( self._calibs is not None ), "DatasetTypes with isCalibration() == False can never be found in a CALIBRATION collection." + calibs_table = self._calibs.alias(f"{self.datasetType.name}_calibs") calibs_builder = self._finish_query_builder( - QueryJoiner(self._db, self._calibs.alias(f"{self.datasetType.name}_calibs")).to_builder( - columns - ), + QueryJoiner(self._db, calibs_table).to_builder(columns), [record for record in collections if record.type is CollectionType.CALIBRATION], fields, ) if "timespan" in fields: calibs_builder.joiner.timespans[self.datasetType.name] = ( - self._db.getTimespanRepresentation().from_columns(self._calibs.columns) + self._db.getTimespanRepresentation().from_columns(calibs_table.columns) ) # In calibration collections, we need timespan as well as data ID From 318c00ab1e1c43c66acb73790784b217a8c5852b Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Mon, 5 Aug 2024 15:32:06 -0700 Subject: [PATCH 02/10] Improve handling of Timespan in queries. Postgres sometimes needs help in guessing column type when doing UNION of several queries. --- python/lsst/daf/butler/registry/databases/postgresql.py | 6 +++++- .../daf/butler/registry/datasets/byDimensions/_storage.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/lsst/daf/butler/registry/databases/postgresql.py b/python/lsst/daf/butler/registry/databases/postgresql.py index fcca457a01..51df4e9f3c 100644 --- a/python/lsst/daf/butler/registry/databases/postgresql.py +++ b/python/lsst/daf/butler/registry/databases/postgresql.py @@ -504,7 +504,11 @@ def extract(cls, mapping: Mapping[str, Any], name: str | None = None) -> Timespa def fromLiteral(cls, timespan: Timespan | None) -> _RangeTimespanRepresentation: # Docstring inherited. if timespan is None: - return cls(column=sqlalchemy.sql.null(), name=cls.NAME) + # Cast NULL to an expected type, helps Postgres to figure out + # column type when doing UNION. + return cls( + column=sqlalchemy.func.cast(sqlalchemy.sql.null(), type_=_RangeTimespanType), name=cls.NAME + ) return cls( column=sqlalchemy.sql.cast( sqlalchemy.sql.literal(timespan, type_=_RangeTimespanType), type_=_RangeTimespanType diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py index d66e114f45..9c518f1026 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py @@ -612,7 +612,7 @@ def make_query_joiner(self, collections: Sequence[CollectionRecord], fields: Set ) if "timespan" in fields: tags_builder.joiner.timespans[self.datasetType.name] = ( - self._db.getTimespanRepresentation().fromLiteral(Timespan(None, None)) + self._db.getTimespanRepresentation().fromLiteral(None) ) calibs_builder: QueryBuilder | None = None if CollectionType.CALIBRATION in collection_types: From 6a5060293071d67366acbcce9f4265d48f37b355 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Mon, 5 Aug 2024 15:38:15 -0700 Subject: [PATCH 03/10] Initial implementation of a general query result (DM-45429) Registry method `queryDatasetAssociations` is reimplemented (for both direct and remote butler) to use new query system and new general query result class. --- .../butler/direct_butler/_direct_butler.py | 21 +-- .../daf/butler/direct_query_driver/_driver.py | 3 + .../_result_page_converter.py | 89 ++++++++++++- python/lsst/daf/butler/queries/__init__.py | 1 + .../butler/queries/_general_query_results.py | 125 ++++++++++++++++++ python/lsst/daf/butler/queries/_query.py | 40 +++++- python/lsst/daf/butler/queries/driver.py | 3 +- .../lsst/daf/butler/registry/sql_registry.py | 95 ++++++------- .../daf/butler/remote_butler/_registry.py | 16 ++- 9 files changed, 317 insertions(+), 76 deletions(-) create mode 100644 python/lsst/daf/butler/queries/_general_query_results.py diff --git a/python/lsst/daf/butler/direct_butler/_direct_butler.py b/python/lsst/daf/butler/direct_butler/_direct_butler.py index 6bd55858a3..54d7620853 100644 --- a/python/lsst/daf/butler/direct_butler/_direct_butler.py +++ b/python/lsst/daf/butler/direct_butler/_direct_butler.py @@ -2193,32 +2193,19 @@ def dimensions(self) -> DimensionUniverse: # Docstring inherited. return self._registry.dimensions - @contextlib.contextmanager - def _query(self) -> Iterator[Query]: + def _query(self) -> contextlib.AbstractContextManager[Query]: # Docstring inherited. - with self._query_driver(self._registry.defaults.collections, self.registry.defaults.dataId) as driver: - yield Query(driver) + return self._registry._query() - @contextlib.contextmanager def _query_driver( self, default_collections: Iterable[str], default_data_id: DataCoordinate, - ) -> Iterator[DirectQueryDriver]: + ) -> contextlib.AbstractContextManager[DirectQueryDriver]: """Set up a QueryDriver instance for use with this Butler. Although this is marked as a private method, it is also used by Butler server. """ - with self._caching_context(): - driver = DirectQueryDriver( - self._registry._db, - self.dimensions, - self._registry._managers, - self._registry.dimension_record_cache, - default_collections=default_collections, - default_data_id=default_data_id, - ) - with driver: - yield driver + return self._registry._query_driver(default_collections, default_data_id) def _preload_cache(self) -> None: """Immediately load caches that are used for common operations.""" diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py index 89ce006057..0aaa848b2c 100644 --- a/python/lsst/daf/butler/direct_query_driver/_driver.py +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -79,6 +79,7 @@ DataCoordinateResultPageConverter, DatasetRefResultPageConverter, DimensionRecordResultPageConverter, + GeneralResultPageConverter, ResultPageConverter, ResultPageConverterContext, ) @@ -271,6 +272,8 @@ def _create_result_page_converter(self, spec: ResultSpec, builder: QueryBuilder) return DatasetRefResultPageConverter( spec, self.get_dataset_type(spec.dataset_type_name), context ) + case GeneralResultSpec(): + return GeneralResultPageConverter(spec, context) case _: raise NotImplementedError(f"Result type '{spec.result_type}' not yet implemented") diff --git a/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py b/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py index 22044994f1..f1419bcff4 100644 --- a/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py +++ b/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py @@ -30,7 +30,7 @@ from abc import abstractmethod from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import sqlalchemy @@ -50,9 +50,16 @@ DataCoordinateResultPage, DatasetRefResultPage, DimensionRecordResultPage, + GeneralResultPage, ResultPage, ) -from ..queries.result_specs import DataCoordinateResultSpec, DatasetRefResultSpec, DimensionRecordResultSpec +from ..queries.result_specs import ( + DataCoordinateResultSpec, + DatasetRefResultSpec, + DimensionRecordResultSpec, + GeneralResultSpec, +) +from ..timespan_database_representation import TimespanDatabaseRepresentation if TYPE_CHECKING: from ..registry.interfaces import Database @@ -310,3 +317,81 @@ def convert(self, row: sqlalchemy.Row) -> dict[str, DimensionRecord]: # numpydo the dimensions in the database row. """ return {name: converter.convert(row) for name, converter in self._record_converters.items()} + + +class GeneralResultPageConverter(ResultPageConverter): # numpydoc ignore=PR01 + """Converts raw SQL rows into pages of `GeneralResult` query results.""" + + def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> None: + self.spec = spec + + result_columns = spec.get_result_columns() + self.converters: list[_GeneralColumnConverter] = [] + for column in result_columns: + column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field) + if column.field == TimespanDatabaseRepresentation.NAME: + self.converters.append(_TimespanGeneralColumnConverter(column_name, ctx.db)) + else: + self.converters.append(_DefaultGeneralColumnConverter(column_name)) + + def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage: + rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows] + return GeneralResultPage(spec=self.spec, rows=rows) + + +class _GeneralColumnConverter: + """Interface for converting one or more columns in a result row to a single + column value in output row. + """ + + @abstractmethod + def convert(self, row: sqlalchemy.Row) -> Any: + """Convert one or more columns in the row into single value. + + Parameters + ---------- + row : `sqlalchemy.Row` + Row of values. + + Returns + ------- + value : `Any` + Result of the conversion. + """ + raise NotImplementedError() + + +class _DefaultGeneralColumnConverter(_GeneralColumnConverter): + """Converter that returns column value without conversion. + + Parameters + ---------- + name : `str` + Column name + """ + + def __init__(self, name: str): + self.name = name + + def convert(self, row: sqlalchemy.Row) -> Any: + return row._mapping[self.name] + + +class _TimespanGeneralColumnConverter(_GeneralColumnConverter): + """Converter that extracts timespan from the row. + + Parameters + ---------- + name : `str` + Column name or prefix. + db : `Database` + Database instance. + """ + + def __init__(self, name: str, db: Database): + self.timespan_class = db.getTimespanRepresentation() + self.name = name + + def convert(self, row: sqlalchemy.Row) -> Any: + timespan = self.timespan_class.extract(row._mapping, self.name) + return timespan diff --git a/python/lsst/daf/butler/queries/__init__.py b/python/lsst/daf/butler/queries/__init__.py index 15743f291f..720e4ca6d1 100644 --- a/python/lsst/daf/butler/queries/__init__.py +++ b/python/lsst/daf/butler/queries/__init__.py @@ -29,4 +29,5 @@ from ._data_coordinate_query_results import * from ._dataset_query_results import * from ._dimension_record_query_results import * +from ._general_query_results import * from ._query import * diff --git a/python/lsst/daf/butler/queries/_general_query_results.py b/python/lsst/daf/butler/queries/_general_query_results.py new file mode 100644 index 0000000000..ae99e6ecdf --- /dev/null +++ b/python/lsst/daf/butler/queries/_general_query_results.py @@ -0,0 +1,125 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("GeneralQueryResults",) + +from collections.abc import Iterator +from typing import Any, final + +from .._dataset_ref import DatasetRef +from .._dataset_type import DatasetType +from ..dimensions import DataCoordinate, DimensionGroup +from ._base import QueryResultsBase +from .driver import QueryDriver +from .result_specs import GeneralResultSpec +from .tree import QueryTree, ResultColumn + + +@final +class GeneralQueryResults(QueryResultsBase): + """A query for `DatasetRef` results with a single dataset type. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. The + instance returned directly by the `Butler._query` entry point should be + constructed via `make_unit_query_tree`. + spec : `GeneralResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This class should never be constructed directly by users; use `Query` + methods instead. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: GeneralResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[dict[ResultColumn, Any]]: + """Iterate over result rows. + + Yields + ------ + row_dict : `dict` [`ResultColumn`, `Any`] + Result row as dictionary, the keys are `ResultColumn` instances. + """ + for page in self._driver.execute(self._spec, self._tree): + columns = tuple(page.spec.get_result_columns()) + for row in page.rows: + yield dict(zip(columns, row)) + + def iter_refs(self, dataset_type: DatasetType) -> Iterator[tuple[DatasetRef, dict[ResultColumn, Any]]]: + """Iterate over result rows and return DatasetRef constructed from each + row and an original row. + + Parameters + ---------- + dataset_type : `DatasetType` + Type of the dataset to return. + + Yields + ------ + dataset_ref : `DatasetRef` + Dataset reference. + row_dict : `dict` [`ResultColumn`, `Any`] + Result row as dictionary, the keys are `ResultColumn` instances. + """ + dimensions = dataset_type.dimensions + id_key = ResultColumn(logical_table=dataset_type.name, field="dataset_id") + run_key = ResultColumn(logical_table=dataset_type.name, field="run") + data_id_keys = [ResultColumn(logical_table=element, field=None) for element in dimensions.required] + for row in self: + values = tuple(row[key] for key in data_id_keys) + data_id = DataCoordinate.from_required_values(dimensions, values) + ref = DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]) + yield ref, row + + @property + def dimensions(self) -> DimensionGroup: + # Docstring inherited + return self._spec.dimensions + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count(self._tree, self._spec, exact=exact, discard=discard) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> GeneralQueryResults: + # Docstring inherited. + return GeneralQueryResults(self._driver, tree, self._spec.model_copy(update=kwargs)) + + def _get_datasets(self) -> frozenset[str]: + # Docstring inherited. + return frozenset(self._spec.dataset_fields) diff --git a/python/lsst/daf/butler/queries/_query.py b/python/lsst/daf/butler/queries/_query.py index cda8edc75b..2bf349642f 100644 --- a/python/lsst/daf/butler/queries/_query.py +++ b/python/lsst/daf/butler/queries/_query.py @@ -43,10 +43,16 @@ from ._data_coordinate_query_results import DataCoordinateQueryResults from ._dataset_query_results import DatasetRefQueryResults from ._dimension_record_query_results import DimensionRecordQueryResults +from ._general_query_results import GeneralQueryResults from .convert_args import convert_where_args from .driver import QueryDriver from .expression_factory import ExpressionFactory -from .result_specs import DataCoordinateResultSpec, DatasetRefResultSpec, DimensionRecordResultSpec +from .result_specs import ( + DataCoordinateResultSpec, + DatasetRefResultSpec, + DimensionRecordResultSpec, + GeneralResultSpec, +) from .tree import DatasetSearch, Predicate, QueryTree, make_identity_query_tree @@ -292,6 +298,38 @@ def dimension_records(self, element: str) -> DimensionRecordQueryResults: result_spec = DimensionRecordResultSpec(element=self._driver.universe[element]) return DimensionRecordQueryResults(self._driver, tree, result_spec) + def dataset_associations( + self, + dataset_type: DatasetType, + collections: Iterable[str], + ) -> GeneralQueryResults: + """Iterate over dataset-collection combinations where the dataset is in + the collection. + + Parameters + ---------- + dataset_type : `DatasetType` + A dataset type object. + collections : `~collections.abc.Iterable` [`str`] + Names of the collections to search. Chained collections are + ignored. + + Returns + ------- + result : `GeneralQueryResults` + Query result that can be iterated over. The result includes all + columns needed to construct `DatasetRef`, plus ``collection`` and + ``timespan`` columns. + """ + _, _, query = self._join_dataset_search_impl(dataset_type, collections) + result_spec = GeneralResultSpec( + dimensions=dataset_type.dimensions, + dimension_fields={}, + dataset_fields={dataset_type.name: {"dataset_id", "run", "collection", "timespan"}}, + find_first=False, + ) + return GeneralQueryResults(self._driver, tree=query._tree, spec=result_spec) + def materialize( self, *, diff --git a/python/lsst/daf/butler/queries/driver.py b/python/lsst/daf/butler/queries/driver.py index 6df857b218..f7a6af352a 100644 --- a/python/lsst/daf/butler/queries/driver.py +++ b/python/lsst/daf/butler/queries/driver.py @@ -116,7 +116,8 @@ class GeneralResultPage: spec: GeneralResultSpec - # Raw tabular data, with columns in the same order as spec.columns. + # Raw tabular data, with columns in the same order as + # spec.get_result_columns(). rows: list[tuple[Any, ...]] diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 23c6355f9b..6428224f6b 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -65,7 +65,10 @@ DimensionUniverse, ) from ..dimensions.record_cache import DimensionRecordCache +from ..direct_query_driver import DirectQueryDriver from ..progress import Progress +from ..queries import Query +from ..queries.tree import ResultColumn from ..registry import ( ArgumentError, CollectionExpressionError, @@ -2342,6 +2345,33 @@ def queryDimensionRecords( query = builder.finish().with_record_columns(element.name) return queries.DatabaseDimensionRecordQueryResults(query, element) + @contextlib.contextmanager + def _query(self) -> Iterator[Query]: + """Context manager returning a `Query` object used for construction + and execution of complex queries. + """ + with self._query_driver(self.defaults.collections, self.defaults.dataId) as driver: + yield Query(driver) + + @contextlib.contextmanager + def _query_driver( + self, + default_collections: Iterable[str], + default_data_id: DataCoordinate, + ) -> Iterator[DirectQueryDriver]: + """Set up a `QueryDriver` instance for query execution.""" + with self.caching_context(): + driver = DirectQueryDriver( + self._db, + self.dimensions, + self._managers, + self.dimension_record_cache, + default_collections=default_collections, + default_data_id=default_data_id, + ) + with driver: + yield driver + def queryDatasetAssociations( self, datasetType: str | DatasetType, @@ -2392,59 +2422,18 @@ def queryDatasetAssociations( lsst.daf.butler.registry.CollectionExpressionError Raised when ``collections`` expression is invalid. """ - if collections is None: - if not self.defaults.collections: - raise NoDefaultCollectionError( - "No collections provided to queryDatasetAssociations, " - "and no defaults from registry construction." - ) - collections = self.defaults.collections - collection_wildcard = CollectionWildcard.from_expression(collections) - backend = queries.SqlQueryBackend(self._db, self._managers, self.dimension_record_cache) - parent_dataset_type = backend.resolve_single_dataset_type_wildcard(datasetType) - timespan_tag = DatasetColumnTag(parent_dataset_type.name, "timespan") - collection_tag = DatasetColumnTag(parent_dataset_type.name, "collection") - for parent_collection_record in backend.resolve_collection_wildcard( - collection_wildcard, - collection_types=frozenset(collectionTypes), - flatten_chains=flattenChains, - ): - # Resolve this possibly-chained collection into a list of - # non-CHAINED collections that actually hold datasets of this - # type. - candidate_collection_records = backend.resolve_dataset_collections( - parent_dataset_type, - CollectionWildcard.from_names([parent_collection_record.name]), - allow_calibration_collections=True, - governor_constraints={}, - ) - if not candidate_collection_records: - continue - with backend.context() as context: - relation = backend.make_dataset_query_relation( - parent_dataset_type, - candidate_collection_records, - columns={"dataset_id", "run", "timespan", "collection"}, - context=context, - ) - reader = queries.DatasetRefReader( - parent_dataset_type, - translate_collection=lambda k: self._managers.collections[k].name, - full=False, - ) - for row in context.fetch_iterable(relation): - ref = reader.read(row) - collection_record = self._managers.collections[row[collection_tag]] - if collection_record.type is CollectionType.CALIBRATION: - timespan = row[timespan_tag] - else: - # For backwards compatibility and (possibly?) user - # convenience we continue to define the timespan of a - # DatasetAssociation row for a non-CALIBRATION - # collection to be None rather than a fully unbounded - # timespan. - timespan = None - yield DatasetAssociation(ref=ref, collection=collection_record.name, timespan=timespan) + if isinstance(datasetType, str): + datasetType = self.getDatasetType(datasetType) + resolved_collections = self.queryCollections( + collections, datasetType, collectionTypes=collectionTypes, flattenChains=flattenChains + ) + with self._query() as query: + result = query.dataset_associations(datasetType, resolved_collections) + timespan_key = ResultColumn(logical_table=datasetType.name, field="timespan") + collection_key = ResultColumn(logical_table=datasetType.name, field="collection") + for ref, row_dict in result.iter_refs(datasetType): + _LOG.debug("row_dict: %s", row_dict) + yield DatasetAssociation(ref, row_dict[collection_key], row_dict[timespan_key]) def get_datastore_records(self, ref: DatasetRef) -> DatasetRef: """Retrieve datastore records for given ref. diff --git a/python/lsst/daf/butler/remote_butler/_registry.py b/python/lsst/daf/butler/remote_butler/_registry.py index f7dee118bd..c5fd5a5721 100644 --- a/python/lsst/daf/butler/remote_butler/_registry.py +++ b/python/lsst/daf/butler/remote_butler/_registry.py @@ -46,6 +46,7 @@ DimensionRecord, DimensionUniverse, ) +from ..queries.tree import ResultColumn from ..registry import ( CollectionArgType, CollectionSummary, @@ -65,12 +66,12 @@ DimensionRecordQueryResults, ) from ..registry.wildcards import CollectionWildcard, DatasetTypeWildcard -from ..remote_butler import RemoteButler from ._collection_args import ( convert_collection_arg_to_glob_string_list, convert_dataset_type_arg_to_glob_string_list, ) from ._http_connection import RemoteButlerHttpConnection, parse_model +from ._remote_butler import RemoteButler from .registry._query_common import CommonQueryArguments from .registry._query_data_coordinates import QueryDriverDataCoordinateQueryResults from .registry._query_datasets import QueryDriverDatasetRefQueryResults @@ -513,7 +514,18 @@ def queryDatasetAssociations( collectionTypes: Iterable[CollectionType] = CollectionType.all(), flattenChains: bool = False, ) -> Iterator[DatasetAssociation]: - raise NotImplementedError() + # queryCollections only accepts DatasetType. + if isinstance(datasetType, str): + datasetType = self.getDatasetType(datasetType) + resolved_collections = self.queryCollections( + collections, datasetType=datasetType, collectionTypes=collectionTypes, flattenChains=flattenChains + ) + with self._butler._query() as query: + result = query.dataset_associations(datasetType, resolved_collections) + timespan_key = ResultColumn(logical_table=datasetType.name, field="timespan") + collection_key = ResultColumn(logical_table=datasetType.name, field="collection") + for ref, row_dict in result.iter_refs(datasetType): + yield DatasetAssociation(ref, row_dict[collection_key], row_dict[timespan_key]) @property def storageClasses(self) -> StorageClassFactory: From b4aed891f000434646eeacede3c494babf4c8023 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Tue, 6 Aug 2024 15:28:10 -0700 Subject: [PATCH 04/10] Switch hybrid registry to use remote registry for queryDatasetAssociations. This enables testing of the remote registry implementation, which uncovered a problem with deserialization of general result pages. Had to add additional code to serialize/deserialize that thing correctly. --- python/lsst/daf/butler/column_spec.py | 49 +++++++++++++++++++ .../daf/butler/remote_butler/_query_driver.py | 20 ++++++++ .../server/handlers/_query_serialization.py | 21 ++++++++ .../daf/butler/remote_butler/server_models.py | 15 +++++- .../butler/tests/hybrid_butler_registry.py | 2 +- 5 files changed, 104 insertions(+), 3 deletions(-) diff --git a/python/lsst/daf/butler/column_spec.py b/python/lsst/daf/butler/column_spec.py index 28f93d55f2..0a53031fbb 100644 --- a/python/lsst/daf/butler/column_spec.py +++ b/python/lsst/daf/butler/column_spec.py @@ -53,6 +53,7 @@ from . import arrow_utils, ddl from ._timespan import Timespan +from .pydantic_utils import SerializableRegion, SerializableTime if TYPE_CHECKING: from .name_shrinker import NameShrinker @@ -134,6 +135,18 @@ def to_arrow(self) -> arrow_utils.ToArrow: """ raise NotImplementedError() + @abstractmethod + def type_adapter(self) -> pydantic.TypeAdapter: + """Return pydantic type adapter that converts values of this column to + or from serializable format. + + Returns + ------- + type_adapter : `pydantic.TypeAdapter` + A converter instance. + """ + raise NotImplementedError() + def display(self, level: int = 0, tab: str = " ") -> list[str]: """Return a human-reader-focused string description of this column as a list of lines. @@ -178,6 +191,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_primitive(self.name, pa.uint64(), nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class StringColumnSpec(_BaseColumnSpec): @@ -198,6 +215,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_primitive(self.name, pa.string(), nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class HashColumnSpec(_BaseColumnSpec): @@ -224,6 +245,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: nullable=self.nullable, ) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class FloatColumnSpec(_BaseColumnSpec): @@ -238,6 +263,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_primitive(self.name, pa.float64(), nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class BoolColumnSpec(_BaseColumnSpec): @@ -251,6 +280,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_primitive(self.name, pa.bool_(), nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class UUIDColumnSpec(_BaseColumnSpec): @@ -265,6 +298,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_uuid(self.name, nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class RegionColumnSpec(_BaseColumnSpec): @@ -284,6 +321,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_region(self.name, nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(SerializableRegion) + @final class TimespanColumnSpec(_BaseColumnSpec): @@ -299,6 +340,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_timespan(self.name, nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(self.pytype) + @final class DateTimeColumnSpec(_BaseColumnSpec): @@ -315,6 +360,10 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_datetime(self.name, nullable=self.nullable) + def type_adapter(self) -> pydantic.TypeAdapter: + # Docstring inherited. + return pydantic.TypeAdapter(SerializableTime) + ColumnSpec = Annotated[ Union[ diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index 9106fa86c9..6151e06c51 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -66,6 +66,7 @@ from .server_models import ( AdditionalQueryInput, DataCoordinateUpload, + GeneralResultModel, MaterializedQuery, QueryAnyRequestModel, QueryAnyResponseModel, @@ -260,5 +261,24 @@ def _convert_query_result_page( spec=result_spec, rows=[DatasetRef.from_simple(r, universe) for r in result.rows], ) + elif result_spec.result_type == "general": + assert result.type == "general" + return _convert_general_result(result_spec, result) else: raise NotImplementedError(f"Unhandled result type {result_spec.result_type}") + + +def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage: + """Convert GeneralResultModel to a general result page.""" + columns = spec.get_result_columns() + type_adapters = [ + columns.get_column_spec(column.logical_table, column.field).type_adapter() for column in columns + ] + rows = [ + tuple( + value if value is None else type_adapter.validate_python(value) + for value, type_adapter in zip(row, type_adapters) + ) + for row in model.rows + ] + return GeneralResultPage(spec=spec, rows=rows) diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py index 6511abec24..0c8f4e1dc1 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py @@ -34,6 +34,7 @@ DataCoordinateResultPage, DatasetRefResultPage, DimensionRecordResultPage, + GeneralResultPage, ResultPage, ResultSpec, ) @@ -42,6 +43,7 @@ DataCoordinateResultModel, DatasetRefResultModel, DimensionRecordsResultModel, + GeneralResultModel, QueryErrorResultModel, QueryExecuteResultData, ) @@ -86,5 +88,24 @@ def _convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResul case "dataset_ref": assert isinstance(page, DatasetRefResultPage) return DatasetRefResultModel(rows=[ref.to_simple() for ref in page.rows]) + case "general": + assert isinstance(page, GeneralResultPage) + return _convert_general_result(page) case _: raise NotImplementedError(f"Unhandled query result type {spec.result_type}") + + +def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel: + """Convert GeneralResultPage to a serializable model.""" + columns = page.spec.get_result_columns() + type_adapters = [ + columns.get_column_spec(column.logical_table, column.field).type_adapter() for column in columns + ] + rows = [ + tuple( + value if value is None else type_adapter.dump_python(value) + for value, type_adapter in zip(row, type_adapters) + ) + for row in page.rows + ] + return GeneralResultModel(rows=rows) diff --git a/python/lsst/daf/butler/remote_butler/server_models.py b/python/lsst/daf/butler/remote_butler/server_models.py index f8b9bf261c..876bfe9839 100644 --- a/python/lsst/daf/butler/remote_butler/server_models.py +++ b/python/lsst/daf/butler/remote_butler/server_models.py @@ -37,7 +37,7 @@ "GetCollectionSummaryResponseModel", ] -from typing import Annotated, Literal, NewType, TypeAlias +from typing import Annotated, Any, Literal, NewType, TypeAlias from uuid import UUID import pydantic @@ -276,6 +276,13 @@ class DatasetRefResultModel(pydantic.BaseModel): rows: list[SerializedDatasetRef] +class GeneralResultModel(pydantic.BaseModel): + """Result model for /query/execute/ when user requested general results.""" + + type: Literal["general"] = "general" + rows: list[tuple[Any, ...]] + + class QueryErrorResultModel(pydantic.BaseModel): """Result model for /query/execute when an error occurs part-way through returning rows. @@ -293,7 +300,11 @@ class QueryErrorResultModel(pydantic.BaseModel): QueryExecuteResultData: TypeAlias = Annotated[ - DataCoordinateResultModel | DimensionRecordsResultModel | DatasetRefResultModel | QueryErrorResultModel, + DataCoordinateResultModel + | DimensionRecordsResultModel + | DatasetRefResultModel + | GeneralResultModel + | QueryErrorResultModel, pydantic.Field(discriminator="type"), ] diff --git a/python/lsst/daf/butler/tests/hybrid_butler_registry.py b/python/lsst/daf/butler/tests/hybrid_butler_registry.py index e838d0a258..167c78bc73 100644 --- a/python/lsst/daf/butler/tests/hybrid_butler_registry.py +++ b/python/lsst/daf/butler/tests/hybrid_butler_registry.py @@ -372,7 +372,7 @@ def queryDatasetAssociations( collectionTypes: Iterable[CollectionType] = CollectionType.all(), flattenChains: bool = False, ) -> Iterator[DatasetAssociation]: - return self._direct.queryDatasetAssociations( + return self._remote.queryDatasetAssociations( datasetType, collections, collectionTypes=collectionTypes, flattenChains=flattenChains ) From 1d8674fe0cfa022f1ef73a3145911f346145a068 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Wed, 7 Aug 2024 16:41:37 -0700 Subject: [PATCH 05/10] Remove `Query.dataset_associations` and replace it with `general` method. General result iterator now returns dictionaries indexed by strings. --- .../butler/queries/_general_query_results.py | 26 +++++----- python/lsst/daf/butler/queries/_query.py | 47 +++++++++++-------- .../daf/butler/queries/tree/_column_set.py | 3 ++ .../lsst/daf/butler/registry/sql_registry.py | 11 +++-- .../daf/butler/remote_butler/_registry.py | 11 +++-- 5 files changed, 59 insertions(+), 39 deletions(-) diff --git a/python/lsst/daf/butler/queries/_general_query_results.py b/python/lsst/daf/butler/queries/_general_query_results.py index ae99e6ecdf..20319ca0d7 100644 --- a/python/lsst/daf/butler/queries/_general_query_results.py +++ b/python/lsst/daf/butler/queries/_general_query_results.py @@ -38,7 +38,7 @@ from ._base import QueryResultsBase from .driver import QueryDriver from .result_specs import GeneralResultSpec -from .tree import QueryTree, ResultColumn +from .tree import QueryTree @final @@ -68,20 +68,22 @@ def __init__(self, driver: QueryDriver, tree: QueryTree, spec: GeneralResultSpec super().__init__(driver, tree) self._spec = spec - def __iter__(self) -> Iterator[dict[ResultColumn, Any]]: + def __iter__(self) -> Iterator[dict[str, Any]]: """Iterate over result rows. Yields ------ - row_dict : `dict` [`ResultColumn`, `Any`] - Result row as dictionary, the keys are `ResultColumn` instances. + row_dict : `dict` [`str`, `Any`] + Result row as dictionary, the keys the names of the dimensions, + dimension fields (separated from dimension by dot) or dataset type + fields (separated from dataset type name by dot). """ for page in self._driver.execute(self._spec, self._tree): - columns = tuple(page.spec.get_result_columns()) + columns = tuple(str(column) for column in page.spec.get_result_columns()) for row in page.rows: yield dict(zip(columns, row)) - def iter_refs(self, dataset_type: DatasetType) -> Iterator[tuple[DatasetRef, dict[ResultColumn, Any]]]: + def iter_refs(self, dataset_type: DatasetType) -> Iterator[tuple[DatasetRef, dict[str, Any]]]: """Iterate over result rows and return DatasetRef constructed from each row and an original row. @@ -94,13 +96,15 @@ def iter_refs(self, dataset_type: DatasetType) -> Iterator[tuple[DatasetRef, dic ------ dataset_ref : `DatasetRef` Dataset reference. - row_dict : `dict` [`ResultColumn`, `Any`] - Result row as dictionary, the keys are `ResultColumn` instances. + row_dict : `dict` [`str`, `Any`] + Result row as dictionary, the keys the names of the dimensions, + dimension fields (separated from dimension by dot) or dataset type + fields (separated from dataset type name by dot). """ dimensions = dataset_type.dimensions - id_key = ResultColumn(logical_table=dataset_type.name, field="dataset_id") - run_key = ResultColumn(logical_table=dataset_type.name, field="run") - data_id_keys = [ResultColumn(logical_table=element, field=None) for element in dimensions.required] + id_key = f"{dataset_type.name}.dataset_id" + run_key = f"{dataset_type.name}.run" + data_id_keys = dimensions.required for row in self: values = tuple(row[key] for key in data_id_keys) data_id = DataCoordinate.from_required_values(dimensions, values) diff --git a/python/lsst/daf/butler/queries/_query.py b/python/lsst/daf/butler/queries/_query.py index 2bf349642f..23516a122d 100644 --- a/python/lsst/daf/butler/queries/_query.py +++ b/python/lsst/daf/butler/queries/_query.py @@ -53,7 +53,7 @@ DimensionRecordResultSpec, GeneralResultSpec, ) -from .tree import DatasetSearch, Predicate, QueryTree, make_identity_query_tree +from .tree import DatasetFieldName, DatasetSearch, Predicate, QueryTree, make_identity_query_tree @final @@ -298,37 +298,44 @@ def dimension_records(self, element: str) -> DimensionRecordQueryResults: result_spec = DimensionRecordResultSpec(element=self._driver.universe[element]) return DimensionRecordQueryResults(self._driver, tree, result_spec) - def dataset_associations( + def general( self, - dataset_type: DatasetType, - collections: Iterable[str], + dimensions: DimensionGroup, + dimension_fields: Mapping[str, set[str]] = {}, + dataset_fields: Mapping[str, set[DatasetFieldName]] = {}, + find_first: bool = False, ) -> GeneralQueryResults: - """Iterate over dataset-collection combinations where the dataset is in - the collection. + """Execute query returning general result. Parameters ---------- - dataset_type : `DatasetType` - A dataset type object. - collections : `~collections.abc.Iterable` [`str`] - Names of the collections to search. Chained collections are - ignored. + dimensions : `DimensionGroup` + The dimensions that span all fields returned by this query. + dimension_fields : `~collections.abc.Mapping` [`str`, `set`[`str`]], \ + optional + Dimension record fields included in this query, the key in the + mapping is dimension name. + dataset_fields : `~collections.abc.Mapping` \ + [`str`, `set`[`DatasetFieldName`]], optional + Dataset fields included in this query, the key in the mapping is + dataset type name. + find_first : bool, optional + Whether this query requires find-first resolution for a dataset. + This can only be `True` if exactly one dataset type's fields are + included in the results. Returns ------- result : `GeneralQueryResults` - Query result that can be iterated over. The result includes all - columns needed to construct `DatasetRef`, plus ``collection`` and - ``timespan`` columns. + Query result that can be iterated over. """ - _, _, query = self._join_dataset_search_impl(dataset_type, collections) result_spec = GeneralResultSpec( - dimensions=dataset_type.dimensions, - dimension_fields={}, - dataset_fields={dataset_type.name: {"dataset_id", "run", "collection", "timespan"}}, - find_first=False, + dimensions=dimensions, + dimension_fields=dimension_fields, + dataset_fields=dataset_fields, + find_first=find_first, ) - return GeneralQueryResults(self._driver, tree=query._tree, spec=result_spec) + return GeneralQueryResults(self._driver, tree=self._tree, spec=result_spec) def materialize( self, diff --git a/python/lsst/daf/butler/queries/tree/_column_set.py b/python/lsst/daf/butler/queries/tree/_column_set.py index 5c29109f20..9729f31005 100644 --- a/python/lsst/daf/butler/queries/tree/_column_set.py +++ b/python/lsst/daf/butler/queries/tree/_column_set.py @@ -376,6 +376,9 @@ class ResultColumn(NamedTuple): """Column associated with the dimension element or dataset type, or `None` if it is a dimension key column.""" + def __str__(self) -> str: + return self.logical_table if self.field is None else f"{self.logical_table}.{self.field}" + class ColumnOrder: """Defines the position of columns within a result row and provides helper diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 6428224f6b..aa7fcd1170 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -68,7 +68,6 @@ from ..direct_query_driver import DirectQueryDriver from ..progress import Progress from ..queries import Query -from ..queries.tree import ResultColumn from ..registry import ( ArgumentError, CollectionExpressionError, @@ -2428,9 +2427,13 @@ def queryDatasetAssociations( collections, datasetType, collectionTypes=collectionTypes, flattenChains=flattenChains ) with self._query() as query: - result = query.dataset_associations(datasetType, resolved_collections) - timespan_key = ResultColumn(logical_table=datasetType.name, field="timespan") - collection_key = ResultColumn(logical_table=datasetType.name, field="collection") + query = query.join_dataset_search(datasetType, resolved_collections) + result = query.general( + datasetType.dimensions, + dataset_fields={datasetType.name: {"dataset_id", "run", "collection", "timespan"}}, + ) + timespan_key = f"{datasetType.name}.timespan" + collection_key = f"{datasetType.name}.collection" for ref, row_dict in result.iter_refs(datasetType): _LOG.debug("row_dict: %s", row_dict) yield DatasetAssociation(ref, row_dict[collection_key], row_dict[timespan_key]) diff --git a/python/lsst/daf/butler/remote_butler/_registry.py b/python/lsst/daf/butler/remote_butler/_registry.py index c5fd5a5721..cb1a7aaee5 100644 --- a/python/lsst/daf/butler/remote_butler/_registry.py +++ b/python/lsst/daf/butler/remote_butler/_registry.py @@ -46,7 +46,6 @@ DimensionRecord, DimensionUniverse, ) -from ..queries.tree import ResultColumn from ..registry import ( CollectionArgType, CollectionSummary, @@ -521,9 +520,13 @@ def queryDatasetAssociations( collections, datasetType=datasetType, collectionTypes=collectionTypes, flattenChains=flattenChains ) with self._butler._query() as query: - result = query.dataset_associations(datasetType, resolved_collections) - timespan_key = ResultColumn(logical_table=datasetType.name, field="timespan") - collection_key = ResultColumn(logical_table=datasetType.name, field="collection") + query = query.join_dataset_search(datasetType, resolved_collections) + result = query.general( + datasetType.dimensions, + dataset_fields={datasetType.name: {"dataset_id", "run", "collection", "timespan"}}, + ) + timespan_key = f"{datasetType.name}.timespan" + collection_key = f"{datasetType.name}.collection" for ref, row_dict in result.iter_refs(datasetType): yield DatasetAssociation(ref, row_dict[collection_key], row_dict[timespan_key]) From 83fa69c0f28845f86ae5173eba627a36001ab33b Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Thu, 8 Aug 2024 11:47:51 -0700 Subject: [PATCH 06/10] General query result returns tuples from `iter_tuples` method. Special NamedTuple class documents the items in the returned tuples. --- .../butler/queries/_general_query_results.py | 69 +++++++++++++------ .../lsst/daf/butler/registry/sql_registry.py | 5 +- .../daf/butler/remote_butler/_registry.py | 4 +- 3 files changed, 51 insertions(+), 27 deletions(-) diff --git a/python/lsst/daf/butler/queries/_general_query_results.py b/python/lsst/daf/butler/queries/_general_query_results.py index 20319ca0d7..2cee5be2da 100644 --- a/python/lsst/daf/butler/queries/_general_query_results.py +++ b/python/lsst/daf/butler/queries/_general_query_results.py @@ -27,10 +27,11 @@ from __future__ import annotations -__all__ = ("GeneralQueryResults",) +__all__ = ("GeneralQueryResults", "GeneralResultTuple") +import itertools from collections.abc import Iterator -from typing import Any, final +from typing import Any, NamedTuple, final from .._dataset_ref import DatasetRef from .._dataset_type import DatasetType @@ -41,6 +42,25 @@ from .tree import QueryTree +class GeneralResultTuple(NamedTuple): + """Helper class for general result that represents the result row as a + data coordinate and optionally a set of dataset refs extracted from a row. + """ + + data_id: DataCoordinate + """Data coordinate for current row.""" + + refs: list[DatasetRef] + """Dataset refs extracted from the current row, the order matches the order + of arguments in ``iter_tuples`` call.""" + + raw_row: dict[str, Any] + """Original result row, the keys are the names of the dimensions, + dimension fields (separated from dimension by dot) or dataset type fields + (separated from dataset type name by dot). + """ + + @final class GeneralQueryResults(QueryResultsBase): """A query for `DatasetRef` results with a single dataset type. @@ -74,7 +94,7 @@ def __iter__(self) -> Iterator[dict[str, Any]]: Yields ------ row_dict : `dict` [`str`, `Any`] - Result row as dictionary, the keys the names of the dimensions, + Result row as dictionary, the keys are the names of the dimensions, dimension fields (separated from dimension by dot) or dataset type fields (separated from dataset type name by dot). """ @@ -83,33 +103,38 @@ def __iter__(self) -> Iterator[dict[str, Any]]: for row in page.rows: yield dict(zip(columns, row)) - def iter_refs(self, dataset_type: DatasetType) -> Iterator[tuple[DatasetRef, dict[str, Any]]]: - """Iterate over result rows and return DatasetRef constructed from each - row and an original row. + def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]: + """Iterate over result rows and return data coordinate, and dataset + refs constructed from each row, and an original row. Parameters ---------- - dataset_type : `DatasetType` - Type of the dataset to return. + *dataset_types : `DatasetType` + Zero or more types of the datasets to return. Yields ------ - dataset_ref : `DatasetRef` - Dataset reference. - row_dict : `dict` [`str`, `Any`] - Result row as dictionary, the keys the names of the dimensions, - dimension fields (separated from dimension by dot) or dataset type - fields (separated from dataset type name by dot). + row_tuple : `GeneralResultTuple` + Structure containing data coordinate, refs, and a copy of the row. """ - dimensions = dataset_type.dimensions - id_key = f"{dataset_type.name}.dataset_id" - run_key = f"{dataset_type.name}.run" - data_id_keys = dimensions.required + all_dimensions = self._spec.dimensions + dataset_keys: list[tuple[DimensionGroup, str, str]] = [] + for dataset_type in dataset_types: + dimensions = dataset_type.dimensions + id_key = f"{dataset_type.name}.dataset_id" + run_key = f"{dataset_type.name}.run" + dataset_keys.append((dimensions, id_key, run_key)) for row in self: - values = tuple(row[key] for key in data_id_keys) - data_id = DataCoordinate.from_required_values(dimensions, values) - ref = DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]) - yield ref, row + values = tuple( + row[key] for key in itertools.chain(all_dimensions.required, all_dimensions.implied) + ) + data_coordinate = DataCoordinate.from_full_values(all_dimensions, values) + refs = [] + for dimensions, id_key, run_key in dataset_keys: + values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied)) + data_id = DataCoordinate.from_full_values(dimensions, values) + refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key])) + yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row) @property def dimensions(self) -> DimensionGroup: diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index aa7fcd1170..990616743f 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -2434,9 +2434,8 @@ def queryDatasetAssociations( ) timespan_key = f"{datasetType.name}.timespan" collection_key = f"{datasetType.name}.collection" - for ref, row_dict in result.iter_refs(datasetType): - _LOG.debug("row_dict: %s", row_dict) - yield DatasetAssociation(ref, row_dict[collection_key], row_dict[timespan_key]) + for _, refs, row_dict in result.iter_tuples(datasetType): + yield DatasetAssociation(refs[0], row_dict[collection_key], row_dict[timespan_key]) def get_datastore_records(self, ref: DatasetRef) -> DatasetRef: """Retrieve datastore records for given ref. diff --git a/python/lsst/daf/butler/remote_butler/_registry.py b/python/lsst/daf/butler/remote_butler/_registry.py index cb1a7aaee5..a74e6da454 100644 --- a/python/lsst/daf/butler/remote_butler/_registry.py +++ b/python/lsst/daf/butler/remote_butler/_registry.py @@ -527,8 +527,8 @@ def queryDatasetAssociations( ) timespan_key = f"{datasetType.name}.timespan" collection_key = f"{datasetType.name}.collection" - for ref, row_dict in result.iter_refs(datasetType): - yield DatasetAssociation(ref, row_dict[collection_key], row_dict[timespan_key]) + for _, refs, row_dict in result.iter_tuples(datasetType): + yield DatasetAssociation(refs[0], row_dict[collection_key], row_dict[timespan_key]) @property def storageClasses(self) -> StorageClassFactory: From a14cc9004e9facc42b9de98b8ec7b6fac738c4f1 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Thu, 8 Aug 2024 13:40:10 -0700 Subject: [PATCH 07/10] Allow Query.general to take positional parameters with field names. --- python/lsst/daf/butler/queries/_query.py | 70 ++++++++++++++++++++---- 1 file changed, 58 insertions(+), 12 deletions(-) diff --git a/python/lsst/daf/butler/queries/_query.py b/python/lsst/daf/butler/queries/_query.py index 23516a122d..4641f21d5a 100644 --- a/python/lsst/daf/butler/queries/_query.py +++ b/python/lsst/daf/butler/queries/_query.py @@ -30,6 +30,7 @@ __all__ = ("Query",) from collections.abc import Iterable, Mapping, Set +from types import EllipsisType from typing import Any, final from lsst.utils.iteration import ensure_iterable @@ -44,6 +45,7 @@ from ._dataset_query_results import DatasetRefQueryResults from ._dimension_record_query_results import DimensionRecordQueryResults from ._general_query_results import GeneralQueryResults +from ._identifiers import IdentifierContext, interpret_identifier from .convert_args import convert_where_args from .driver import QueryDriver from .expression_factory import ExpressionFactory @@ -53,7 +55,17 @@ DimensionRecordResultSpec, GeneralResultSpec, ) -from .tree import DatasetFieldName, DatasetSearch, Predicate, QueryTree, make_identity_query_tree +from .tree import ( + DATASET_FIELD_NAMES, + DatasetFieldName, + DatasetFieldReference, + DatasetSearch, + DimensionFieldReference, + DimensionKeyReference, + Predicate, + QueryTree, + make_identity_query_tree, +) @final @@ -301,8 +313,9 @@ def dimension_records(self, element: str) -> DimensionRecordQueryResults: def general( self, dimensions: DimensionGroup, - dimension_fields: Mapping[str, set[str]] = {}, - dataset_fields: Mapping[str, set[DatasetFieldName]] = {}, + *names: str, + dimension_fields: Mapping[str, Set[str]] = {}, + dataset_fields: Mapping[str, Set[DatasetFieldName] | EllipsisType] = {}, find_first: bool = False, ) -> GeneralQueryResults: """Execute query returning general result. @@ -311,14 +324,18 @@ def general( ---------- dimensions : `DimensionGroup` The dimensions that span all fields returned by this query. - dimension_fields : `~collections.abc.Mapping` [`str`, `set`[`str`]], \ - optional + *names : `str` + Names of dimensions fields (in "dimension.field" format), dataset + fields (in "dataset_type.field" format) to include in this query. + dimension_fields : `~collections.abc.Mapping` [`str`, \ + `~collections.abc.Set`[`str`]], optional Dimension record fields included in this query, the key in the mapping is dimension name. - dataset_fields : `~collections.abc.Mapping` \ - [`str`, `set`[`DatasetFieldName`]], optional + dataset_fields : `~collections.abc.Mapping` [`str`, \ + `~collections.abc.Set`[`DatasetFieldName`] | ...], optional Dataset fields included in this query, the key in the mapping is - dataset type name. + dataset type name. Ellipsis (``...``) can be used for value + to include all dataset fields. find_first : bool, optional Whether this query requires find-first resolution for a dataset. This can only be `True` if exactly one dataset type's fields are @@ -329,13 +346,42 @@ def general( result : `GeneralQueryResults` Query result that can be iterated over. """ + dimension_fields_dict = {name: set(fields) for name, fields in dimension_fields.items()} + dataset_fields_dict = { + name: set(DATASET_FIELD_NAMES) if fields is ... else set(fields) + for name, fields in dataset_fields.items() + } + # Parse all names. + context = IdentifierContext(dimensions, set(self._tree.datasets)) + extra_dimension_names: set[str] = set() + for name in names: + identifier = interpret_identifier(context, name) + match identifier: + case DimensionKeyReference(dimension=dimension): + # Could be because someone asked for the key field. + extra_dimension_names.add(dimension.name) + case DimensionFieldReference(element=element, field=field): + extra_dimension_names.add(element.name) + dimension_fields_dict.setdefault(element.name, set()).add(field) + case DatasetFieldReference(dataset_type=dataset_type, field=dataset_field): + dataset_fields_dict.setdefault(dataset_type, set()).add(dataset_field) + case _: + raise TypeError(f"Unexpected type of identifier ({name}): {identifier}") + + extra_dimensions = dimensions.universe.conform(extra_dimension_names) + + # Merge missing dimensions into the tree. + tree = self._tree + if not extra_dimensions <= tree.dimensions: + tree = tree.join_dimensions(extra_dimensions) + result_spec = GeneralResultSpec( - dimensions=dimensions, - dimension_fields=dimension_fields, - dataset_fields=dataset_fields, + dimensions=dimensions.union(extra_dimensions), + dimension_fields=dimension_fields_dict, + dataset_fields=dataset_fields_dict, find_first=find_first, ) - return GeneralQueryResults(self._driver, tree=self._tree, spec=result_spec) + return GeneralQueryResults(self._driver, tree=tree, spec=result_spec) def materialize( self, From 75f7b0a1750fd76529c7cd02106e2ada152e1614 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Thu, 8 Aug 2024 16:55:58 -0700 Subject: [PATCH 08/10] Add unit test for general queries and fix serialization of ingest_date. Adding unit tests exposed the issue with serialization of ingest_time. We still create schema with a native time type for that column, while its corresponding ColumnSpec says it should be astropy time. The code that I added recently used SerializableTime with type adapter to serialize it which did not work with datetime. I had to reimplement serialization methods to allow more flexible handling of ingest_date types. --- python/lsst/daf/butler/column_spec.py | 141 +++++++++++--- .../daf/butler/remote_butler/_query_driver.py | 9 +- .../server/handlers/_query_serialization.py | 10 +- .../lsst/daf/butler/tests/butler_queries.py | 172 +++++++++++++++++- 4 files changed, 296 insertions(+), 36 deletions(-) diff --git a/python/lsst/daf/butler/column_spec.py b/python/lsst/daf/butler/column_spec.py index 0a53031fbb..19209aacda 100644 --- a/python/lsst/daf/butler/column_spec.py +++ b/python/lsst/daf/butler/column_spec.py @@ -41,6 +41,7 @@ "COLLECTION_NAME_MAX_LENGTH", ) +import datetime import textwrap import uuid from abc import ABC, abstractmethod @@ -88,6 +89,102 @@ # that actually changing the value is a (minor) schema change. +class ColumnValueSerializer(ABC): + """Class that knows how to serialize and deserialize column values.""" + + @abstractmethod + def serialize(self, value: Any) -> Any: + """Convert column value to something that can be serialized. + + Parameters + ---------- + value : `Any` + Column value to be serialized. + + Returns + ------- + value : `Any` + Column value in serializable format. + """ + raise NotImplementedError + + @abstractmethod + def deserialize(self, value: Any) -> Any: + """Convert serialized value to column value. + + Parameters + ---------- + value : `Any` + Serialized column value. + + Returns + ------- + value : `Any` + Deserialized column value. + """ + raise NotImplementedError + + +class _DefaultColumnValueSerializer(ColumnValueSerializer): + """Default implementation of serializer for basic types.""" + + def serialize(self, value: Any) -> Any: + # Docstring inherited. + return value + + def deserialize(self, value: Any) -> Any: + # Docstring inherited. + return value + + +class _TypeAdapterColumnValueSerializer(ColumnValueSerializer): + """Implementation of serializer that uses pydantic type adapter.""" + + def __init__(self, type_adapter: pydantic.TypeAdapter): + # Docstring inherited. + self._type_adapter = type_adapter + + def serialize(self, value: Any) -> Any: + # Docstring inherited. + return value if value is None else self._type_adapter.dump_python(value) + + def deserialize(self, value: Any) -> Any: + # Docstring inherited. + return value if value is None else self._type_adapter.validate_python(value) + + +class _DateTimeColumnValueSerializer(ColumnValueSerializer): + """Implementation of serializer for ingest_time column. That column can be + either in native database time appearing as `datetime.datetime` on Python + side or integer nanoseconds appearing as astropy.time.Time. We use pydantic + type adapter for astropy time, which serializes it into integer + nanoseconds. datetime is converted to string representation to distinguish + it from integer nanoseconds (timezone handling depends entirely on what + database returns). + """ + + def __init__(self) -> None: + self._astropy_adapter = pydantic.TypeAdapter(SerializableTime) + + def serialize(self, value: Any) -> Any: + # Docstring inherited. + if value is None: + return None + elif isinstance(value, datetime.datetime): + return value.isoformat() + else: + return self._astropy_adapter.dump_python(value) + + def deserialize(self, value: Any) -> Any: + # Docstring inherited. + if value is None: + return None + elif isinstance(value, str): + return datetime.datetime.fromisoformat(value) + else: + return self._astropy_adapter.validate_python(value) + + class _BaseColumnSpec(pydantic.BaseModel, ABC): """Base class for descriptions of table columns.""" @@ -136,13 +233,13 @@ def to_arrow(self) -> arrow_utils.ToArrow: raise NotImplementedError() @abstractmethod - def type_adapter(self) -> pydantic.TypeAdapter: - """Return pydantic type adapter that converts values of this column to - or from serializable format. + def serializer(self) -> ColumnValueSerializer: + """Return object that converts values of this column to or from + serializable format. Returns ------- - type_adapter : `pydantic.TypeAdapter` + serializer : `ColumnValueSerializer` A converter instance. """ raise NotImplementedError() @@ -191,9 +288,9 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_primitive(self.name, pa.uint64(), nullable=self.nullable) - def type_adapter(self) -> pydantic.TypeAdapter: + def serializer(self) -> ColumnValueSerializer: # Docstring inherited. - return pydantic.TypeAdapter(self.pytype) + return _DefaultColumnValueSerializer() @final @@ -215,9 +312,9 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_primitive(self.name, pa.string(), nullable=self.nullable) - def type_adapter(self) -> pydantic.TypeAdapter: + def serializer(self) -> ColumnValueSerializer: # Docstring inherited. - return pydantic.TypeAdapter(self.pytype) + return _DefaultColumnValueSerializer() @final @@ -245,9 +342,9 @@ def to_arrow(self) -> arrow_utils.ToArrow: nullable=self.nullable, ) - def type_adapter(self) -> pydantic.TypeAdapter: + def serializer(self) -> ColumnValueSerializer: # Docstring inherited. - return pydantic.TypeAdapter(self.pytype) + return _DefaultColumnValueSerializer() @final @@ -263,9 +360,9 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_primitive(self.name, pa.float64(), nullable=self.nullable) - def type_adapter(self) -> pydantic.TypeAdapter: + def serializer(self) -> ColumnValueSerializer: # Docstring inherited. - return pydantic.TypeAdapter(self.pytype) + return _DefaultColumnValueSerializer() @final @@ -280,9 +377,9 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_primitive(self.name, pa.bool_(), nullable=self.nullable) - def type_adapter(self) -> pydantic.TypeAdapter: + def serializer(self) -> ColumnValueSerializer: # Docstring inherited. - return pydantic.TypeAdapter(self.pytype) + return _DefaultColumnValueSerializer() @final @@ -298,9 +395,9 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_uuid(self.name, nullable=self.nullable) - def type_adapter(self) -> pydantic.TypeAdapter: + def serializer(self) -> ColumnValueSerializer: # Docstring inherited. - return pydantic.TypeAdapter(self.pytype) + return _TypeAdapterColumnValueSerializer(pydantic.TypeAdapter(self.pytype)) @final @@ -321,9 +418,9 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_region(self.name, nullable=self.nullable) - def type_adapter(self) -> pydantic.TypeAdapter: + def serializer(self) -> ColumnValueSerializer: # Docstring inherited. - return pydantic.TypeAdapter(SerializableRegion) + return _TypeAdapterColumnValueSerializer(pydantic.TypeAdapter(SerializableRegion)) @final @@ -340,9 +437,9 @@ def to_arrow(self) -> arrow_utils.ToArrow: # Docstring inherited. return arrow_utils.ToArrow.for_timespan(self.name, nullable=self.nullable) - def type_adapter(self) -> pydantic.TypeAdapter: + def serializer(self) -> ColumnValueSerializer: # Docstring inherited. - return pydantic.TypeAdapter(self.pytype) + return _TypeAdapterColumnValueSerializer(pydantic.TypeAdapter(self.pytype)) @final @@ -360,9 +457,9 @@ def to_arrow(self) -> arrow_utils.ToArrow: assert self.nullable is not None, "nullable=None should be resolved by validators" return arrow_utils.ToArrow.for_datetime(self.name, nullable=self.nullable) - def type_adapter(self) -> pydantic.TypeAdapter: + def serializer(self) -> ColumnValueSerializer: # Docstring inherited. - return pydantic.TypeAdapter(SerializableTime) + return _DateTimeColumnValueSerializer() ColumnSpec = Annotated[ diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index 6151e06c51..5cd2d0a793 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -271,14 +271,11 @@ def _convert_query_result_page( def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage: """Convert GeneralResultModel to a general result page.""" columns = spec.get_result_columns() - type_adapters = [ - columns.get_column_spec(column.logical_table, column.field).type_adapter() for column in columns + serializers = [ + columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns ] rows = [ - tuple( - value if value is None else type_adapter.validate_python(value) - for value, type_adapter in zip(row, type_adapters) - ) + tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers)) for row in model.rows ] return GeneralResultPage(spec=spec, rows=rows) diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py index 0c8f4e1dc1..e2db3399d5 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py @@ -98,14 +98,10 @@ def _convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResul def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel: """Convert GeneralResultPage to a serializable model.""" columns = page.spec.get_result_columns() - type_adapters = [ - columns.get_column_spec(column.logical_table, column.field).type_adapter() for column in columns + serializers = [ + columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns ] rows = [ - tuple( - value if value is None else type_adapter.dump_python(value) - for value, type_adapter in zip(row, type_adapters) - ) - for row in page.rows + tuple(serializer.serialize(value) for value, serializer in zip(row, serializers)) for row in page.rows ] return GeneralResultModel(rows=rows) diff --git a/python/lsst/daf/butler/tests/butler_queries.py b/python/lsst/daf/butler/tests/butler_queries.py index 13513a9a91..1d598c0a27 100644 --- a/python/lsst/daf/butler/tests/butler_queries.py +++ b/python/lsst/daf/butler/tests/butler_queries.py @@ -44,7 +44,7 @@ from .._dataset_type import DatasetType from .._exceptions import InvalidQueryError from .._timespan import Timespan -from ..dimensions import DataCoordinate, DimensionRecord +from ..dimensions import DataCoordinate, DimensionGroup, DimensionRecord from ..direct_query_driver import DirectQueryDriver from ..queries import DimensionRecordQueryResults from ..queries.tree import Predicate @@ -209,6 +209,176 @@ def test_simple_dataset_query(self) -> None: self.assertEqual(ref.dataId["detector"], detector) self.assertEqual(ref.run, "imported_g") + def test_general_query(self) -> None: + """Test Query.general and its result.""" + butler = self.make_butler("base.yaml", "datasets.yaml") + dimensions = butler.dimensions["detector"].minimal_group + + # Do simple dimension queries. + with butler._query() as query: + query = query.join_dimensions(dimensions) + rows = list(query.general(dimensions).order_by("detector")) + self.assertEqual( + rows, + [ + {"instrument": "Cam1", "detector": 1}, + {"instrument": "Cam1", "detector": 2}, + {"instrument": "Cam1", "detector": 3}, + {"instrument": "Cam1", "detector": 4}, + ], + ) + rows = list( + query.general(dimensions, "detector.full_name", "purpose").order_by( + "-detector.purpose", "full_name" + ) + ) + self.assertEqual( + rows, + [ + { + "instrument": "Cam1", + "detector": 4, + "detector.full_name": "Bb", + "detector.purpose": "WAVEFRONT", + }, + { + "instrument": "Cam1", + "detector": 1, + "detector.full_name": "Aa", + "detector.purpose": "SCIENCE", + }, + { + "instrument": "Cam1", + "detector": 2, + "detector.full_name": "Ab", + "detector.purpose": "SCIENCE", + }, + { + "instrument": "Cam1", + "detector": 3, + "detector.full_name": "Ba", + "detector.purpose": "SCIENCE", + }, + ], + ) + rows = list( + query.general(dimensions, "detector.full_name", "purpose").where( + "instrument = 'Cam1' AND purpose = 'WAVEFRONT'" + ) + ) + self.assertEqual( + rows, + [ + { + "instrument": "Cam1", + "detector": 4, + "detector.full_name": "Bb", + "detector.purpose": "WAVEFRONT", + }, + ], + ) + result = query.general(dimensions, dimension_fields={"detector": {"full_name"}}) + self.assertEqual(set(row["detector.full_name"] for row in result), {"Aa", "Ab", "Ba", "Bb"}) + + # Use "flat" whose dimension group includes implied dimension. + flat = butler.get_dataset_type("flat") + dimensions = DimensionGroup(butler.dimensions, ["detector", "physical_filter"]) + + # Do simple dataset queries in RUN collection. + with butler._query() as query: + query = query.join_dataset_search("flat", "imported_g") + # This just returns data IDs. + rows = list(query.general(dimensions).order_by("detector")) + self.assertEqual( + rows, + [ + {"instrument": "Cam1", "detector": 2, "physical_filter": "Cam1-G", "band": "g"}, + {"instrument": "Cam1", "detector": 3, "physical_filter": "Cam1-G", "band": "g"}, + {"instrument": "Cam1", "detector": 4, "physical_filter": "Cam1-G", "band": "g"}, + ], + ) + + result = query.general(dimensions, dataset_fields={"flat": ...}).order_by("detector") + ids = {row["flat.dataset_id"] for row in result} + self.assertEqual( + ids, + { + UUID("60c8a65c-7290-4c38-b1de-e3b1cdcf872d"), + UUID("84239e7f-c41f-46d5-97b9-a27976b98ceb"), + UUID("fd51bce1-2848-49d6-a378-f8a122f5139a"), + }, + ) + + # Check what iter_tuples() returns + row_tuples = list(result.iter_tuples(flat)) + self.assertEqual(len(row_tuples), 3) + for row_tuple in row_tuples: + self.assertEqual(len(row_tuple.refs), 1) + self.assertEqual(row_tuple.refs[0].datasetType, flat) + self.assertTrue(row_tuple.refs[0].dataId.hasFull()) + self.assertTrue(row_tuple.data_id.hasFull()) + self.assertEqual(row_tuple.data_id.dimensions, dimensions) + self.assertEqual(row_tuple.raw_row["flat.run"], "imported_g") + + flat1, flat2, flat3 = (row_tuple.refs[0] for row_tuple in row_tuples) + + # Query datasets CALIBRATION/TAGGED collections. + butler.registry.registerCollection("tagged", CollectionType.TAGGED) + butler.registry.registerCollection("calib", CollectionType.CALIBRATION) + + # Add two refs to tagged collection. + butler.registry.associate("tagged", [flat1, flat2]) + + # Certify some calibs. + t1 = astropy.time.Time("2020-01-01T01:00:00", format="isot", scale="tai") + t2 = astropy.time.Time("2020-01-01T02:00:00", format="isot", scale="tai") + t3 = astropy.time.Time("2020-01-01T03:00:00", format="isot", scale="tai") + butler.registry.certify("calib", [flat1], Timespan(t1, t2)) + butler.registry.certify("calib", [flat3], Timespan(t2, t3)) + butler.registry.certify("calib", [flat1], Timespan(t3, None)) + butler.registry.certify("calib", [flat2], Timespan.makeEmpty()) + + # Query tagged collection. + with butler._query() as query: + query = query.join_dataset_search("flat", ["tagged"]) + + result = query.general(dimensions, "flat.dataset_id", "flat.run", "flat.collection") + row_tuples = list(result.iter_tuples(flat)) + self.assertEqual(len(row_tuples), 2) + self.assertEqual({row_tuple.refs[0] for row_tuple in row_tuples}, {flat1, flat2}) + self.assertEqual({row_tuple.raw_row["flat.collection"] for row_tuple in row_tuples}, {"tagged"}) + + # Query calib collection. + with butler._query() as query: + query = query.join_dataset_search("flat", ["calib"]) + result = query.general( + dimensions, "flat.dataset_id", "flat.run", "flat.collection", "flat.timespan" + ) + row_tuples = list(result.iter_tuples(flat)) + self.assertEqual(len(row_tuples), 4) + self.assertEqual({row_tuple.refs[0] for row_tuple in row_tuples}, {flat1, flat2, flat3}) + self.assertEqual({row_tuple.raw_row["flat.collection"] for row_tuple in row_tuples}, {"calib"}) + self.assertEqual( + {row_tuple.raw_row["flat.timespan"] for row_tuple in row_tuples}, + {Timespan(t1, t2), Timespan(t2, t3), Timespan(t3, None), Timespan.makeEmpty()}, + ) + + # Query both tagged and calib collection. + with butler._query() as query: + query = query.join_dataset_search("flat", ["tagged", "calib"]) + result = query.general( + dimensions, "flat.dataset_id", "flat.run", "flat.collection", "flat.timespan" + ) + row_tuples = list(result.iter_tuples(flat)) + self.assertEqual(len(row_tuples), 6) + self.assertEqual( + {row_tuple.raw_row["flat.collection"] for row_tuple in row_tuples}, {"calib", "tagged"} + ) + self.assertEqual( + {row_tuple.raw_row["flat.timespan"] for row_tuple in row_tuples}, + {Timespan(t1, t2), Timespan(t2, t3), Timespan(t3, None), Timespan.makeEmpty(), None}, + ) + def test_implied_union_record_query(self) -> None: """Test queries for a dimension ('band') that uses "implied union" storage, in which its values are the union of the values for it in a From 746888a956f523b1c00f82d85593157e9709830e Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Thu, 8 Aug 2024 17:22:14 -0700 Subject: [PATCH 09/10] Simplify construction of DatasetAssociation from general query result. --- .../lsst/daf/butler/_dataset_association.py | 27 ++++++++++++++++++- .../lsst/daf/butler/registry/sql_registry.py | 5 +--- .../daf/butler/remote_butler/_registry.py | 5 +--- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/python/lsst/daf/butler/_dataset_association.py b/python/lsst/daf/butler/_dataset_association.py index a836a50682..6572fe0c38 100644 --- a/python/lsst/daf/butler/_dataset_association.py +++ b/python/lsst/daf/butler/_dataset_association.py @@ -29,12 +29,17 @@ __all__ = ("DatasetAssociation",) +from collections.abc import Iterator from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any from ._dataset_ref import DatasetRef +from ._dataset_type import DatasetType from ._timespan import Timespan +if TYPE_CHECKING: + from .queries._general_query_results import GeneralQueryResults + @dataclass(frozen=True, eq=True) class DatasetAssociation: @@ -59,6 +64,26 @@ class DatasetAssociation: collection (`Timespan` or `None`). """ + @classmethod + def from_query_result( + cls, result: GeneralQueryResults, dataset_type: DatasetType + ) -> Iterator[DatasetAssociation]: + """Construct dataset associations from the result of general query. + + Parameters + ---------- + result : `GeneralQueryResults` + General query result returned by `Query.general` method. The result + has to include "{dataset_type.name}.timespan" and + "{dataset_type.name}.collection" columns. + dataset_type : `DatasetType` + Dataset type, query has to include this dataset type. + """ + timespan_key = f"{dataset_type.name}.timespan" + collection_key = f"{dataset_type.name}.collection" + for _, refs, row_dict in result.iter_tuples(dataset_type): + yield DatasetAssociation(refs[0], row_dict[collection_key], row_dict[timespan_key]) + def __lt__(self, other: Any) -> bool: # Allow sorting of associations if not isinstance(other, type(self)): diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 990616743f..fb6f51503c 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -2432,10 +2432,7 @@ def queryDatasetAssociations( datasetType.dimensions, dataset_fields={datasetType.name: {"dataset_id", "run", "collection", "timespan"}}, ) - timespan_key = f"{datasetType.name}.timespan" - collection_key = f"{datasetType.name}.collection" - for _, refs, row_dict in result.iter_tuples(datasetType): - yield DatasetAssociation(refs[0], row_dict[collection_key], row_dict[timespan_key]) + yield from DatasetAssociation.from_query_result(result, datasetType) def get_datastore_records(self, ref: DatasetRef) -> DatasetRef: """Retrieve datastore records for given ref. diff --git a/python/lsst/daf/butler/remote_butler/_registry.py b/python/lsst/daf/butler/remote_butler/_registry.py index a74e6da454..bc7515a4c4 100644 --- a/python/lsst/daf/butler/remote_butler/_registry.py +++ b/python/lsst/daf/butler/remote_butler/_registry.py @@ -525,10 +525,7 @@ def queryDatasetAssociations( datasetType.dimensions, dataset_fields={datasetType.name: {"dataset_id", "run", "collection", "timespan"}}, ) - timespan_key = f"{datasetType.name}.timespan" - collection_key = f"{datasetType.name}.collection" - for _, refs, row_dict in result.iter_tuples(datasetType): - yield DatasetAssociation(refs[0], row_dict[collection_key], row_dict[timespan_key]) + yield from DatasetAssociation.from_query_result(result, datasetType) @property def storageClasses(self) -> StorageClassFactory: From 0fbe45858ad236364efbdb0b02874f633cb963ab Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Tue, 13 Aug 2024 14:41:56 -0700 Subject: [PATCH 10/10] Update RepoExportContext to use new query system instead of queryDatasetAssociations. --- python/lsst/daf/butler/transfers/_context.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/lsst/daf/butler/transfers/_context.py b/python/lsst/daf/butler/transfers/_context.py index 7a9ad3a146..b880cec6e7 100644 --- a/python/lsst/daf/butler/transfers/_context.py +++ b/python/lsst/daf/butler/transfers/_context.py @@ -358,13 +358,20 @@ def _computeDatasetAssociations(self) -> dict[str, list[DatasetAssociation]]: collectionTypes = {CollectionType.TAGGED} if datasetType.isCalibration(): collectionTypes.add(CollectionType.CALIBRATION) - associationIter = self._butler.registry.queryDatasetAssociations( - datasetType, - collections=self._collections.keys(), + resolved_collections = self._butler._registry.queryCollections( + self._collections.keys(), + datasetType=datasetType, collectionTypes=collectionTypes, flattenChains=False, ) - for association in associationIter: - if association.ref.id in self._dataset_ids: - results[association.collection].append(association) + with self._butler._query() as query: + query = query.join_dataset_search(datasetType, resolved_collections) + result = query.general( + datasetType.dimensions, + dataset_fields={datasetType.name: {"dataset_id", "run", "collection", "timespan"}}, + ) + for association in DatasetAssociation.from_query_result(result, datasetType): + if association.ref.id in self._dataset_ids: + results[association.collection].append(association) + return results