diff --git a/python/lsst/pipe/base/connectionTypes.py b/python/lsst/pipe/base/connectionTypes.py index f092d3c07..a9d7fb41b 100644 --- a/python/lsst/pipe/base/connectionTypes.py +++ b/python/lsst/pipe/base/connectionTypes.py @@ -30,7 +30,16 @@ from collections.abc import Callable, Iterable, Sequence from typing import Optional, Union -from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Registry, StorageClass +from lsst.daf.butler import ( + DataCoordinate, + DatasetRef, + DatasetType, + DimensionGraph, + DimensionUniverse, + Registry, + SkyPixDimension, + StorageClass, +) @dataclasses.dataclass(frozen=True) @@ -57,6 +66,7 @@ class BaseConnection: storageClass: str doc: str = "" multiple: bool = False + isCalibration: bool = False def __get__(self, inst, klass): """Descriptor method @@ -113,6 +123,34 @@ def makeDatasetType( self.name, universe.empty, self.storageClass, parentStorageClass=parentStorageClass ) + def resolve_dimensions( + self, universe: DimensionUniverse, expected: DimensionGraph | None, task_label: str + ) -> DimensionGraph: + """Resolve the dimensions of dataset type associated with this + connnection. + + Parameters + ---------- + universe : `DimensionUniverse` + Set of all known dimensions to be used to normalize the dimension + names specified in config. + expected : `DimensionGraph` or `None` + The dimensions this dataset type is expected to have, from either + the data repository or an output connection in the same pipeline. + Implementations may use this to fill in wildcards, but do not in + general use it to check for consistency (as this is better handled + by more general code elsewhere) + task_label : `str` + Task label associated with this connection. Should be included in + any error messages. + + Returns + ------- + dimensions : `lsst.daf.butler.DimensionGraph` + Fully expanded and resolved dimensions. + """ + return universe.empty + @dataclasses.dataclass(frozen=True) class DimensionedConnection(BaseConnection): @@ -143,7 +181,6 @@ class DimensionedConnection(BaseConnection): """ dimensions: typing.Iterable[str] = () - isCalibration: bool = False def __post_init__(self): if isinstance(self.dimensions, str): @@ -179,6 +216,12 @@ def makeDatasetType( parentStorageClass=parentStorageClass, ) + def resolve_dimensions( + self, universe: DimensionUniverse, expected: DimensionGraph | None, task_label: str + ) -> DimensionGraph: + # Docstring inherited. + return universe.extract(self.dimensions) + @dataclasses.dataclass(frozen=True) class BaseInput(DimensionedConnection): @@ -315,6 +358,29 @@ class PrerequisiteInput(BaseInput): Callable[[DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef]] ] = None + def resolve_dimensions( + self, universe: DimensionUniverse, expected: DimensionGraph | None, task_label: str + ) -> DimensionGraph: + # Docstring inherited. + if "skypix" in self.dimensions: + if expected is None: + raise LookupError( + f"DatasetType '{self.name}' referenced by " + f"{task_label!r} uses 'skypix' as a dimension " + f"placeholder, but does not already exist in the registry. " + f"Note that reference catalog names are now used as the dataset " + f"type name instead of 'ref_cat'." + ) + rest1 = set(universe.extract(set(self.dimensions) - set(["skypix"])).names) + rest2 = set(dim.name for dim in expected.dimensions if not isinstance(dim, SkyPixDimension)) + if rest1 != rest2: + raise ValueError( + f"Non-skypix dimensions for dataset type {self.name} declared in " + f"connections ({rest1}) are inconsistent with those in " + f"registry's version of this dataset ({rest2})." + ) + return universe.extract(self.dimensions) + @dataclasses.dataclass(frozen=True) class Output(DimensionedConnection):