Skip to content

Commit

Permalink
Merge pull request #444 from lsst/tickets/DM-38041
Browse files Browse the repository at this point in the history
DM-38041: rewrite pre-exec-init logic to work without QGs and respect storage class differences
  • Loading branch information
TallJimbo authored Sep 12, 2024
2 parents 3efb6ca + 9a3e471 commit 8292ef4
Show file tree
Hide file tree
Showing 8 changed files with 1,232 additions and 19 deletions.
3 changes: 3 additions & 0 deletions doc/changes/DM-38041.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Add support for initializing processing output runs with just a pipeline graph, not a quantum graph.

This also moves much of the logic for initializing output runs from `lsst.ctrl.mpexec.PreExecInit` to `PipelineGraph` and `QuantumGraph` methods.
8 changes: 4 additions & 4 deletions python/lsst/pipe/base/graph/_versionDeserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,8 @@ def constructGraph(
container = {}
datasetDict = _DatasetTracker(createInverse=True)
taskToQuantumNode: defaultdict[TaskDef, set[QuantumNode]] = defaultdict(set)
initInputRefs: dict[TaskDef, list[DatasetRef]] = {}
initOutputRefs: dict[TaskDef, list[DatasetRef]] = {}
initInputRefs: dict[str, list[DatasetRef]] = {}
initOutputRefs: dict[str, list[DatasetRef]] = {}

if universe is not None:
if not universe.isCompatibleWith(self.infoMappings.universe):
Expand Down Expand Up @@ -597,11 +597,11 @@ def constructGraph(

# initInputRefs and initOutputRefs are optional
if (refs := taskDefDump.get("initInputRefs")) is not None:
initInputRefs[recreatedTaskDef] = [
initInputRefs[recreatedTaskDef.label] = [
cast(DatasetRef, DatasetRef.from_json(ref, universe=universe)) for ref in refs
]
if (refs := taskDefDump.get("initOutputRefs")) is not None:
initOutputRefs[recreatedTaskDef] = [
initOutputRefs[recreatedTaskDef.label] = [
cast(DatasetRef, DatasetRef.from_json(ref, universe=universe)) for ref in refs
]

Expand Down
271 changes: 262 additions & 9 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,28 @@

import networkx as nx
from lsst.daf.butler import (
Config,
DatasetId,
DatasetRef,
DatasetType,
DimensionRecordsAccumulator,
DimensionUniverse,
LimitedButler,
Quantum,
QuantumBackedButler,
)
from lsst.daf.butler.datastore.record_data import DatastoreRecordData
from lsst.daf.butler.persistence_context import PersistenceContextVars
from lsst.daf.butler.registry import ConflictingDefinitionError
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils.introspection import get_full_type_name
from lsst.utils.packages import Packages
from networkx.drawing.nx_agraph import write_dot

from ..config import PipelineTaskConfig
from ..connections import iterConnections
from ..pipeline import TaskDef
from ..pipeline_graph import PipelineGraph
from ..pipeline_graph import PipelineGraph, compare_packages, log_config_mismatch
from ._implDetails import DatasetTypeName, _DatasetTracker
from ._loadHelpers import LoadHelper
from ._versionDeserializers import DESERIALIZER_MAP
Expand Down Expand Up @@ -286,14 +292,14 @@ def _buildGraphs(
# insertion
self._taskToQuantumNode = dict(self._taskToQuantumNode.items())

self._initInputRefs: dict[TaskDef, list[DatasetRef]] = {}
self._initOutputRefs: dict[TaskDef, list[DatasetRef]] = {}
self._initInputRefs: dict[str, list[DatasetRef]] = {}
self._initOutputRefs: dict[str, list[DatasetRef]] = {}
self._globalInitOutputRefs: list[DatasetRef] = []
self._registryDatasetTypes: list[DatasetType] = []
if initInputs is not None:
self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
self._initInputRefs = {taskDef.label: list(refs) for taskDef, refs in initInputs.items()}
if initOutputs is not None:
self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
self._initOutputRefs = {taskDef.label: list(refs) for taskDef, refs in initOutputs.items()}
if globalInitOutputs is not None:
self._globalInitOutputRefs = list(globalInitOutputs)
if registryDatasetTypes is not None:
Expand Down Expand Up @@ -812,6 +818,38 @@ def metadata(self) -> MappingProxyType[str, Any]:
"""
return MappingProxyType(self._metadata)

def get_init_input_refs(self, task_label: str) -> list[DatasetRef]:
"""Return the DatasetRefs for the given task's init inputs.
Parameters
----------
task_label : `str`
Label of the task.
Returns
-------
refs : `list` [ `lsst.daf.butler.DatasetRef` ]
Dataset references. Guaranteed to be a new list, not internal
state.
"""
return list(self._initInputRefs.get(task_label, ()))

def get_init_output_refs(self, task_label: str) -> list[DatasetRef]:
"""Return the DatasetRefs for the given task's init outputs.
Parameters
----------
task_label : `str`
Label of the task.
Returns
-------
refs : `list` [ `lsst.daf.butler.DatasetRef` ]
Dataset references. Guaranteed to be a new list, not internal
state.
"""
return list(self._initOutputRefs.get(task_label, ()))

def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
"""Return DatasetRefs for a given task InitInputs.
Expand All @@ -826,7 +864,7 @@ def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
DatasetRef for the task InitInput, can be `None`. This can return
either resolved or non-resolved reference.
"""
return self._initInputRefs.get(taskDef)
return self._initInputRefs.get(taskDef.label)

def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
"""Return DatasetRefs for a given task InitOutputs.
Expand All @@ -843,7 +881,7 @@ def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
either resolved or non-resolved reference. Resolved reference will
match Quantum's initInputs if this is an intermediate dataset type.
"""
return self._initOutputRefs.get(taskDef)
return self._initOutputRefs.get(taskDef.label)

def globalInitOutputRefs(self) -> list[DatasetRef]:
"""Return DatasetRefs for global InitOutputs.
Expand Down Expand Up @@ -1027,9 +1065,9 @@ def _buildSaveObjectImpl(self, returnHeader: bool = False) -> bytearray | tuple[
taskDef.config.saveToStream(stream)
taskDescription["config"] = stream.getvalue()
taskDescription["label"] = taskDef.label
if (refs := self._initInputRefs.get(taskDef)) is not None:
if (refs := self._initInputRefs.get(taskDef.label)) is not None:
taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
if (refs := self._initOutputRefs.get(taskDef)) is not None:
if (refs := self._initOutputRefs.get(taskDef.label)) is not None:
taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]

inputs = []
Expand Down Expand Up @@ -1403,3 +1441,218 @@ def getSummary(self) -> QgraphSummary:
qts.numOutputs[k.name] += 1

return summary

def make_init_qbb(
self,
butler_config: Config | ResourcePathExpression,
*,
config_search_paths: Iterable[str] | None = None,
) -> QuantumBackedButler:
"""Construct an quantum-backed butler suitable for reading and writing
init input and init output datasets, respectively.
This requires the full graph to have been loaded.
Parameters
----------
butler_config : `~lsst.daf.butler.Config` or \
`~lsst.resources.ResourcePathExpression`
A butler repository root, configuration filename, or configuration
instance.
config_search_paths : `~collections.abc.Iterable` [ `str` ], optional
Additional search paths for butler configuration.
Returns
-------
qbb : `~lsst.daf.butler.QuantumBackedButler`
A limited butler that can ``get`` init-input datasets and ``put``
init-output datasets.
"""
universe = self.universe
# Collect all init input/output dataset IDs.
predicted_inputs: set[DatasetId] = set()
predicted_outputs: set[DatasetId] = set()
pipeline_graph = self.pipeline_graph
for task_label in pipeline_graph.tasks:
predicted_inputs.update(ref.id for ref in self.get_init_input_refs(task_label))
predicted_outputs.update(ref.id for ref in self.get_init_output_refs(task_label))
predicted_outputs.update(ref.id for ref in self.globalInitOutputRefs())
# remove intermediates from inputs
predicted_inputs -= predicted_outputs
# Very inefficient way to extract datastore records from quantum graph,
# we have to scan all quanta and look at their datastore records.
datastore_records: dict[str, DatastoreRecordData] = {}
for quantum_node in self:
for store_name, records in quantum_node.quantum.datastore_records.items():
subset = records.subset(predicted_inputs)
if subset is not None:
datastore_records.setdefault(store_name, DatastoreRecordData()).update(subset)

dataset_types = {dstype.name: dstype for dstype in self.registryDatasetTypes()}
# Make butler from everything.
return QuantumBackedButler.from_predicted(
config=butler_config,
predicted_inputs=predicted_inputs,
predicted_outputs=predicted_outputs,
dimensions=universe,
datastore_records=datastore_records,
search_paths=list(config_search_paths) if config_search_paths is not None else None,
dataset_types=dataset_types,
)

def write_init_outputs(self, butler: LimitedButler, skip_existing: bool = True) -> None:
"""Write the init-output datasets for all tasks in the quantum graph.
Parameters
----------
butler : `lsst.daf.butler.LimitedButler`
A limited butler data repository client.
skip_existing : `bool`, optional
If `True` (default) ignore init-outputs that already exist. If
`False`, raise.
Raises
------
lsst.daf.butler.registry.ConflictingDefinitionError
Raised if an init-output dataset already exists and
``skip_existing=False``.
"""
# Extract init-input and init-output refs from the QG.
input_refs: dict[str, DatasetRef] = {}
output_refs: dict[str, DatasetRef] = {}
for task_node in self.pipeline_graph.tasks.values():
input_refs.update(
{ref.datasetType.name: ref for ref in self.get_init_input_refs(task_node.label)}
)
output_refs.update(
{
ref.datasetType.name: ref
for ref in self.get_init_output_refs(task_node.label)
if ref.datasetType.name != task_node.init.config_output.dataset_type_name
}
)
for ref, is_stored in butler.stored_many(output_refs.values()).items():
if is_stored:
if not skip_existing:
raise ConflictingDefinitionError(f"Init-output dataset {ref} already exists.")
# We'll `put` whatever's left in output_refs at the end.
del output_refs[ref.datasetType.name]
# Instantiate tasks, reading overall init-inputs and gathering
# init-output in-memory objects.
init_outputs: list[tuple[Any, DatasetType]] = []
self.pipeline_graph.instantiate_tasks(
get_init_input=lambda dataset_type: butler.get(
input_refs[dataset_type.name].overrideStorageClass(dataset_type.storageClass)
),
init_outputs=init_outputs,
)
# Write init-outputs that weren't already present.
for obj, dataset_type in init_outputs:
if new_ref := output_refs.get(dataset_type.name):
assert (
new_ref.datasetType.storageClass_name == dataset_type.storageClass_name
), "QG init refs should use task connection storage classes."
butler.put(obj, new_ref)

def write_configs(self, butler: LimitedButler, compare_existing: bool = True) -> None:
"""Write the config datasets for all tasks in the quantum graph.
Parameters
----------
butler : `lsst.daf.butler.LimitedButler`
A limited butler data repository client.
compare_existing : `bool`, optional
If `True` check configs that already exist for consistency. If
`False`, always raise if configs already exist.
Raises
------
lsst.daf.butler.registry.ConflictingDefinitionError
Raised if an config dataset already exists and
``compare_existing=False``, or if the existing config is not
consistent with the config in the quantum graph.
"""
to_put: list[tuple[PipelineTaskConfig, DatasetRef]] = []
for task_node in self.pipeline_graph.tasks.values():
dataset_type_name = task_node.init.config_output.dataset_type_name
(ref,) = [
ref
for ref in self.get_init_output_refs(task_node.label)
if ref.datasetType.name == dataset_type_name
]
try:
old_config = butler.get(ref)
except (LookupError, FileNotFoundError):
old_config = None
if old_config is not None:
if not compare_existing:
raise ConflictingDefinitionError(f"Config dataset {ref} already exists.")
if not task_node.config.compare(old_config, shortcut=False, output=log_config_mismatch):
raise ConflictingDefinitionError(
f"Config does not match existing task config {dataset_type_name!r} in "
"butler; tasks configurations must be consistent within the same run collection."
)
else:
to_put.append((task_node.config, ref))
# We do writes at the end to minimize the mess we leave behind when we
# raise an exception.
for config, ref in to_put:
butler.put(config, ref)

def write_packages(self, butler: LimitedButler, compare_existing: bool = True) -> None:
"""Write the 'packages' dataset for the currently-active software
versions.
Parameters
----------
butler : `lsst.daf.butler.LimitedButler`
A limited butler data repository client.
compare_existing : `bool`, optional
If `True` check packages that already exist for consistency. If
`False`, always raise if the packages dataset already exists.
Raises
------
lsst.daf.butler.registry.ConflictingDefinitionError
Raised if the packages dataset already exists and is not consistent
with the current packages.
"""
new_packages = Packages.fromSystem()
(ref,) = self.globalInitOutputRefs()
try:
packages = butler.get(ref)
except (LookupError, FileNotFoundError):
packages = None
if packages is not None:
if not compare_existing:
raise ConflictingDefinitionError(f"Packages dataset {ref} already exists.")
if compare_packages(packages, new_packages):
# have to remove existing dataset first; butler has no
# replace option.
butler.pruneDatasets([ref], unstore=True, purge=True)
butler.put(packages, ref)
else:
butler.put(new_packages, ref)

def init_output_run(self, butler: LimitedButler, existing: bool = True) -> None:
"""Initialize a new output RUN collection by writing init-output
datasets (including configs and packages).
Parameters
----------
butler : `lsst.daf.butler.LimitedButler`
A limited butler data repository client.
existing : `bool`, optional
If `True` check or ignore outputs that already exist. If
`False`, always raise if an output dataset already exists.
Raises
------
lsst.daf.butler.registry.ConflictingDefinitionError
Raised if there are existing init output datasets, and either
``existing=False`` or their contents are not compatible with this
graph.
"""
self.write_configs(butler, compare_existing=existing)
self.write_packages(butler, compare_existing=existing)
self.write_init_outputs(butler, skip_existing=existing)
Loading

0 comments on commit 8292ef4

Please sign in to comment.