Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reimplement PipelineDatasetTypes via PipelineGraph.
Browse files Browse the repository at this point in the history
The long-term plan is to get rid of PipelineDatasetTypes, but for now
and while it's deprecated this removes a lot of duplication.

This *slightly* changes the interface for TaskDatasetTypes, by dropping
the storage_class_mapping argument to fromTaskDef.  But that was really
only usable by PipelineDatasetTypes, which no longer calls that method.
TallJimbo committed Mar 2, 2023
1 parent ab2d4b1 commit 7a893cb
Showing 1 changed file with 155 additions and 326 deletions.
481 changes: 155 additions & 326 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
@@ -35,10 +35,8 @@
# Imports of standard modules --
# -------------------------------
from dataclasses import dataclass
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
AbstractSet,
Callable,
ClassVar,
Dict,
@@ -56,16 +54,15 @@

# -----------------------------
# Imports for other modules --
from lsst.daf.butler import DatasetType, NamedValueSet, Registry, SkyPixDimension
from lsst.daf.butler import DatasetType, NamedValueSet, Registry
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils import doImportType
from lsst.utils.introspection import get_full_type_name

from . import pipeline_graph, pipelineIR
from .config import PipelineTaskConfig
from .configOverrides import ConfigOverrides
from .connections import PipelineTaskConnections, iterConnections
from .connectionTypes import Input
from .connections import PipelineTaskConnections
from .pipelineTask import PipelineTask

if TYPE_CHECKING: # Imports needed only for type annotations; may be circular.
@@ -924,7 +921,6 @@ def fromTaskDef(
*,
registry: Registry,
include_configs: bool = True,
storage_class_mapping: Optional[Mapping[str, str]] = None,
) -> TaskDatasetTypes:
"""Extract and classify the dataset types from a single `PipelineTask`.
@@ -938,226 +934,83 @@ def fromTaskDef(
include_configs : `bool`, optional
If `True` (default) include config dataset types as
``initOutputs``.
storage_class_mapping : `Mapping` of `str` to `StorageClass`, optional
If a taskdef contains a component dataset type that is unknown
to the registry, its parent StorageClass will be looked up in this
mapping if it is supplied. If the mapping does not contain the
composite dataset type, or the mapping is not supplied an exception
will be raised.
Returns
-------
types: `TaskDatasetTypes`
types : `TaskDatasetTypes`
The dataset types used by this task.
Raises
------
ValueError
IncompatibleDatasetTypeError
Raised if dataset type connection definition differs from
registry definition.
LookupError
Raised if component parent StorageClass could not be determined
and storage_class_mapping does not contain the composite type, or
is set to None.
MissingDatasetTypeError
Raised if component parent StorageClass could not be determined.
"""
# Since it's no longer used by PipelineDatasetTypes, I expect this to
# be extremely rarely used, so I'm not bothered by the implementation
# involving making a single-task PipelineGraph. I hope to deprecate
# the whole class soon, but for now and before it's actually removed
# it's more important to avoid duplication with PipelineGraph's dataset
# type resolution logic.
mgraph = pipeline_graph.MutablePipelineGraph()
mgraph.add_task(taskDef.label, taskDef.taskClass, taskDef.config, taskDef.connections)
rgraph = mgraph.resolved(registry)
(task_node,) = rgraph.tasks.values()
return cls._from_graph_nodes(task_node, rgraph.dataset_types)

def makeDatasetTypesSet(
connectionType: str,
is_input: bool,
freeze: bool = True,
) -> NamedValueSet[DatasetType]:
"""Constructs a set of true `DatasetType` objects
Parameters
----------
connectionType : `str`
Name of the connection type to produce a set for, corresponds
to an attribute of type `list` on the connection class instance
is_input : `bool`
These are input dataset types, else they are output dataset
types.
freeze : `bool`, optional
If `True`, call `NamedValueSet.freeze` on the object returned.
Returns
-------
datasetTypes : `NamedValueSet`
A set of all datasetTypes which correspond to the input
connection type specified in the connection class of this
`PipelineTask`
Raises
------
ValueError
Raised if dataset type connection definition differs from
registry definition.
LookupError
Raised if component parent StorageClass could not be determined
and storage_class_mapping does not contain the composite type,
or is set to None.
Notes
-----
This function is a closure over the variables ``registry`` and
``taskDef``, and ``storage_class_mapping``.
"""
datasetTypes = NamedValueSet[DatasetType]()
for c in iterConnections(taskDef.connections, connectionType):
dimensions = set(getattr(c, "dimensions", set()))
if "skypix" in dimensions:
try:
datasetType = registry.getDatasetType(c.name)
except LookupError as err:
raise LookupError(
f"DatasetType '{c.name}' referenced by "
f"{type(taskDef.connections).__name__} uses 'skypix' as a dimension "
"placeholder, but does not already exist in the registry. "
"Note that reference catalog names are now used as the dataset "
"type name instead of 'ref_cat'."
) from err
rest1 = set(registry.dimensions.extract(dimensions - set(["skypix"])).names)
rest2 = set(
dim.name for dim in datasetType.dimensions if not isinstance(dim, SkyPixDimension)
)
if rest1 != rest2:
raise ValueError(
f"Non-skypix dimensions for dataset type {c.name} declared in "
f"connections ({rest1}) are inconsistent with those in "
f"registry's version of this dataset ({rest2})."
)
else:
# Component dataset types are not explicitly in the
# registry. This complicates consistency checks with
# registry and requires we work out the composite storage
# class.
registryDatasetType = None
try:
registryDatasetType = registry.getDatasetType(c.name)
except KeyError:
compositeName, componentName = DatasetType.splitDatasetTypeName(c.name)
if componentName:
if storage_class_mapping is None or compositeName not in storage_class_mapping:
raise LookupError(
"Component parent class cannot be determined, and "
"composite name was not in storage class mapping, or no "
"storage_class_mapping was supplied"
)
else:
parentStorageClass = storage_class_mapping[compositeName]
else:
parentStorageClass = None
datasetType = c.makeDatasetType(
registry.dimensions, parentStorageClass=parentStorageClass
)
registryDatasetType = datasetType
else:
datasetType = c.makeDatasetType(
registry.dimensions, parentStorageClass=registryDatasetType.parentStorageClass
)

if registryDatasetType and datasetType != registryDatasetType:
# The dataset types differ but first check to see if
# they are compatible before raising.
if is_input:
# This DatasetType must be compatible on get.
is_compatible = datasetType.is_compatible_with(registryDatasetType)
else:
# Has to be able to be converted to expect type
# on put.
is_compatible = registryDatasetType.is_compatible_with(datasetType)
if is_compatible:
# For inputs we want the pipeline to use the
# pipeline definition, for outputs it should use
# the registry definition.
if not is_input:
datasetType = registryDatasetType
_LOG.debug(
"Dataset types differ (task %s != registry %s) but are compatible"
" for %s in %s.",
datasetType,
registryDatasetType,
"input" if is_input else "output",
taskDef.label,
)
else:
try:
# Explicitly check for storage class just to
# make more specific message.
_ = datasetType.storageClass
except KeyError:
raise ValueError(
"Storage class does not exist for supplied dataset type "
f"{datasetType} for {taskDef.label}."
) from None
raise ValueError(
f"Supplied dataset type ({datasetType}) inconsistent with "
f"registry definition ({registryDatasetType}) "
f"for {taskDef.label}."
)
datasetTypes.add(datasetType)
if freeze:
datasetTypes.freeze()
return datasetTypes

# optionally add initOutput dataset for config
initOutputs = makeDatasetTypesSet("initOutputs", is_input=False, freeze=False)
if include_configs:
initOutputs.add(
DatasetType(
taskDef.configDatasetName,
registry.dimensions.empty,
storageClass=pipeline_graph.WriteEdge.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
)
)
initOutputs.freeze()

# optionally add output dataset for metadata
outputs = makeDatasetTypesSet("outputs", is_input=False, freeze=False)
if taskDef.metadataDatasetName is not None:
# Metadata is supposed to be of the TaskMetadata type, its
# dimensions correspond to a task quantum.
dimensions = registry.dimensions.extract(taskDef.connections.dimensions)

# Allow the storage class definition to be read from the existing
# dataset type definition if present.
try:
current = registry.getDatasetType(taskDef.metadataDatasetName)
except KeyError:
# No previous definition so use the default.
storageClass = pipeline_graph.WriteEdge.METADATA_OUTPUT_STORAGE_CLASS
else:
storageClass = current.storageClass.name

outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)})
if taskDef.logOutputDatasetName is not None:
# Log output dimensions correspond to a task quantum.
dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
outputs.update(
{
DatasetType(
taskDef.logOutputDatasetName,
dimensions,
pipeline_graph.WriteEdge.LOG_OUTPUT_STORAGE_CLASS,
)
}
)

outputs.freeze()
@classmethod
def _from_graph_nodes(
cls,
task_node: pipeline_graph.ResolvedTaskNode,
dataset_type_nodes: Mapping[str, pipeline_graph.ResolvedDatasetTypeNode],
include_configs: bool = True,
) -> TaskDatasetTypes:
"""Construct from `PipelineGraph` nodes.
inputs = makeDatasetTypesSet("inputs", is_input=True)
queryConstraints = NamedValueSet(
inputs[c.name]
for c in cast(Iterable[Input], iterConnections(taskDef.connections, "inputs"))
if not c.deferGraphConstraint
)
Parameters
----------
task_node : `pipeline_graph.TaskNode`
Task node to extract dataset types from.
dataset_type_nodes : `Mapping` [ `str`, \
`pipeline_graph.DatasetTypeNode` ]
Mapping of all dataset type nodes in the graph.
include_configs : `bool`, optional
If `True` (default) include config dataset types as
``initOutputs``.
Returns
-------
types : `TaskDatasetTypes`
The dataset types used by this task.
"""
return cls(
initInputs=makeDatasetTypesSet("initInputs", is_input=True),
initOutputs=initOutputs,
inputs=inputs,
queryConstraints=queryConstraints,
prerequisites=makeDatasetTypesSet("prerequisiteInputs", is_input=True),
outputs=outputs,
initInputs=NamedValueSet(
edge.adapt_dataset_type(dataset_type_nodes[edge.parent_dataset_type_name].dataset_type)
for edge in task_node.init_inputs
),
initOutputs=NamedValueSet(
edge.adapt_dataset_type(dataset_type_nodes[edge.parent_dataset_type_name].dataset_type)
for edge in (task_node.iter_all_init_outputs() if include_configs else task_node.init_outputs)
),
inputs=NamedValueSet(
edge.adapt_dataset_type(dataset_type_nodes[edge.parent_dataset_type_name].dataset_type)
for edge in task_node.inputs
),
queryConstraints=NamedValueSet(
edge.adapt_dataset_type(node.dataset_type)
for edge in task_node.inputs
if (node := dataset_type_nodes[edge.parent_dataset_type_name]).is_initial_query_constraint
),
prerequisites=NamedValueSet(
edge.adapt_dataset_type(dataset_type_nodes[edge.parent_dataset_type_name].dataset_type)
for edge in task_node.prerequisite_inputs
),
outputs=NamedValueSet(
edge.adapt_dataset_type(dataset_type_nodes[edge.parent_dataset_type_name].dataset_type)
for edge in task_node.iter_all_run_outputs()
),
)


@@ -1240,7 +1093,7 @@ class PipelineDatasetTypes:
@classmethod
def fromPipeline(
cls,
pipeline: Union[Pipeline, Iterable[TaskDef]],
pipeline: Pipeline | Iterable[TaskDef],
*,
registry: Registry,
include_configs: bool = True,
@@ -1270,122 +1123,101 @@ def fromPipeline(
Raises
------
ValueError
IncompatibleDatasetTypeError
Raised if Tasks are inconsistent about which datasets are marked
prerequisite. This indicates that the Tasks cannot be run as part
of the same `Pipeline`.
"""
allInputs = NamedValueSet[DatasetType]()
allOutputs = NamedValueSet[DatasetType]()
allInitInputs = NamedValueSet[DatasetType]()
allInitOutputs = NamedValueSet[DatasetType]()
prerequisites = NamedValueSet[DatasetType]()
queryConstraints = NamedValueSet[DatasetType]()
if isinstance(pipeline, Pipeline):
mgraph = pipeline.to_graph()
else:
mgraph = pipeline_graph.MutablePipelineGraph()
for task_def in pipeline:
mgraph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
rgraph = mgraph.resolved(registry)
byTask = dict()
for task_node in rgraph.tasks.values():
byTask[task_node.label] = TaskDatasetTypes._from_graph_nodes(
task_node,
rgraph.dataset_types,
include_configs=include_configs,
)
result = cls(
initInputs=NamedValueSet(),
initOutputs=NamedValueSet(),
initIntermediates=NamedValueSet(),
inputs=NamedValueSet(),
queryConstraints=NamedValueSet(),
prerequisites=NamedValueSet(),
intermediates=NamedValueSet(),
outputs=NamedValueSet(),
byTask=byTask,
)
# All of the logic involving components below is unfortunate, because
# downstream code would be far simpler if PipelineDatasetTypes always
# converted them to their parent dataset types and left
# component-handling to TaskDatasetTypes (which is more or less what
# PipelineGraph does, by putting components in the edge objects). But
# including all components as well is what this code has done in the
# past and changing that would break downstream code.
for dataset_type_node in rgraph.init_dataset_types.values():
if consumers := rgraph.consumers_of(dataset_type_node.name):
dataset_types = [
(
dataset_type_node.dataset_type.makeComponentDatasetType(edge.component)
if edge.component is not None
else dataset_type_node.dataset_type
)
for edge in consumers.values()
]
if rgraph.producer_of(dataset_type_node.name) is None:
result.initInputs.update(dataset_types)
else:
result.initIntermediates.update(dataset_types)
else:
result.initOutputs.add(dataset_type_node.dataset_type)
for dataset_type_node in rgraph.run_dataset_types.values():
if consumers := rgraph.consumers_of(dataset_type_node.name):
dataset_types = [
(
dataset_type_node.dataset_type.makeComponentDatasetType(edge.component)
if edge.component is not None
else dataset_type_node.dataset_type
)
for edge in consumers.values()
]
if dataset_type_node.is_prerequisite:
result.prerequisites.update(dataset_types)
elif rgraph.producer_of(dataset_type_node.name) is None:
result.inputs.update(dataset_types)
if dataset_type_node.is_initial_query_constraint:
result.queryConstraints.add(dataset_type_node.dataset_type)
elif rgraph.consumers_of(dataset_type_node.name):
result.intermediates.update(dataset_types)
else:
result.outputs.add(dataset_type_node.dataset_type)
if include_packages:
allInitOutputs.add(
result.initOutputs.add(
DatasetType(
cls.packagesDatasetName,
registry.dimensions.empty,
storageClass="Packages",
)
)
# create a list of TaskDefs in case the input is a generator
pipeline = list(pipeline)

# collect all the output dataset types
typeStorageclassMap: Dict[str, str] = {}
for taskDef in pipeline:
for outConnection in iterConnections(taskDef.connections, "outputs"):
typeStorageclassMap[outConnection.name] = outConnection.storageClass

for taskDef in pipeline:
thisTask = TaskDatasetTypes.fromTaskDef(
taskDef,
registry=registry,
include_configs=include_configs,
storage_class_mapping=typeStorageclassMap,
)
allInitInputs.update(thisTask.initInputs)
allInitOutputs.update(thisTask.initOutputs)
allInputs.update(thisTask.inputs)
# Inputs are query constraints if any task considers them a query
# constraint.
queryConstraints.update(thisTask.queryConstraints)
prerequisites.update(thisTask.prerequisites)
allOutputs.update(thisTask.outputs)
byTask[taskDef.label] = thisTask
if not prerequisites.isdisjoint(allInputs):
raise ValueError(
"{} marked as both prerequisites and regular inputs".format(
{dt.name for dt in allInputs & prerequisites}
)
)
if not prerequisites.isdisjoint(allOutputs):
raise ValueError(
"{} marked as both prerequisites and outputs".format(
{dt.name for dt in allOutputs & prerequisites}
)
)
# Make sure that components which are marked as inputs get treated as
# intermediates if there is an output which produces the composite
# containing the component
intermediateComponents = NamedValueSet[DatasetType]()
intermediateComposites = NamedValueSet[DatasetType]()
outputNameMapping = {dsType.name: dsType for dsType in allOutputs}
for dsType in allInputs:
# get the name of a possible component
name, component = dsType.nameAndComponent()
# if there is a component name, that means this is a component
# DatasetType, if there is an output which produces the parent of
# this component, treat this input as an intermediate
if component is not None:
# This needs to be in this if block, because someone might have
# a composite that is a pure input from existing data
if name in outputNameMapping:
intermediateComponents.add(dsType)
intermediateComposites.add(outputNameMapping[name])

def checkConsistency(a: NamedValueSet, b: NamedValueSet) -> None:
common = a.names & b.names
for name in common:
# Any compatibility is allowed. This function does not know
# if a dataset type is to be used for input or output.
if not (a[name].is_compatible_with(b[name]) or b[name].is_compatible_with(a[name])):
raise ValueError(f"Conflicting definitions for dataset type: {a[name]} != {b[name]}.")

checkConsistency(allInitInputs, allInitOutputs)
checkConsistency(allInputs, allOutputs)
checkConsistency(allInputs, intermediateComposites)
checkConsistency(allOutputs, intermediateComposites)

def frozen(s: AbstractSet[DatasetType]) -> NamedValueSet[DatasetType]:
assert isinstance(s, NamedValueSet)
s.freeze()
return s

inputs = frozen(allInputs - allOutputs - intermediateComponents)

return cls(
initInputs=frozen(allInitInputs - allInitOutputs),
initIntermediates=frozen(allInitInputs & allInitOutputs),
initOutputs=frozen(allInitOutputs - allInitInputs),
inputs=inputs,
queryConstraints=frozen(queryConstraints & inputs),
# If there are storage class differences in inputs and outputs
# the intermediates have to choose priority. Here choose that
# inputs to tasks much match the requested storage class by
# applying the inputs over the top of the outputs.
intermediates=frozen(allOutputs & allInputs | intermediateComponents),
outputs=frozen(allOutputs - allInputs - intermediateComposites),
prerequisites=frozen(prerequisites),
byTask=MappingProxyType(byTask), # MappingProxyType -> frozen view of dict for immutability
)
result.initInputs.freeze()
result.initOutputs.freeze()
result.initIntermediates.freeze()
result.inputs.freeze()
result.queryConstraints.freeze()
result.prerequisites.freeze()
result.intermediates.freeze()
result.outputs.freeze()
return result

@classmethod
def initOutputNames(
cls,
pipeline: Union[Pipeline, Iterable[TaskDef]],
pipeline: Pipeline | Iterable[TaskDef],
*,
include_configs: bool = True,
include_packages: bool = True,
@@ -1407,19 +1239,16 @@ def initOutputNames(
datasetTypeName : `str`
Name of the dataset type.
"""
if isinstance(pipeline, Pipeline):
mgraph = pipeline.to_graph()
else:
mgraph = pipeline_graph.MutablePipelineGraph()
for task_def in pipeline:
mgraph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
if include_packages:
# Package versions dataset type
yield cls.packagesDatasetName

if isinstance(pipeline, Pipeline):
pipeline = pipeline.toExpandedPipeline()

for taskDef in pipeline:
# all task InitOutputs
for name in taskDef.connections.initOutputs:
attribute = getattr(taskDef.connections, name)
yield attribute.name

# config dataset name
if include_configs:
yield taskDef.configDatasetName
for task_node in mgraph.tasks.values():
edges = task_node.iter_all_init_outputs() if include_configs else task_node.init_outputs
for edge in edges:
yield edge.dataset_type_name

0 comments on commit 7a893cb

Please sign in to comment.