diff --git a/doc/changes/DM-38498.feature.md b/doc/changes/DM-38498.feature.md new file mode 100644 index 0000000000..c9b3b0c309 --- /dev/null +++ b/doc/changes/DM-38498.feature.md @@ -0,0 +1 @@ +Improve support for finding calibrations and spatially-joined datasets as follow-ups to data ID queries. diff --git a/python/lsst/daf/butler/core/dimensions/_coordinate.py b/python/lsst/daf/butler/core/dimensions/_coordinate.py index afe5985fe0..f73a44be9a 100644 --- a/python/lsst/daf/butler/core/dimensions/_coordinate.py +++ b/python/lsst/daf/butler/core/dimensions/_coordinate.py @@ -340,12 +340,12 @@ def fromFullValues(graph: DimensionGraph, values: tuple[DataIdValue, ...]) -> Da return _BasicTupleDataCoordinate(graph, values) def __hash__(self) -> int: - return hash((self.graph,) + tuple(self[d.name] for d in self.graph.required)) + return hash((self.graph,) + self.values_tuple()) def __eq__(self, other: Any) -> bool: if not isinstance(other, DataCoordinate): other = DataCoordinate.standardize(other, universe=self.universe) - return self.graph == other.graph and all(self[d.name] == other[d.name] for d in self.graph.required) + return self.graph == other.graph and self.values_tuple() == other.values_tuple() def __repr__(self) -> str: # We can't make repr yield something that could be exec'd here without @@ -513,6 +513,7 @@ def hasFull(self) -> bool: raise NotImplementedError() @property + @abstractmethod def full(self) -> NamedKeyMapping[Dimension, DataIdValue]: """Return mapping for all dimensions in ``self.graph``. @@ -524,8 +525,17 @@ def full(self) -> NamedKeyMapping[Dimension, DataIdValue]: when implied keys are accessed via the returned mapping, depending on the implementation and whether assertions are enabled. """ - assert self.hasFull(), "full may only be accessed if hasFull() returns True." - return _DataCoordinateFullView(self) + raise NotImplementedError() + + @abstractmethod + def values_tuple(self) -> tuple[DataIdValue, ...]: + """Return the required values (only) of this data ID as a tuple. + + In contexts where all data IDs have the same dimensions, comparing and + hashing these tuples can be *much* faster than comparing the original + `DataCoordinate` instances. + """ + raise NotImplementedError() @abstractmethod def hasRecords(self) -> bool: @@ -779,7 +789,7 @@ class _DataCoordinateFullView(NamedKeyMapping[Dimension, DataIdValue]): The `DataCoordinate` instance this object provides a view of. """ - def __init__(self, target: DataCoordinate): + def __init__(self, target: _BasicTupleDataCoordinate): self._target = target __slots__ = ("_target",) @@ -892,6 +902,13 @@ def __getitem__(self, key: DataIdKey) -> DataIdValue: # values for the required ones. raise KeyError(key) from None + def byName(self) -> dict[str, DataIdValue]: + # Docstring inheritance. + # Reimplementation is for optimization; `values_tuple()` is much faster + # to iterate over than values() because it doesn't go through + # `__getitem__`. + return dict(zip(self.names, self.values_tuple(), strict=True)) + def subset(self, graph: DimensionGraph) -> DataCoordinate: # Docstring inherited from DataCoordinate. if self._graph == graph: @@ -933,6 +950,12 @@ def union(self, other: DataCoordinate) -> DataCoordinate: values.update(other.full.byName() if other.hasFull() else other.byName()) return DataCoordinate.standardize(values, graph=graph) + @property + def full(self) -> NamedKeyMapping[Dimension, DataIdValue]: + # Docstring inherited. + assert self.hasFull(), "full may only be accessed if hasFull() returns True." + return _DataCoordinateFullView(self) + def expanded( self, records: NameLookupMapping[DimensionElement, DimensionRecord | None] ) -> DataCoordinate: @@ -954,6 +977,10 @@ def hasRecords(self) -> bool: # Docstring inherited from DataCoordinate. return False + def values_tuple(self) -> tuple[DataIdValue, ...]: + # Docstring inherited from DataCoordinate. + return self._values[: len(self._graph.required)] + def _record(self, name: str) -> DimensionRecord | None: # Docstring inherited from DataCoordinate. raise AssertionError() diff --git a/python/lsst/daf/butler/core/dimensions/_records.py b/python/lsst/daf/butler/core/dimensions/_records.py index 6de6a99954..22e9cebaf6 100644 --- a/python/lsst/daf/butler/core/dimensions/_records.py +++ b/python/lsst/daf/butler/core/dimensions/_records.py @@ -342,6 +342,10 @@ def to_simple(self, minimal: bool = False) -> SerializedDimensionRecord: # query. This may not be overly useful since to reconstruct # a collection of records will require repeated registry queries. # For now do not implement minimal form. + key = (id(self.definition), self.dataId) + cache = PersistenceContextVars.serializedDimensionRecordMapping.get() + if cache is not None and (result := cache.get(key)) is not None: + return result mapping = {name: getattr(self, name) for name in self.__slots__} # If the item in mapping supports simplification update it @@ -360,7 +364,10 @@ def to_simple(self, minimal: bool = False) -> SerializedDimensionRecord: # hash objects, encode it here to a hex string mapping[k] = v.hex() definition = self.definition.to_simple(minimal=minimal) - return SerializedDimensionRecord(definition=definition, record=mapping) + dimRec = SerializedDimensionRecord(definition=definition, record=mapping) + if cache is not None: + cache[key] = dimRec + return dimRec @classmethod def from_simple( diff --git a/python/lsst/daf/butler/core/persistenceContext.py b/python/lsst/daf/butler/core/persistenceContext.py index 7542754704..bc39b899b4 100644 --- a/python/lsst/daf/butler/core/persistenceContext.py +++ b/python/lsst/daf/butler/core/persistenceContext.py @@ -94,7 +94,7 @@ class PersistenceContextVars: """ serializedDimensionRecordMapping: ContextVar[ - dict[tuple[str, frozenset], SerializedDimensionRecord] | None + dict[tuple[str, frozenset] | tuple[int, DataCoordinate], SerializedDimensionRecord] | None ] = ContextVar("serializedDimensionRecordMapping", default=None) r"""A cache of `SerializedDimensionRecord`\ s. """ diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py index d08530d8da..ac8dc4747e 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py @@ -728,9 +728,9 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> sqlalchemy.sql.select( dataset.columns.id.label("dataset_id"), dataset.columns.dataset_type_id.label("dataset_type_id"), - tmp_tags.columns.dataset_type_id.label("new dataset_type_id"), + tmp_tags.columns.dataset_type_id.label("new_dataset_type_id"), dataset.columns[self._runKeyColumn].label("run"), - tmp_tags.columns[collFkName].label("new run"), + tmp_tags.columns[collFkName].label("new_run"), ) .select_from(dataset.join(tmp_tags, dataset.columns.id == tmp_tags.columns.dataset_id)) .where( @@ -742,21 +742,38 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> .limit(1) ) with self._db.query(query) as result: + # Only include the first one in the exception message if (row := result.first()) is not None: - # Only include the first one in the exception message - raise ConflictingDefinitionError( - f"Existing dataset type or run do not match new dataset: {row._asdict()}" - ) + existing_run = self._collections[row.run].name + new_run = self._collections[row.new_run].name + if row.dataset_type_id == self._dataset_type_id: + if row.new_dataset_type_id == self._dataset_type_id: + raise ConflictingDefinitionError( + f"Current run {existing_run!r} and new run {new_run!r} do not agree for " + f"dataset {row.dataset_id}." + ) + else: + raise ConflictingDefinitionError( + f"Dataset {row.dataset_id} was provided with type {self.datasetType.name!r} " + f"in run {new_run!r}, but was already defined with type ID {row.dataset_type_id} " + f"in run {run!r}." + ) + else: + raise ConflictingDefinitionError( + f"Dataset {row.dataset_id} was provided with type ID {row.new_dataset_type_id} " + f"in run {new_run!r}, but was already defined with type {self.datasetType.name!r} " + f"in run {run!r}." + ) # Check that matching dataset in tags table has the same DataId. query = ( sqlalchemy.sql.select( tags.columns.dataset_id, tags.columns.dataset_type_id.label("type_id"), - tmp_tags.columns.dataset_type_id.label("new type_id"), + tmp_tags.columns.dataset_type_id.label("new_type_id"), *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], *[ - tmp_tags.columns[dim].label(f"new {dim}") + tmp_tags.columns[dim].label(f"new_{dim}") for dim in self.datasetType.dimensions.required.names ], ) @@ -783,12 +800,11 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> # Check that matching run+dataId have the same dataset ID. query = ( sqlalchemy.sql.select( - tags.columns.dataset_type_id.label("dataset_type_id"), *[tags.columns[dim] for dim in self.datasetType.dimensions.required.names], tags.columns.dataset_id, - tmp_tags.columns.dataset_id.label("new dataset_id"), + tmp_tags.columns.dataset_id.label("new_dataset_id"), tags.columns[collFkName], - tmp_tags.columns[collFkName].label(f"new {collFkName}"), + tmp_tags.columns[collFkName].label(f"new_{collFkName}"), ) .select_from( tags.join( @@ -807,8 +823,13 @@ def _validateImport(self, tmp_tags: sqlalchemy.schema.Table, run: RunRecord) -> .limit(1) ) with self._db.query(query) as result: + # only include the first one in the exception message if (row := result.first()) is not None: - # only include the first one in the exception message + data_id = {dim: getattr(row, dim) for dim in self.datasetType.dimensions.required.names} + existing_collection = self._collections[getattr(row, collFkName)].name + new_collection = self._collections[getattr(row, f"new_{collFkName}")].name raise ConflictingDefinitionError( - f"Existing dataset type and dataId does not match new dataset: {row._asdict()}" + f"Dataset with type {self.datasetType.name!r} and data ID {data_id} " + f"has ID {row.dataset_id} in existing collection {existing_collection!r} " + f"but ID {row.new_dataset_id} in new collection {new_collection!r}." ) diff --git a/python/lsst/daf/butler/registry/queries/_query.py b/python/lsst/daf/butler/registry/queries/_query.py index c129f5d3a6..3b79afd8f7 100644 --- a/python/lsst/daf/butler/registry/queries/_query.py +++ b/python/lsst/daf/butler/registry/queries/_query.py @@ -22,6 +22,7 @@ __all__ = () +import itertools from collections.abc import Iterable, Iterator, Mapping, Sequence, Set from contextlib import contextmanager from typing import Any, cast, final @@ -38,7 +39,9 @@ DimensionGraph, DimensionKeyColumnTag, DimensionRecord, + DimensionRecordColumnTag, ) +from .._collectionType import CollectionType from ..wildcards import CollectionWildcard from ._query_backend import QueryBackend from ._query_context import QueryContext @@ -237,6 +240,52 @@ def iter_dataset_refs( else: yield parent_ref.makeComponentRef(component) + def iter_data_ids_and_dataset_refs( + self, dataset_type: DatasetType, dimensions: DimensionGraph | None = None + ) -> Iterator[tuple[DataCoordinate, DatasetRef]]: + """Iterate over pairs of data IDs and dataset refs. + + This permits the data ID dimensions to differ from the dataset + dimensions. + + Parameters + ---------- + dataset_type : `DatasetType` + The parent dataset type to yield references for. + dimensions : `DimensionGraph`, optional + Dimensions of the data IDs to return. If not provided, + ``self.dimensions`` is used. + + Returns + ------- + pairs : `~collections.abc.Iterable` [ `tuple` [ `DataCoordinate`, + `DatasetRef` ] ] + An iterator over (data ID, dataset reference) pairs. + """ + if dimensions is None: + dimensions = self._dimensions + data_id_reader = DataCoordinateReader.make( + dimensions, records=self._has_record_columns is True, record_caches=self._record_caches + ) + dataset_reader = DatasetRefReader( + dataset_type, + translate_collection=self._backend.get_collection_name, + records=self._has_record_columns is True, + record_caches=self._record_caches, + ) + if not (data_id_reader.columns_required <= self.relation.columns): + raise ColumnError( + f"Missing column(s) {set(data_id_reader.columns_required - self.relation.columns)} " + f"for data IDs with dimensions {dimensions}." + ) + if not (dataset_reader.columns_required <= self.relation.columns): + raise ColumnError( + f"Missing column(s) {set(dataset_reader.columns_required - self.relation.columns)} " + f"for datasets with type {dataset_type.name} and dimensions {dataset_type.dimensions}." + ) + for row in self: + yield (data_id_reader.read(row), dataset_reader.read(row)) + def iter_dimension_records(self, element: DimensionElement | None = None) -> Iterator[DimensionRecord]: """Return an iterator that converts result rows to dimension records. @@ -603,9 +652,6 @@ def find_datasets( lsst.daf.relation.ColumnError Raised if a dataset search is already present in this query and this is a find-first search. - ValueError - Raised if the given dataset type's dimensions are not a subset of - the current query's dimensions. """ if find_first and DatasetColumnTag.filter_from(self._relation.columns): raise ColumnError( @@ -613,7 +659,7 @@ def find_datasets( "on a query that already includes dataset columns." ) # - # TODO: it'd nice to do a QueryContext.restore_columns call here or + # TODO: it'd be nice to do a QueryContext.restore_columns call here or # similar, to look for dataset-constraint joins already present in the # relation and expand them to include dataset-result columns as well, # instead of doing a possibly-redundant join here. But that would @@ -635,14 +681,6 @@ def find_datasets( # where we materialize the initial data ID query into a temp table # and hence can't go back and "recover" those dataset columns anyway; # - if not (dataset_type.dimensions <= self._dimensions): - raise ValueError( - "Cannot find datasets from a query unless the dataset types's dimensions " - f"({dataset_type.dimensions}, for {dataset_type.name}) are a subset of the query's " - f"({self._dimensions})." - ) - columns = set(columns) - columns.add("dataset_id") collections = CollectionWildcard.from_expression(collections) if find_first: collections.require_ordered() @@ -651,23 +689,107 @@ def find_datasets( dataset_type, collections, governor_constraints=self._governor_constraints, - allow_calibration_collections=False, # TODO + allow_calibration_collections=True, rejections=rejections, ) + # If the dataset type has dimensions not in the current query, or we + # need a temporal join for a calibration collection, either restore + # those columns or join them in. + full_dimensions = dataset_type.dimensions.union(self._dimensions) + relation = self._relation + record_caches = self._record_caches + base_columns_required: set[ColumnTag] = { + DimensionKeyColumnTag(name) for name in full_dimensions.names + } + spatial_joins: list[tuple[str, str]] = [] + if not (dataset_type.dimensions <= self._dimensions): + if self._has_record_columns is True: + # This query is for expanded data IDs, so if we add new + # dimensions to the query we need to be able to get records for + # the new dimensions. + record_caches = dict(self._record_caches) + for element in full_dimensions.elements: + if element in record_caches: + continue + if ( + cache := self._backend.get_dimension_record_cache(element.name, self._context) + ) is not None: + record_caches[element] = cache + else: + base_columns_required.update(element.RecordClass.fields.columns.keys()) + # See if we need spatial joins between the current query and the + # dataset type's dimensions. The logic here is for multiple + # spatial joins in general, but in practice it'll be exceedingly + # rare for there to be more than one. We start by figuring out + # which spatial "families" (observations vs. skymaps, skypix + # systems) are present on only one side and not the other. + lhs_spatial_families = self._dimensions.spatial - dataset_type.dimensions.spatial + rhs_spatial_families = dataset_type.dimensions.spatial - self._dimensions.spatial + # Now we iterate over the Cartesian product of those, so e.g. + # if the query has {tract, patch, visit} and the dataset type + # has {htm7} dimensions, the iterations of this loop + # correspond to: (skymap, htm), (observations, htm). + for lhs_spatial_family, rhs_spatial_family in itertools.product( + lhs_spatial_families, rhs_spatial_families + ): + # For each pair we add a join between the most-precise element + # present in each family (e.g. patch beats tract). + spatial_joins.append( + ( + lhs_spatial_family.choose(full_dimensions.elements).name, + rhs_spatial_family.choose(full_dimensions.elements).name, + ) + ) + # Set up any temporal join between the query dimensions and CALIBRATION + # collection's validity ranges. + temporal_join_on: set[ColumnTag] = set() + if any(r.type is CollectionType.CALIBRATION for r in collection_records): + for family in self._dimensions.temporal: + endpoint = family.choose(self._dimensions.elements) + temporal_join_on.add(DimensionRecordColumnTag(endpoint.name, "timespan")) + base_columns_required.update(temporal_join_on) + # Note which of the many kinds of potentially-missing columns we have + # and add the rest. + base_columns_required.difference_update(relation.columns) + if base_columns_required: + relation = self._backend.make_dimension_relation( + full_dimensions, + base_columns_required, + self._context, + initial_relation=relation, + # Don't permit joins to use any columns beyond those in the + # original relation, as that would change what this + # operation does. + initial_join_max_columns=frozenset(self._relation.columns), + governor_constraints=self._governor_constraints, + spatial_joins=spatial_joins, + ) + # Finally we can join in the search for the dataset query. + columns = set(columns) + columns.add("dataset_id") if not collection_records: - relation = self._relation.join( + relation = relation.join( self._backend.make_doomed_dataset_relation(dataset_type, columns, rejections, self._context) ) elif find_first: relation = self._backend.make_dataset_search_relation( - dataset_type, collection_records, columns, self._context, join_to=self._relation + dataset_type, + collection_records, + columns, + self._context, + join_to=relation, + temporal_join_on=temporal_join_on, ) else: - dataset_relation = self._backend.make_dataset_query_relation( - dataset_type, collection_records, columns, self._context + relation = self._backend.make_dataset_query_relation( + dataset_type, + collection_records, + columns, + self._context, + join_to=relation, + temporal_join_on=temporal_join_on, ) - relation = self.relation.join(dataset_relation) - return self._chain(relation, defer=defer) + return self._chain(relation, dimensions=full_dimensions, record_caches=record_caches, defer=defer) def sliced( self, diff --git a/python/lsst/daf/butler/registry/queries/_query_backend.py b/python/lsst/daf/butler/registry/queries/_query_backend.py index 426e790a1c..542cc8a613 100644 --- a/python/lsst/daf/butler/registry/queries/_query_backend.py +++ b/python/lsst/daf/butler/registry/queries/_query_backend.py @@ -28,9 +28,11 @@ from lsst.daf.relation import ( BinaryOperationRelation, + ColumnExpression, ColumnTag, LeafRelation, MarkerRelation, + Predicate, Relation, UnaryOperationRelation, ) @@ -43,6 +45,7 @@ DimensionKeyColumnTag, DimensionRecord, DimensionUniverse, + timespan, ) from .._collectionType import CollectionType from .._exceptions import DatasetTypeError, MissingDatasetTypeError @@ -411,7 +414,7 @@ def resolve_dataset_collections( return supported_collection_records @abstractmethod - def make_dataset_query_relation( + def _make_dataset_query_relation_impl( self, dataset_type: DatasetType, collections: Sequence[CollectionRecord], @@ -438,9 +441,87 @@ def make_dataset_query_relation( ------- relation : `lsst.daf.relation.Relation` Relation representing a dataset query. + + Notes + ----- + This method must be implemented by derived classes but is not + responsible for joining the resulting relation to an existing relation. """ raise NotImplementedError() + def make_dataset_query_relation( + self, + dataset_type: DatasetType, + collections: Sequence[CollectionRecord], + columns: Set[str], + context: _C, + *, + join_to: Relation | None = None, + temporal_join_on: Set[ColumnTag] = frozenset(), + ) -> Relation: + """Construct a relation that represents an unordered query for datasets + that returns matching results from all given collections. + + Parameters + ---------- + dataset_type : `DatasetType` + Type for the datasets being queried. + collections : `~collections.abc.Sequence` [ `CollectionRecord` ] + Records for collections to query. Should generally be the result + of a call to `resolve_dataset_collections`, and must not be empty. + context : `QueryContext` + Context that manages per-query state. + columns : `~collections.abc.Set` [ `str` ] + Columns to include in the relation. See `Query.find_datasets` for + details. + join_to : `Relation`, optional + Another relation to join with the query for datasets in all + collections. + temporal_join_on: `~collections.abc.Set` [ `ColumnTag` ], optional + Timespan columns in ``join_to`` that calibration dataset timespans + must overlap. Must already be present in ``join_to``. Ignored if + ``join_to`` is `None` or if there are no calibration collections. + + Returns + ------- + relation : `lsst.daf.relation.Relation` + Relation representing a dataset query. + """ + # If we need to do a temporal join to a calibration collection, we need + # to include the timespan column in the base query and prepare the join + # predicate. + join_predicates: list[Predicate] = [] + base_timespan_tag: ColumnTag | None = None + full_columns: set[str] = set(columns) + if ( + temporal_join_on + and join_to is not None + and any(r.type is CollectionType.CALIBRATION for r in collections) + ): + base_timespan_tag = DatasetColumnTag(dataset_type.name, "timespan") + rhs = ColumnExpression.reference(base_timespan_tag, dtype=timespan.Timespan) + full_columns.add("timespan") + for timespan_tag in temporal_join_on: + lhs = ColumnExpression.reference(timespan_tag, dtype=timespan.Timespan) + join_predicates.append(lhs.predicate_method("overlaps", rhs)) + # Delegate to the concrete QueryBackend subclass to do most of the + # work. + result = self._make_dataset_query_relation_impl( + dataset_type, + collections, + full_columns, + context=context, + ) + if join_to is not None: + result = join_to.join( + result, predicate=Predicate.logical_and(*join_predicates) if join_predicates else None + ) + if join_predicates and "timespan" not in columns: + # Drop the timespan column we added for the join only if the + # timespan wasn't requested in its own right. + result = result.with_only_columns(result.columns - {base_timespan_tag}) + return result + def make_dataset_search_relation( self, dataset_type: DatasetType, @@ -449,10 +530,11 @@ def make_dataset_search_relation( context: _C, *, join_to: Relation | None = None, + temporal_join_on: Set[ColumnTag] = frozenset(), ) -> Relation: """Construct a relation that represents an order query for datasets - that returns results from the first matching collection for each - data ID. + that returns results from the first matching collection for each data + ID. Parameters ---------- @@ -462,13 +544,17 @@ def make_dataset_search_relation( Records for collections to search. Should generally be the result of a call to `resolve_dataset_collections`, and must not be empty. columns : `~collections.abc.Set` [ `str` ] - Columns to include in the `relation. See + Columns to include in the ``relation``. See `make_dataset_query_relation` for options. context : `QueryContext` Context that manages per-query state. join_to : `Relation`, optional Another relation to join with the query for datasets in all collections before filtering out out shadowed datasets. + temporal_join_on: `~collections.abc.Set` [ `ColumnTag` ], optional + Timespan columns in ``join_to`` that calibration dataset timespans + must overlap. Must already be present in ``join_to``. Ignored if + ``join_to`` is `None` or if there are no calibration collections. Returns ------- @@ -480,9 +566,9 @@ def make_dataset_search_relation( collections, columns | {"rank"}, context=context, + join_to=join_to, + temporal_join_on=temporal_join_on, ) - if join_to is not None: - base = join_to.join(base) # Query-simplification shortcut: if there is only one collection, a # find-first search is just a regular result subquery. Same if there # are no collections. diff --git a/python/lsst/daf/butler/registry/queries/_results.py b/python/lsst/daf/butler/registry/queries/_results.py index c024037455..b128368ff3 100644 --- a/python/lsst/daf/butler/registry/queries/_results.py +++ b/python/lsst/daf/butler/registry/queries/_results.py @@ -254,8 +254,6 @@ def findDatasets( Raises ------ - ValueError - Raised if ``datasetType.dimensions.issubset(self.graph) is False``. MissingDatasetTypeError Raised if the given dataset type is not registered. """ @@ -268,6 +266,63 @@ def findDatasets( components_found, ) + def findRelatedDatasets( + self, + datasetType: DatasetType | str, + collections: Any, + *, + findFirst: bool = True, + dimensions: DimensionGraph | None = None, + ) -> Iterable[tuple[DataCoordinate, DatasetRef]]: + """Find datasets using the data IDs identified by this query, and + return them along with the original data IDs. + + This is a variant of `findDatasets` that is often more useful when + the target dataset type does not have all of the dimensions of the + original data ID query, as is generally the case with calibration + lookups. + + Parameters + ---------- + datasetType : `DatasetType` or `str` + Dataset type or the name of one to search for. Must have + dimensions that are a subset of ``self.graph``. + collections : `Any` + An expression that fully or partially identifies the collections + to search for the dataset, such as a `str`, `re.Pattern`, or + iterable thereof. ``...`` can be used to return all collections. + See :ref:`daf_butler_collection_expressions` for more information. + findFirst : `bool`, optional + If `True` (default), for each data ID in ``self``, only yield one + `DatasetRef`, from the first collection in which a dataset of that + dataset type appears (according to the order of ``collections`` + passed in). If `True`, ``collections`` must not contain regular + expressions and may not be ``...``. Note that this is not the + same as yielding one `DatasetRef` for each yielded data ID if + ``dimensions`` is not `None`. + dimensions : `DimensionGraph`, optional + The dimensions of the data IDs returned. Must be a subset of + ``self.dimensions``. + + Returns + ------- + pairs : `~collections.abc.Iterable` [ `tuple` [ `DataCoordinate`, \ + `DatasetRef` ] ] + An iterable of (data ID, dataset reference) pairs. + + Raises + ------ + MissingDatasetTypeError + Raised if the given dataset type is not registered. + """ + if dimensions is None: + dimensions = self.graph + parent_dataset_type, _ = self._query.backend.resolve_single_dataset_type_wildcard( + datasetType, components=False, explicit_only=True + ) + query = self._query.find_datasets(parent_dataset_type, collections, find_first=findFirst, defer=True) + return query.iter_data_ids_and_dataset_refs(parent_dataset_type, dimensions) + def count(self, *, exact: bool = True, discard: bool = False) -> int: """Count the number of rows this query would return. diff --git a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py index c4710a11a5..b3ea3ebe57 100644 --- a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py +++ b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py @@ -141,7 +141,7 @@ def filter_dataset_collections( filtered_collections.append(collection_record) return result - def make_dataset_query_relation( + def _make_dataset_query_relation_impl( self, dataset_type: DatasetType, collections: Sequence[CollectionRecord], @@ -245,6 +245,10 @@ def make_dimension_relation( "it is part of a dataset subquery, spatial join, or other initial relation." ) + # Before joining in new tables to provide columns, attempt to restore + # them from the given relation by weakening projections applied to it. + relation, _ = context.restore_columns(relation, columns_required) + # Categorize columns not yet included in the relation to associate them # with dimension elements and detect bad inputs. missing_columns = ColumnCategorization.from_iterable(columns_required - relation.columns) diff --git a/python/lsst/daf/butler/registry/tests/_registry.py b/python/lsst/daf/butler/registry/tests/_registry.py index 87b9a0f1c9..3f84dc9af4 100644 --- a/python/lsst/daf/butler/registry/tests/_registry.py +++ b/python/lsst/daf/butler/registry/tests/_registry.py @@ -1490,9 +1490,12 @@ def testQueryResults(self): expectedDeduplicatedBiases, ) - # Check dimensions match. - with self.assertRaises(ValueError): - subsetDataIds.findDatasets("flat", collections=["imported_r", "imported_g"], findFirst=True) + # Searching for a dataset with dimensions we had projected away + # restores those dimensions. + self.assertCountEqual( + list(subsetDataIds.findDatasets("flat", collections=["imported_r"], findFirst=True)), + expectedFlats, + ) # Use a component dataset type. self.assertCountEqual( @@ -2027,8 +2030,9 @@ def range_set_hull( def testCalibrationCollections(self): """Test operations on `~CollectionType.CALIBRATION` collections, - including `Registry.certify`, `Registry.decertify`, and - `Registry.findDataset`. + including `Registry.certify`, `Registry.decertify`, + `Registry.findDataset`, and + `DataCoordinateQueryResults.findRelatedDatasets`. """ # Setup - make a Registry, fill it with some datasets in # non-calibration collections. @@ -2044,6 +2048,39 @@ def testCalibrationCollections(self): allTimespans = [ Timespan(a, b) for a, b in itertools.combinations([None, t1, t2, t3, t4, t5, None], r=2) ] + # Insert some exposure records with timespans between each sequential + # pair of those. + registry.insertDimensionData( + "exposure", + { + "instrument": "Cam1", + "id": 0, + "obs_id": "zero", + "physical_filter": "Cam1-G", + "timespan": Timespan(t1, t2), + }, + { + "instrument": "Cam1", + "id": 1, + "obs_id": "one", + "physical_filter": "Cam1-G", + "timespan": Timespan(t2, t3), + }, + { + "instrument": "Cam1", + "id": 2, + "obs_id": "two", + "physical_filter": "Cam1-G", + "timespan": Timespan(t3, t4), + }, + { + "instrument": "Cam1", + "id": 3, + "obs_id": "three", + "physical_filter": "Cam1-G", + "timespan": Timespan(t4, t5), + }, + ) # Get references to some datasets. bias2a = registry.findDataset("bias", instrument="Cam1", detector=2, collections="imported_g") bias3a = registry.findDataset("bias", instrument="Cam1", detector=3, collections="imported_g") @@ -2058,8 +2095,7 @@ def testCalibrationCollections(self): # Certify 2a dataset with [t2, t4) validity. registry.certify(collection, [bias2a], Timespan(begin=t2, end=t4)) # Test that we can query for this dataset via the new collection, both - # on its own and with a RUN collection, as long as we don't try to join - # in temporal dimensions or use findFirst=True. + # on its own and with a RUN collection. self.assertEqual( set(registry.queryDatasets("bias", findFirst=False, collections=collection)), {bias2a}, @@ -2085,6 +2121,30 @@ def testCalibrationCollections(self): registry.expandDataId(instrument="Cam1", detector=4), }, ) + self.assertEqual( + set( + registry.queryDataIds(["exposure", "detector"]).findRelatedDatasets( + "bias", findFirst=True, collections=[collection] + ) + ), + { + (registry.expandDataId(instrument="Cam1", detector=2, exposure=1), bias2a), + (registry.expandDataId(instrument="Cam1", detector=2, exposure=2), bias2a), + }, + ) + self.assertEqual( + set( + registry.queryDataIds( + ["exposure", "detector"], instrument="Cam1", detector=2 + ).findRelatedDatasets("bias", findFirst=True, collections=[collection, "imported_r"]) + ), + { + (registry.expandDataId(instrument="Cam1", detector=2, exposure=1), bias2a), + (registry.expandDataId(instrument="Cam1", detector=2, exposure=2), bias2a), + (registry.expandDataId(instrument="Cam1", detector=2, exposure=0), bias2b), + (registry.expandDataId(instrument="Cam1", detector=2, exposure=3), bias2b), + }, + ) # We should not be able to certify 2b with anything overlapping that # window. @@ -2217,6 +2277,58 @@ def assertLookup( assertLookup(detector=3, timespan=Timespan(t4, None), expected=bias3b) assertLookup(detector=3, timespan=Timespan(t5, None), expected=bias3b) + # Test lookups via temporal joins to exposures. + self.assertEqual( + set( + registry.queryDataIds( + ["exposure", "detector"], instrument="Cam1", detector=2 + ).findRelatedDatasets("bias", collections=[collection]) + ), + { + (registry.expandDataId(instrument="Cam1", exposure=1, detector=2), bias2a), + (registry.expandDataId(instrument="Cam1", exposure=2, detector=2), bias2a), + (registry.expandDataId(instrument="Cam1", exposure=3, detector=2), bias2b), + }, + ) + self.assertEqual( + set( + registry.queryDataIds( + ["exposure", "detector"], instrument="Cam1", detector=3 + ).findRelatedDatasets("bias", collections=[collection]) + ), + { + (registry.expandDataId(instrument="Cam1", exposure=0, detector=3), bias3a), + (registry.expandDataId(instrument="Cam1", exposure=1, detector=3), bias3a), + (registry.expandDataId(instrument="Cam1", exposure=3, detector=3), bias3b), + }, + ) + self.assertEqual( + set( + registry.queryDataIds( + ["exposure", "detector"], instrument="Cam1", detector=2 + ).findRelatedDatasets("bias", collections=[collection, "imported_g"]) + ), + { + (registry.expandDataId(instrument="Cam1", exposure=0, detector=2), bias2a), + (registry.expandDataId(instrument="Cam1", exposure=1, detector=2), bias2a), + (registry.expandDataId(instrument="Cam1", exposure=2, detector=2), bias2a), + (registry.expandDataId(instrument="Cam1", exposure=3, detector=2), bias2b), + }, + ) + self.assertEqual( + set( + registry.queryDataIds( + ["exposure", "detector"], instrument="Cam1", detector=3 + ).findRelatedDatasets("bias", collections=[collection, "imported_g"]) + ), + { + (registry.expandDataId(instrument="Cam1", exposure=0, detector=3), bias3a), + (registry.expandDataId(instrument="Cam1", exposure=1, detector=3), bias3a), + (registry.expandDataId(instrument="Cam1", exposure=2, detector=3), bias3a), + (registry.expandDataId(instrument="Cam1", exposure=3, detector=3), bias3b), + }, + ) + # Decertify [t3, t5) for all data IDs, and do test lookups again. # This should truncate bias2a to [t2, t3), leave bias3a unchanged at # [t1, t3), and truncate bias2b and bias3b to [t5, ∞). @@ -3522,3 +3634,43 @@ def test_query_empty_collections(self) -> None: messages = list(result.explain_no_results()) self.assertTrue(messages) self.assertTrue(any("because collection list is empty" in message for message in messages)) + + def test_dataset_followup_spatial_joins(self) -> None: + """Test queryDataIds(...).findRelatedDatasets(...) where a spatial join + is involved. + """ + registry = self.makeRegistry() + self.loadData(registry, "base.yaml") + self.loadData(registry, "spatial.yaml") + pvi_dataset_type = DatasetType( + "pvi", {"visit", "detector"}, storageClass="StructuredDataDict", universe=registry.dimensions + ) + registry.registerDatasetType(pvi_dataset_type) + collection = "datasets" + registry.registerRun(collection) + (pvi1,) = registry.insertDatasets( + pvi_dataset_type, [{"instrument": "Cam1", "visit": 1, "detector": 1}], run=collection + ) + (pvi2,) = registry.insertDatasets( + pvi_dataset_type, [{"instrument": "Cam1", "visit": 1, "detector": 2}], run=collection + ) + (pvi3,) = registry.insertDatasets( + pvi_dataset_type, [{"instrument": "Cam1", "visit": 1, "detector": 3}], run=collection + ) + self.assertEqual( + set( + registry.queryDataIds(["patch"], skymap="SkyMap1", tract=0) + .expanded() + .findRelatedDatasets("pvi", [collection]) + ), + { + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=0), pvi1), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=0), pvi2), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=1), pvi2), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=2), pvi1), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=2), pvi2), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=2), pvi3), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=3), pvi2), + (registry.expandDataId(skymap="SkyMap1", tract=0, patch=4), pvi3), + }, + ) diff --git a/tests/data/registry/spatial.py b/tests/data/registry/spatial.py index d18ff224fb..393e63fdb5 100644 --- a/tests/data/registry/spatial.py +++ b/tests/data/registry/spatial.py @@ -252,7 +252,7 @@ def make_plots(detector_grid: bool, patch_grid: bool): index_labels(color="black", alpha=0.5), ) colors = iter(["red", "blue", "cyan", "green"]) - for (visit_id, visit_data), color in zip(VISIT_DATA.items(), colors, strict=True): + for (visit_id, visit_data), color in zip(VISIT_DATA.items(), colors, strict=False): for detector_id, pixel_indices in visit_data["detector_regions"].items(): label = f"visit={visit_id}" if label in labels_used: