From 7dcc6993a35aa0c1ad4630e5c5f5945de93b56da Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 8 Jun 2023 18:30:40 -0400 Subject: [PATCH] Add Pipeline.get_data_id. --- python/lsst/pipe/base/graphBuilder.py | 15 +++---------- python/lsst/pipe/base/pipeline.py | 31 ++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/python/lsst/pipe/base/graphBuilder.py b/python/lsst/pipe/base/graphBuilder.py index c2658345..7d2d322e 100644 --- a/python/lsst/pipe/base/graphBuilder.py +++ b/python/lsst/pipe/base/graphBuilder.py @@ -53,7 +53,6 @@ from lsst.daf.butler.registry import MissingCollectionError, MissingDatasetTypeError from lsst.daf.butler.registry.queries import DataCoordinateQueryResults from lsst.daf.butler.registry.wildcards import CollectionWildcard -from lsst.utils import doImportType # ----------------------------- # Imports for other modules -- @@ -1606,18 +1605,10 @@ def makeGraph( scaffolding = _PipelineScaffolding(pipeline, registry=self.registry) if not collections and (scaffolding.initInputs or scaffolding.inputs or scaffolding.prerequisites): raise ValueError("Pipeline requires input datasets but no input collections provided.") - instrument_class: Optional[Any] = None - if isinstance(pipeline, Pipeline): - instrument_class_name = pipeline.getInstrument() - if instrument_class_name is not None: - instrument_class = doImportType(instrument_class_name) - pipeline = list(pipeline.toExpandedPipeline()) - if instrument_class is not None: - dataId = DataCoordinate.standardize( - dataId, instrument=instrument_class.getName(), universe=self.registry.dimensions - ) - elif dataId is None: + if dataId is None: dataId = DataCoordinate.makeEmpty(self.registry.dimensions) + if isinstance(pipeline, Pipeline): + dataId = pipeline.get_data_id(self.registry.dimensions).union(dataId) with scaffolding.connectDataIds( self.registry, collections, userQuery, dataId, datasetQueryConstraint, bind ) as commonDataIds: diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index f1d198d0..e5754cd7 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -55,7 +55,14 @@ # ----------------------------- # Imports for other modules -- -from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension +from lsst.daf.butler import ( + DataCoordinate, + DatasetType, + DimensionUniverse, + NamedValueSet, + Registry, + SkyPixDimension, +) from lsst.resources import ResourcePath, ResourcePathExpression from lsst.utils import doImportType from lsst.utils.introspection import get_full_type_name @@ -613,6 +620,28 @@ def getInstrument(self) -> Optional[str]: """ return self._pipelineIR.instrument + def get_data_id(self, universe: DimensionUniverse) -> DataCoordinate: + """Return a data ID with all dimension constraints embedded in the + pipeline. + + Parameters + ---------- + universe : `lsst.daf.butler.DimensionUniverse` + Object that defines all dimensions. + + Returns + ------- + data_id : `lsst.daf.butler.DataCoordinate` + Data ID with all dimension constraints embedded in the + pipeline. + """ + instrument_class_name = self._pipelineIR.instrument + if instrument_class_name is not None: + instrument_class = doImportType(instrument_class_name) + if instrument_class is not None: + return DataCoordinate.standardize(instrument=instrument_class.getName(), universe=universe) + return DataCoordinate.makeEmpty(universe) + def addTask(self, task: Union[Type[PipelineTask], str], label: str) -> None: """Add a new task to the pipeline, or replace a task that is already associated with the supplied label.