diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index 05cc4e460..1c32e7fe9 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -35,10 +35,8 @@ # Imports of standard modules -- # ------------------------------- from dataclasses import dataclass -from types import MappingProxyType from typing import ( TYPE_CHECKING, - AbstractSet, Callable, ClassVar, Dict, @@ -56,7 +54,7 @@ # ----------------------------- # Imports for other modules -- -from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension +from lsst.daf.butler import DatasetType, NamedValueSet, Registry from lsst.resources import ResourcePath, ResourcePathExpression from lsst.utils import doImportType from lsst.utils.introspection import get_full_type_name @@ -65,8 +63,7 @@ from . import pipeline_graph, pipelineIR from .config import PipelineTaskConfig from .configOverrides import ConfigOverrides -from .connections import PipelineTaskConnections, iterConnections -from .connectionTypes import Input +from .connections import PipelineTaskConnections from .pipelineTask import PipelineTask if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. @@ -929,7 +926,6 @@ def fromTaskDef( *, registry: Registry, include_configs: bool = True, - storage_class_mapping: Optional[Mapping[str, str]] = None, ) -> TaskDatasetTypes: """Extract and classify the dataset types from a single `PipelineTask`. @@ -943,226 +939,83 @@ def fromTaskDef( include_configs : `bool`, optional If `True` (default) include config dataset types as ``initOutputs``. - storage_class_mapping : `Mapping` of `str` to `StorageClass`, optional - If a taskdef contains a component dataset type that is unknown - to the registry, its parent StorageClass will be looked up in this - mapping if it is supplied. If the mapping does not contain the - composite dataset type, or the mapping is not supplied an exception - will be raised. Returns ------- - types: `TaskDatasetTypes` + types : `TaskDatasetTypes` The dataset types used by this task. Raises ------ - ValueError + IncompatibleDatasetTypeError Raised if dataset type connection definition differs from registry definition. - LookupError - Raised if component parent StorageClass could not be determined - and storage_class_mapping does not contain the composite type, or - is set to None. + MissingDatasetTypeError + Raised if component parent StorageClass could not be determined. """ + # Since it's no longer used by PipelineDatasetTypes, I expect this to + # be extremely rarely used, so I'm not bothered by the implementation + # involving making a single-task PipelineGraph. I hope to deprecate + # the whole class soon, but for now and before it's actually removed + # it's more important to avoid duplication with PipelineGraph's dataset + # type resolution logic. + mgraph = pipeline_graph.MutablePipelineGraph() + mgraph.add_task(taskDef.label, taskDef.taskClass, taskDef.config, taskDef.connections) + rgraph = mgraph.resolved(registry) + (task_node,) = rgraph.tasks.values() + return cls._from_graph_nodes(task_node, rgraph.dataset_types) - def makeDatasetTypesSet( - connectionType: str, - is_input: bool, - freeze: bool = True, - ) -> NamedValueSet[DatasetType]: - """Constructs a set of true `DatasetType` objects - - Parameters - ---------- - connectionType : `str` - Name of the connection type to produce a set for, corresponds - to an attribute of type `list` on the connection class instance - is_input : `bool` - These are input dataset types, else they are output dataset - types. - freeze : `bool`, optional - If `True`, call `NamedValueSet.freeze` on the object returned. - - Returns - ------- - datasetTypes : `NamedValueSet` - A set of all datasetTypes which correspond to the input - connection type specified in the connection class of this - `PipelineTask` - - Raises - ------ - ValueError - Raised if dataset type connection definition differs from - registry definition. - LookupError - Raised if component parent StorageClass could not be determined - and storage_class_mapping does not contain the composite type, - or is set to None. - - Notes - ----- - This function is a closure over the variables ``registry`` and - ``taskDef``, and ``storage_class_mapping``. - """ - datasetTypes = NamedValueSet[DatasetType]() - for c in iterConnections(taskDef.connections, connectionType): - dimensions = set(getattr(c, "dimensions", set())) - if "skypix" in dimensions: - try: - datasetType = registry.getDatasetType(c.name) - except LookupError as err: - raise LookupError( - f"DatasetType '{c.name}' referenced by " - f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension " - "placeholder, but does not already exist in the registry. " - "Note that reference catalog names are now used as the dataset " - "type name instead of 'ref_cat'." - ) from err - rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names) - rest2 = set( - dim.name for dim in datasetType.dimensions if not isinstance(dim, SkyPixDimension) - ) - if rest1 != rest2: - raise ValueError( - f"Non-skypix dimensions for dataset type {c.name} declared in " - f"connections ({rest1}) are inconsistent with those in " - f"registry's version of this dataset ({rest2})." - ) - else: - # Component dataset types are not explicitly in the - # registry. This complicates consistency checks with - # registry and requires we work out the composite storage - # class. - registryDatasetType = None - try: - registryDatasetType = registry.getDatasetType(c.name) - except KeyError: - compositeName, componentName = DatasetType.splitDatasetTypeName(c.name) - if componentName: - if storage_class_mapping is None or compositeName not in storage_class_mapping: - raise LookupError( - "Component parent class cannot be determined, and " - "composite name was not in storage class mapping, or no " - "storage_class_mapping was supplied" - ) - else: - parentStorageClass = storage_class_mapping[compositeName] - else: - parentStorageClass = None - datasetType = c.makeDatasetType( - registry.dimensions, parentStorageClass=parentStorageClass - ) - registryDatasetType = datasetType - else: - datasetType = c.makeDatasetType( - registry.dimensions, parentStorageClass=registryDatasetType.parentStorageClass - ) - - if registryDatasetType and datasetType != registryDatasetType: - # The dataset types differ but first check to see if - # they are compatible before raising. - if is_input: - # This DatasetType must be compatible on get. - is_compatible = datasetType.is_compatible_with(registryDatasetType) - else: - # Has to be able to be converted to expect type - # on put. - is_compatible = registryDatasetType.is_compatible_with(datasetType) - if is_compatible: - # For inputs we want the pipeline to use the - # pipeline definition, for outputs it should use - # the registry definition. - if not is_input: - datasetType = registryDatasetType - _LOG.debug( - "Dataset types differ (task %s != registry %s) but are compatible" - " for %s in %s.", - datasetType, - registryDatasetType, - "input" if is_input else "output", - taskDef.label, - ) - else: - try: - # Explicitly check for storage class just to - # make more specific message. - _ = datasetType.storageClass - except KeyError: - raise ValueError( - "Storage class does not exist for supplied dataset type " - f"{datasetType} for {taskDef.label}." - ) from None - raise ValueError( - f"Supplied dataset type ({datasetType}) inconsistent with " - f"registry definition ({registryDatasetType}) " - f"for {taskDef.label}." - ) - datasetTypes.add(datasetType) - if freeze: - datasetTypes.freeze() - return datasetTypes - - # optionally add initOutput dataset for config - initOutputs = makeDatasetTypesSet("initOutputs", is_input=False, freeze=False) - if include_configs: - initOutputs.add( - DatasetType( - taskDef.configDatasetName, - registry.dimensions.empty, - storageClass=acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, - ) - ) - initOutputs.freeze() - - # optionally add output dataset for metadata - outputs = makeDatasetTypesSet("outputs", is_input=False, freeze=False) - if taskDef.metadataDatasetName is not None: - # Metadata is supposed to be of the TaskMetadata type, its - # dimensions correspond to a task quantum. - dimensions = registry.dimensions.extract(taskDef.connections.dimensions) - - # Allow the storage class definition to be read from the existing - # dataset type definition if present. - try: - current = registry.getDatasetType(taskDef.metadataDatasetName) - except KeyError: - # No previous definition so use the default. - storageClass = acc.METADATA_OUTPUT_STORAGE_CLASS - else: - storageClass = current.storageClass.name - - outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)}) - if taskDef.logOutputDatasetName is not None: - # Log output dimensions correspond to a task quantum. - dimensions = registry.dimensions.extract(taskDef.connections.dimensions) - outputs.update( - { - DatasetType( - taskDef.logOutputDatasetName, - dimensions, - acc.LOG_OUTPUT_STORAGE_CLASS, - ) - } - ) - - outputs.freeze() + @classmethod + def _from_graph_nodes( + cls, + task_node: pipeline_graph.ResolvedTaskNode, + dataset_type_nodes: Mapping[str, pipeline_graph.ResolvedDatasetTypeNode], + include_configs: bool = True, + ) -> TaskDatasetTypes: + """Construct from `PipelineGraph` nodes. - inputs = makeDatasetTypesSet("inputs", is_input=True) - queryConstraints = NamedValueSet( - inputs[c.name] - for c in cast(Iterable[Input], iterConnections(taskDef.connections, "inputs")) - if not c.deferGraphConstraint - ) + Parameters + ---------- + task_node : `pipeline_graph.TaskNode` + Task node to extract dataset types from. + dataset_type_nodes : `Mapping` [ `str`, \ + `pipeline_graph.DatasetTypeNode` ] + Mapping of all dataset type nodes in the graph. + include_configs : `bool`, optional + If `True` (default) include config dataset types as + ``initOutputs``. + Returns + ------- + types : `TaskDatasetTypes` + The dataset types used by this task. + """ return cls( - initInputs=makeDatasetTypesSet("initInputs", is_input=True), - initOutputs=initOutputs, - inputs=inputs, - queryConstraints=queryConstraints, - prerequisites=makeDatasetTypesSet("prerequisiteInputs", is_input=True), - outputs=outputs, + initInputs=NamedValueSet( + edge.adapt_dataset_type(dataset_type_nodes[edge.parent_dataset_type_name].dataset_type) + for edge in task_node.init.iter_all_inputs() + ), + initOutputs=NamedValueSet( + edge.adapt_dataset_type(dataset_type_nodes[edge.parent_dataset_type_name].dataset_type) + for edge in (task_node.init.iter_all_outputs() if include_configs else task_node.init.outputs) + ), + inputs=NamedValueSet( + edge.adapt_dataset_type(dataset_type_nodes[edge.parent_dataset_type_name].dataset_type) + for edge in task_node.inputs + ), + queryConstraints=NamedValueSet( + edge.adapt_dataset_type(node.dataset_type) + for edge in task_node.inputs + if (node := dataset_type_nodes[edge.parent_dataset_type_name]).is_initial_query_constraint + ), + prerequisites=NamedValueSet( + edge.adapt_dataset_type(dataset_type_nodes[edge.parent_dataset_type_name].dataset_type) + for edge in task_node.prerequisite_inputs + ), + outputs=NamedValueSet( + edge.adapt_dataset_type(dataset_type_nodes[edge.parent_dataset_type_name].dataset_type) + for edge in task_node.iter_all_outputs() + ), ) @@ -1245,7 +1098,7 @@ class PipelineDatasetTypes: @classmethod def fromPipeline( cls, - pipeline: Union[Pipeline, Iterable[TaskDef]], + pipeline: Pipeline | Iterable[TaskDef], *, registry: Registry, include_configs: bool = True, @@ -1275,122 +1128,98 @@ def fromPipeline( Raises ------ - ValueError + IncompatibleDatasetTypeError Raised if Tasks are inconsistent about which datasets are marked prerequisite. This indicates that the Tasks cannot be run as part of the same `Pipeline`. """ - allInputs = NamedValueSet[DatasetType]() - allOutputs = NamedValueSet[DatasetType]() - allInitInputs = NamedValueSet[DatasetType]() - allInitOutputs = NamedValueSet[DatasetType]() - prerequisites = NamedValueSet[DatasetType]() - queryConstraints = NamedValueSet[DatasetType]() + if isinstance(pipeline, Pipeline): + mgraph = pipeline.to_graph() + else: + mgraph = pipeline_graph.MutablePipelineGraph() + for task_def in pipeline: + mgraph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections) + rgraph = mgraph.resolved(registry) byTask = dict() + for task_node in rgraph.tasks.values(): + byTask[task_node.label] = TaskDatasetTypes._from_graph_nodes( + task_node, + rgraph.dataset_types, + include_configs=include_configs, + ) + result = cls( + initInputs=NamedValueSet(), + initOutputs=NamedValueSet(), + initIntermediates=NamedValueSet(), + inputs=NamedValueSet(), + queryConstraints=NamedValueSet(), + prerequisites=NamedValueSet(), + intermediates=NamedValueSet(), + outputs=NamedValueSet(), + byTask=byTask, + ) + # All of the logic involving components below is unfortunate, because + # downstream code would be far simpler if PipelineDatasetTypes always + # converted them to their parent dataset types and left + # component-handling to TaskDatasetTypes (which is more or less what + # PipelineGraph does, by putting components in the edge objects). But + # including all components as well is what this code has done in the + # past and changing that would break downstream code. + for dataset_type_node in rgraph.dataset_types.values(): + if consumers := rgraph.consumers_of(dataset_type_node.name): + dataset_types = [ + ( + dataset_type_node.dataset_type.makeComponentDatasetType(edge.component) + if edge.component is not None + else dataset_type_node.dataset_type + ) + for edge in consumers.values() + ] + if any(edge.is_init for edge in consumers.values()): + if rgraph.producer_of(dataset_type_node.name) is None: + result.initInputs.update(dataset_types) + else: + result.initIntermediates.update(dataset_types) + else: + if dataset_type_node.is_prerequisite: + result.prerequisites.update(dataset_types) + elif rgraph.producer_of(dataset_type_node.name) is None: + result.inputs.update(dataset_types) + if dataset_type_node.is_initial_query_constraint: + result.queryConstraints.add(dataset_type_node.dataset_type) + elif rgraph.consumers_of(dataset_type_node.name): + result.intermediates.update(dataset_types) + else: + producer = rgraph.producer_of(dataset_type_node.name) + assert ( + producer is not None + ), "Dataset type must have either a producer or consumers to be in graph." + if producer.is_init: + result.initOutputs.add(dataset_type_node.dataset_type) + else: + result.outputs.add(dataset_type_node.dataset_type) if include_packages: - allInitOutputs.add( + result.initOutputs.add( DatasetType( cls.packagesDatasetName, registry.dimensions.empty, storageClass="Packages", ) ) - # create a list of TaskDefs in case the input is a generator - pipeline = list(pipeline) - - # collect all the output dataset types - typeStorageclassMap: Dict[str, str] = {} - for taskDef in pipeline: - for outConnection in iterConnections(taskDef.connections, "outputs"): - typeStorageclassMap[outConnection.name] = outConnection.storageClass - - for taskDef in pipeline: - thisTask = TaskDatasetTypes.fromTaskDef( - taskDef, - registry=registry, - include_configs=include_configs, - storage_class_mapping=typeStorageclassMap, - ) - allInitInputs.update(thisTask.initInputs) - allInitOutputs.update(thisTask.initOutputs) - allInputs.update(thisTask.inputs) - # Inputs are query constraints if any task considers them a query - # constraint. - queryConstraints.update(thisTask.queryConstraints) - prerequisites.update(thisTask.prerequisites) - allOutputs.update(thisTask.outputs) - byTask[taskDef.label] = thisTask - if not prerequisites.isdisjoint(allInputs): - raise ValueError( - "{} marked as both prerequisites and regular inputs".format( - {dt.name for dt in allInputs & prerequisites} - ) - ) - if not prerequisites.isdisjoint(allOutputs): - raise ValueError( - "{} marked as both prerequisites and outputs".format( - {dt.name for dt in allOutputs & prerequisites} - ) - ) - # Make sure that components which are marked as inputs get treated as - # intermediates if there is an output which produces the composite - # containing the component - intermediateComponents = NamedValueSet[DatasetType]() - intermediateComposites = NamedValueSet[DatasetType]() - outputNameMapping = {dsType.name: dsType for dsType in allOutputs} - for dsType in allInputs: - # get the name of a possible component - name, component = dsType.nameAndComponent() - # if there is a component name, that means this is a component - # DatasetType, if there is an output which produces the parent of - # this component, treat this input as an intermediate - if component is not None: - # This needs to be in this if block, because someone might have - # a composite that is a pure input from existing data - if name in outputNameMapping: - intermediateComponents.add(dsType) - intermediateComposites.add(outputNameMapping[name]) - - def checkConsistency(a: NamedValueSet, b: NamedValueSet) -> None: - common = a.names & b.names - for name in common: - # Any compatibility is allowed. This function does not know - # if a dataset type is to be used for input or output. - if not (a[name].is_compatible_with(b[name]) or b[name].is_compatible_with(a[name])): - raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.") - - checkConsistency(allInitInputs, allInitOutputs) - checkConsistency(allInputs, allOutputs) - checkConsistency(allInputs, intermediateComposites) - checkConsistency(allOutputs, intermediateComposites) - - def frozen(s: AbstractSet[DatasetType]) -> NamedValueSet[DatasetType]: - assert isinstance(s, NamedValueSet) - s.freeze() - return s - - inputs = frozen(allInputs - allOutputs - intermediateComponents) - - return cls( - initInputs=frozen(allInitInputs - allInitOutputs), - initIntermediates=frozen(allInitInputs & allInitOutputs), - initOutputs=frozen(allInitOutputs - allInitInputs), - inputs=inputs, - queryConstraints=frozen(queryConstraints & inputs), - # If there are storage class differences in inputs and outputs - # the intermediates have to choose priority. Here choose that - # inputs to tasks much match the requested storage class by - # applying the inputs over the top of the outputs. - intermediates=frozen(allOutputs & allInputs | intermediateComponents), - outputs=frozen(allOutputs - allInputs - intermediateComposites), - prerequisites=frozen(prerequisites), - byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability - ) + result.initInputs.freeze() + result.initOutputs.freeze() + result.initIntermediates.freeze() + result.inputs.freeze() + result.queryConstraints.freeze() + result.prerequisites.freeze() + result.intermediates.freeze() + result.outputs.freeze() + return result @classmethod def initOutputNames( cls, - pipeline: Union[Pipeline, Iterable[TaskDef]], + pipeline: Pipeline | Iterable[TaskDef], *, include_configs: bool = True, include_packages: bool = True, @@ -1412,19 +1241,16 @@ def initOutputNames( datasetTypeName : `str` Name of the dataset type. """ + if isinstance(pipeline, Pipeline): + mgraph = pipeline.to_graph() + else: + mgraph = pipeline_graph.MutablePipelineGraph() + for task_def in pipeline: + mgraph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections) if include_packages: # Package versions dataset type yield cls.packagesDatasetName - - if isinstance(pipeline, Pipeline): - pipeline = pipeline.toExpandedPipeline() - - for taskDef in pipeline: - # all task InitOutputs - for name in taskDef.connections.initOutputs: - attribute = getattr(taskDef.connections, name) - yield attribute.name - - # config dataset name - if include_configs: - yield taskDef.configDatasetName + for task_node in mgraph.tasks.values(): + edges = task_node.init.iter_all_outputs() if include_configs else task_node.init.outputs + for edge in edges: + yield edge.dataset_type_name