Skip to content

Commit

Permalink
Merge pull request #362 from lsst/tickets/DM-40254
Browse files Browse the repository at this point in the history
DM-40254: Improve handling of calibration datasets in graph builder
  • Loading branch information
andy-slac committed Aug 1, 2023
2 parents 6a8d145 + caa5298 commit 496a576
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 51 deletions.
2 changes: 2 additions & 0 deletions doc/changes/DM-40254.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix a bug in quantum graph builder which resulted in missing datastore records for calibration datasets.
This bug was causing failures for pipetask execution with quantum-backed butler.
227 changes: 176 additions & 51 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from collections.abc import Collection, Iterable, Iterator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
from typing import Any, TypeVar, cast

from lsst.daf.butler import (
CollectionType,
Expand Down Expand Up @@ -104,9 +104,12 @@ def resolved_ref(self) -> DatasetRef:
return self.ref


class _DatasetDict(NamedKeyDict[DatasetType, dict[DataCoordinate, _RefHolder]]):
_Refs = TypeVar("_Refs")


class _DatasetDictBase(NamedKeyDict[DatasetType, _Refs]):
"""A custom dictionary that maps `~lsst.daf.butler.DatasetType` to a nested
dictionary of the known `~lsst.daf.butler.DatasetRef` instances of that
collection of the known `~lsst.daf.butler.DatasetRef` instances of that
type.
Parameters
Expand All @@ -122,35 +125,12 @@ def __init__(self, *args: Any, universe: DimensionUniverse):
self.universe = universe

@classmethod
def fromDatasetTypes(
cls, datasetTypes: Iterable[DatasetType], *, universe: DimensionUniverse
) -> _DatasetDict:
"""Construct a dictionary from a flat iterable of
`~lsst.daf.butler.DatasetType` keys.
Parameters
----------
datasetTypes : `~collections.abc.Iterable` of \
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be empty
dictionaries.
universe : `~lsst.daf.butler.DimensionUniverse`
Universe of all possible dimensions.
Returns
-------
dictionary : `_DatasetDict`
A new `_DatasetDict` instance.
"""
return cls({datasetType: {} for datasetType in datasetTypes}, universe=universe)

@classmethod
def fromSubset(
def _fromSubset(
cls,
datasetTypes: Collection[DatasetType],
first: _DatasetDict,
*rest: _DatasetDict,
) -> _DatasetDict:
first: _DatasetDictBase,
*rest: _DatasetDictBase,
) -> _DatasetDictBase:
"""Return a new dictionary by extracting items corresponding to the
given keys from one or more existing dictionaries.
Expand All @@ -160,14 +140,16 @@ def fromSubset(
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be obtained
by lookups against ``first`` and ``rest``.
first : `_DatasetDict`
Another dictionary from which to extract values.
first : `_DatasetDictBase`
Another dictionary from which to extract values. Its actual type
must be idedntical to the type of sub-class used to call this
method.
rest
Additional dictionaries from which to extract values.
Returns
-------
dictionary : `_DatasetDict`
dictionary : `_DatasetDictBase`
A new dictionary instance.
"""
combined = ChainMap(first, *rest)
Expand Down Expand Up @@ -300,6 +282,68 @@ def unpackMultiRefs(self, storage_classes: dict[str, str]) -> NamedKeyDict[Datas
`~lsst.daf.butler.DatasetType` instances and string names usable
as keys.
"""
raise NotImplementedError()


class _DatasetDict(_DatasetDictBase[dict[DataCoordinate, _RefHolder]]):
"""A custom dictionary that maps `~lsst.daf.butler.DatasetType` to a nested
dictionary of the known `~lsst.daf.butler.DatasetRef` instances of that
type.
"""

@classmethod
def fromDatasetTypes(
cls, datasetTypes: Iterable[DatasetType], *, universe: DimensionUniverse
) -> _DatasetDict:
"""Construct a dictionary from a flat iterable of
`~lsst.daf.butler.DatasetType` keys.
Parameters
----------
datasetTypes : `~collections.abc.Iterable` of \
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be empty
dictionaries.
universe : `~lsst.daf.butler.DimensionUniverse`
Universe of all possible dimensions.
Returns
-------
dictionary : `_DatasetDict`
A new `_DatasetDict` instance.
"""
return cls({datasetType: {} for datasetType in datasetTypes}, universe=universe)

@classmethod
def fromSubset(
cls,
datasetTypes: Collection[DatasetType],
first: _DatasetDict,
*rest: _DatasetDict,
) -> _DatasetDict:
"""Return a new dictionary by extracting items corresponding to the
given keys from one or more existing dictionaries.
Parameters
----------
datasetTypes : `~collections.abc.Iterable` of \
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be obtained
by lookups against ``first`` and ``rest``.
first : `_DatasetDict`
Another dictionary from which to extract values.
rest
Additional dictionaries from which to extract values.
Returns
-------
dictionary : `_DatasetDict`
A new dictionary instance.
"""
return cast(_DatasetDict, cls._fromSubset(datasetTypes, first, *rest))

def unpackMultiRefs(self, storage_classes: dict[str, str]) -> NamedKeyDict[DatasetType, list[DatasetRef]]:
# Docstring inherited.
result = {}
for dataset_type, holders in self.items():
if (
Expand Down Expand Up @@ -358,6 +402,92 @@ def iter_resolved_refs(self) -> Iterator[DatasetRef]:
yield holder.resolved_ref


class _DatasetDictMulti(_DatasetDictBase[defaultdict[DataCoordinate, list[_RefHolder]]]):
"""A custom dictionary that maps `~lsst.daf.butler.DatasetType` to a nested
dictionary of the known `~lsst.daf.butler.DatasetRef` instances of that
type. Nexted dictionary can contain multiple refs for the same data ID,
suitable for use with calibration datasets.
"""

@classmethod
def fromDatasetTypes(
cls, datasetTypes: Iterable[DatasetType], *, universe: DimensionUniverse
) -> _DatasetDictMulti:
"""Construct a dictionary from a flat iterable of
`~lsst.daf.butler.DatasetType` keys.
Parameters
----------
datasetTypes : `~collections.abc.Iterable` of \
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be empty
dictionaries.
universe : `~lsst.daf.butler.DimensionUniverse`
Universe of all possible dimensions.
Returns
-------
dictionary : `_DatasetDictMulti`
A new `_DatasetDictMulti` instance.
"""
return cls({datasetType: defaultdict(list) for datasetType in datasetTypes}, universe=universe)

@classmethod
def fromSubset(
cls,
datasetTypes: Collection[DatasetType],
first: _DatasetDictMulti,
*rest: _DatasetDictMulti,
) -> _DatasetDictMulti:
"""Return a new dictionary by extracting items corresponding to the
given keys from one or more existing dictionaries.
Parameters
----------
datasetTypes : `~collections.abc.Iterable` of \
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be obtained
by lookups against ``first`` and ``rest``.
first : `_DatasetDictMulti`
Another dictionary from which to extract values.
rest
Additional dictionaries from which to extract values.
Returns
-------
dictionary : `_DatasetDictMulti`
A new dictionary instance.
"""
return cast(_DatasetDictMulti, cls._fromSubset(datasetTypes, first, *rest))

def unpackMultiRefs(self, storage_classes: dict[str, str]) -> NamedKeyDict[DatasetType, list[DatasetRef]]:
# Docstring inherited.
result = {}
for dataset_type, holder_map in self.items():
if (
override := storage_classes.get(dataset_type.name, dataset_type.storageClass_name)
) != dataset_type.storageClass_name:
dataset_type = dataset_type.overrideStorageClass(override)
refs = []
for holder_list in holder_map.values():
refs += [holder.resolved_ref.overrideStorageClass(override) for holder in holder_list]
else:
refs = []
for holder_list in holder_map.values():
refs += [holder.resolved_ref for holder in holder_list]
result[dataset_type] = refs
return NamedKeyDict(result)

def iter_resolved_refs(self) -> Iterator[DatasetRef]:
"""Iterate over all DatasetRef instances held by this data structure,
assuming that each `_RefHolder` already carries are resolved ref.
"""
for holders_by_data_id in self.values():
for holder_list in holders_by_data_id.values():
for holder in holder_list:
yield holder.resolved_ref


class _QuantumScaffolding:
"""Helper class aggregating information about a `Quantum`, used when
constructing a `QuantumGraph`.
Expand Down Expand Up @@ -530,7 +660,7 @@ def __init__(
)
self.inputs = _DatasetDict.fromSubset(datasetTypes.inputs, parent.inputs, parent.intermediates)
self.outputs = _DatasetDict.fromSubset(datasetTypes.outputs, parent.intermediates, parent.outputs)
self.prerequisites = _DatasetDict.fromSubset(datasetTypes.prerequisites, parent.prerequisites)
self.prerequisites = _DatasetDictMulti.fromSubset(datasetTypes.prerequisites, parent.prerequisites)
self.dataIds: set[DataCoordinate] = set()
self.quanta = {}
self.storage_classes = {
Expand Down Expand Up @@ -581,10 +711,10 @@ def __repr__(self) -> str:
(`_DatasetDict`).
"""

prerequisites: _DatasetDict
prerequisites: _DatasetDictMulti
"""Dictionary containing information about input datasets that must be
present in the repository before any Pipeline containing this task is run
(`_DatasetDict`).
(`_DatasetDictMulti`).
"""

quanta: dict[DataCoordinate, _QuantumScaffolding]
Expand Down Expand Up @@ -739,13 +869,15 @@ def __init__(self, pipeline: Pipeline | Iterable[TaskDef], *, registry: Registry
"inputs",
"intermediates",
"outputs",
"prerequisites",
):
setattr(
self,
attr,
_DatasetDict.fromDatasetTypes(getattr(datasetTypes, attr), universe=registry.dimensions),
)
self.prerequisites = _DatasetDictMulti.fromDatasetTypes(
datasetTypes.prerequisites, universe=registry.dimensions
)
self.missing = _DatasetDict(universe=registry.dimensions)
self.defaultDatasetQueryConstraints = datasetTypes.queryConstraints
# Aggregate all dimensions for all non-init, non-prerequisite
Expand Down Expand Up @@ -804,9 +936,9 @@ def __repr__(self) -> str:
(`_DatasetDict`).
"""

prerequisites: _DatasetDict
prerequisites: _DatasetDictMulti
"""Datasets that are consumed when running this pipeline and looked up
per-Quantum when generating the graph (`_DatasetDict`).
per-Quantum when generating the graph (`_DatasetDictMulti`).
"""

defaultDatasetQueryConstraints: NamedValueSet[DatasetType]
Expand Down Expand Up @@ -1381,7 +1513,7 @@ def resolveDatasetRefs(
for ref in prereq_refs:
if ref is not None:
quantum.prerequisites[datasetType][ref.dataId] = _RefHolder(datasetType, ref)
task.prerequisites[datasetType][ref.dataId] = _RefHolder(datasetType, ref)
task.prerequisites[datasetType][ref.dataId].append(_RefHolder(datasetType, ref))

# Resolve all quantum inputs and outputs.
for dataset_type, refDict in quantum.inputs.items():
Expand Down Expand Up @@ -1457,20 +1589,13 @@ def makeQuantumGraph(
graph : `QuantumGraph`
The full `QuantumGraph`.
"""

def _make_refs(dataset_dict: _DatasetDict) -> Iterable[DatasetRef]:
"""Extract all DatasetRefs from the dictionaries"""
for ref_dict in dataset_dict.values():
for holder in ref_dict.values():
yield holder.resolved_ref

datastore_records: Mapping[str, DatastoreRecordData] | None = None
if datastore is not None:
datastore_records = datastore.export_records(
itertools.chain(
_make_refs(self.inputs),
_make_refs(self.initInputs),
_make_refs(self.prerequisites),
self.inputs.iter_resolved_refs(),
self.initInputs.iter_resolved_refs(),
self.prerequisites.iter_resolved_refs(),
)
)

Expand Down Expand Up @@ -1509,7 +1634,7 @@ def _get_registry_dataset_types(self, registry: Registry) -> Iterable[DatasetTyp
"""Make a list of all dataset types used by a graph as defined in
registry.
"""
chain = [
chain: list[_DatasetDict | _DatasetDictMulti] = [
self.initInputs,
self.initIntermediates,
self.initOutputs,
Expand Down

0 comments on commit 496a576

Please sign in to comment.