Skip to content

Commit

Permalink
Improve handling of calibration datasets in graph builder (DM-40254)
Browse files Browse the repository at this point in the history
There may be multiple calibration datasets for the same dataset type and
data ID in one graph (with different timespans). This patch changes internal
graph builder structure for prerequisite datasets to allow multiple datasets
for one data ID.
  • Loading branch information
andy-slac committed Aug 1, 2023
1 parent 6a8d145 commit 96d7e54
Showing 1 changed file with 176 additions and 51 deletions.
227 changes: 176 additions & 51 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from collections.abc import Collection, Iterable, Iterator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any
from typing import Any, TypeVar, cast

from lsst.daf.butler import (
CollectionType,
Expand Down Expand Up @@ -104,9 +104,12 @@ def resolved_ref(self) -> DatasetRef:
return self.ref


class _DatasetDict(NamedKeyDict[DatasetType, dict[DataCoordinate, _RefHolder]]):
_Refs = TypeVar("_Refs")


class _DatasetDictBase(NamedKeyDict[DatasetType, _Refs]):
"""A custom dictionary that maps `~lsst.daf.butler.DatasetType` to a nested
dictionary of the known `~lsst.daf.butler.DatasetRef` instances of that
collection of the known `~lsst.daf.butler.DatasetRef` instances of that
type.
Parameters
Expand All @@ -122,35 +125,12 @@ def __init__(self, *args: Any, universe: DimensionUniverse):
self.universe = universe

@classmethod
def fromDatasetTypes(
cls, datasetTypes: Iterable[DatasetType], *, universe: DimensionUniverse
) -> _DatasetDict:
"""Construct a dictionary from a flat iterable of
`~lsst.daf.butler.DatasetType` keys.
Parameters
----------
datasetTypes : `~collections.abc.Iterable` of \
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be empty
dictionaries.
universe : `~lsst.daf.butler.DimensionUniverse`
Universe of all possible dimensions.
Returns
-------
dictionary : `_DatasetDict`
A new `_DatasetDict` instance.
"""
return cls({datasetType: {} for datasetType in datasetTypes}, universe=universe)

@classmethod
def fromSubset(
def _fromSubset(
cls,
datasetTypes: Collection[DatasetType],
first: _DatasetDict,
*rest: _DatasetDict,
) -> _DatasetDict:
first: _DatasetDictBase,
*rest: _DatasetDictBase,
) -> _DatasetDictBase:
"""Return a new dictionary by extracting items corresponding to the
given keys from one or more existing dictionaries.
Expand All @@ -160,14 +140,16 @@ def fromSubset(
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be obtained
by lookups against ``first`` and ``rest``.
first : `_DatasetDict`
Another dictionary from which to extract values.
first : `_DatasetDictBase`
Another dictionary from which to extract values. Its actual type
must be idedntical to the type of sub-class used to call this
method.
rest
Additional dictionaries from which to extract values.
Returns
-------
dictionary : `_DatasetDict`
dictionary : `_DatasetDictBase`
A new dictionary instance.
"""
combined = ChainMap(first, *rest)
Expand Down Expand Up @@ -300,6 +282,68 @@ def unpackMultiRefs(self, storage_classes: dict[str, str]) -> NamedKeyDict[Datas
`~lsst.daf.butler.DatasetType` instances and string names usable
as keys.
"""
raise NotImplementedError()


class _DatasetDict(_DatasetDictBase[dict[DataCoordinate, _RefHolder]]):
"""A custom dictionary that maps `~lsst.daf.butler.DatasetType` to a nested
dictionary of the known `~lsst.daf.butler.DatasetRef` instances of that
type.
"""

@classmethod
def fromDatasetTypes(
cls, datasetTypes: Iterable[DatasetType], *, universe: DimensionUniverse
) -> _DatasetDict:
"""Construct a dictionary from a flat iterable of
`~lsst.daf.butler.DatasetType` keys.
Parameters
----------
datasetTypes : `~collections.abc.Iterable` of \
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be empty
dictionaries.
universe : `~lsst.daf.butler.DimensionUniverse`
Universe of all possible dimensions.
Returns
-------
dictionary : `_DatasetDict`
A new `_DatasetDict` instance.
"""
return cls({datasetType: {} for datasetType in datasetTypes}, universe=universe)

@classmethod
def fromSubset(
cls,
datasetTypes: Collection[DatasetType],
first: _DatasetDict,
*rest: _DatasetDict,
) -> _DatasetDict:
"""Return a new dictionary by extracting items corresponding to the
given keys from one or more existing dictionaries.
Parameters
----------
datasetTypes : `~collections.abc.Iterable` of \
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be obtained
by lookups against ``first`` and ``rest``.
first : `_DatasetDict`
Another dictionary from which to extract values.
rest
Additional dictionaries from which to extract values.
Returns
-------
dictionary : `_DatasetDict`
A new dictionary instance.
"""
return cast(_DatasetDict, cls._fromSubset(datasetTypes, first, *rest))

def unpackMultiRefs(self, storage_classes: dict[str, str]) -> NamedKeyDict[DatasetType, list[DatasetRef]]:
# Docstring inherited.
result = {}
for dataset_type, holders in self.items():
if (
Expand Down Expand Up @@ -358,6 +402,92 @@ def iter_resolved_refs(self) -> Iterator[DatasetRef]:
yield holder.resolved_ref


class _DatasetDictMulti(_DatasetDictBase[defaultdict[DataCoordinate, list[_RefHolder]]]):
"""A custom dictionary that maps `~lsst.daf.butler.DatasetType` to a nested
dictionary of the known `~lsst.daf.butler.DatasetRef` instances of that
type. Nexted dictionary can contain multiple refs for the same data ID,
suitable for use with calibration datasets.
"""

@classmethod
def fromDatasetTypes(
cls, datasetTypes: Iterable[DatasetType], *, universe: DimensionUniverse
) -> _DatasetDictMulti:
"""Construct a dictionary from a flat iterable of
`~lsst.daf.butler.DatasetType` keys.
Parameters
----------
datasetTypes : `~collections.abc.Iterable` of \
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be empty
dictionaries.
universe : `~lsst.daf.butler.DimensionUniverse`
Universe of all possible dimensions.
Returns
-------
dictionary : `_DatasetDictMulti`
A new `_DatasetDictMulti` instance.
"""
return cls({datasetType: defaultdict(list) for datasetType in datasetTypes}, universe=universe)

@classmethod
def fromSubset(
cls,
datasetTypes: Collection[DatasetType],
first: _DatasetDictMulti,
*rest: _DatasetDictMulti,
) -> _DatasetDictMulti:
"""Return a new dictionary by extracting items corresponding to the
given keys from one or more existing dictionaries.
Parameters
----------
datasetTypes : `~collections.abc.Iterable` of \
`~lsst.daf.butler.DatasetType`
DatasetTypes to use as keys for the dict. Values will be obtained
by lookups against ``first`` and ``rest``.
first : `_DatasetDictMulti`
Another dictionary from which to extract values.
rest
Additional dictionaries from which to extract values.
Returns
-------
dictionary : `_DatasetDictMulti`
A new dictionary instance.
"""
return cast(_DatasetDictMulti, cls._fromSubset(datasetTypes, first, *rest))

def unpackMultiRefs(self, storage_classes: dict[str, str]) -> NamedKeyDict[DatasetType, list[DatasetRef]]:
# Docstring inherited.
result = {}
for dataset_type, holder_map in self.items():
if (
override := storage_classes.get(dataset_type.name, dataset_type.storageClass_name)
) != dataset_type.storageClass_name:
dataset_type = dataset_type.overrideStorageClass(override)
refs = []
for holder_list in holder_map.values():
refs += [holder.resolved_ref.overrideStorageClass(override) for holder in holder_list]
else:
refs = []
for holder_list in holder_map.values():
refs += [holder.resolved_ref for holder in holder_list]
result[dataset_type] = refs
return NamedKeyDict(result)

def iter_resolved_refs(self) -> Iterator[DatasetRef]:
"""Iterate over all DatasetRef instances held by this data structure,
assuming that each `_RefHolder` already carries are resolved ref.
"""
for holders_by_data_id in self.values():
for holder_list in holders_by_data_id.values():
for holder in holder_list:
yield holder.resolved_ref


class _QuantumScaffolding:
"""Helper class aggregating information about a `Quantum`, used when
constructing a `QuantumGraph`.
Expand Down Expand Up @@ -530,7 +660,7 @@ def __init__(
)
self.inputs = _DatasetDict.fromSubset(datasetTypes.inputs, parent.inputs, parent.intermediates)
self.outputs = _DatasetDict.fromSubset(datasetTypes.outputs, parent.intermediates, parent.outputs)
self.prerequisites = _DatasetDict.fromSubset(datasetTypes.prerequisites, parent.prerequisites)
self.prerequisites = _DatasetDictMulti.fromSubset(datasetTypes.prerequisites, parent.prerequisites)
self.dataIds: set[DataCoordinate] = set()
self.quanta = {}
self.storage_classes = {
Expand Down Expand Up @@ -581,10 +711,10 @@ def __repr__(self) -> str:
(`_DatasetDict`).
"""

prerequisites: _DatasetDict
prerequisites: _DatasetDictMulti
"""Dictionary containing information about input datasets that must be
present in the repository before any Pipeline containing this task is run
(`_DatasetDict`).
(`_DatasetDictMulti`).
"""

quanta: dict[DataCoordinate, _QuantumScaffolding]
Expand Down Expand Up @@ -739,13 +869,15 @@ def __init__(self, pipeline: Pipeline | Iterable[TaskDef], *, registry: Registry
"inputs",
"intermediates",
"outputs",
"prerequisites",
):
setattr(
self,
attr,
_DatasetDict.fromDatasetTypes(getattr(datasetTypes, attr), universe=registry.dimensions),
)
self.prerequisites = _DatasetDictMulti.fromDatasetTypes(
datasetTypes.prerequisites, universe=registry.dimensions
)
self.missing = _DatasetDict(universe=registry.dimensions)
self.defaultDatasetQueryConstraints = datasetTypes.queryConstraints
# Aggregate all dimensions for all non-init, non-prerequisite
Expand Down Expand Up @@ -804,9 +936,9 @@ def __repr__(self) -> str:
(`_DatasetDict`).
"""

prerequisites: _DatasetDict
prerequisites: _DatasetDictMulti
"""Datasets that are consumed when running this pipeline and looked up
per-Quantum when generating the graph (`_DatasetDict`).
per-Quantum when generating the graph (`_DatasetDictMulti`).
"""

defaultDatasetQueryConstraints: NamedValueSet[DatasetType]
Expand Down Expand Up @@ -1381,7 +1513,7 @@ def resolveDatasetRefs(
for ref in prereq_refs:
if ref is not None:
quantum.prerequisites[datasetType][ref.dataId] = _RefHolder(datasetType, ref)
task.prerequisites[datasetType][ref.dataId] = _RefHolder(datasetType, ref)
task.prerequisites[datasetType][ref.dataId].append(_RefHolder(datasetType, ref))

# Resolve all quantum inputs and outputs.
for dataset_type, refDict in quantum.inputs.items():
Expand Down Expand Up @@ -1457,20 +1589,13 @@ def makeQuantumGraph(
graph : `QuantumGraph`
The full `QuantumGraph`.
"""

def _make_refs(dataset_dict: _DatasetDict) -> Iterable[DatasetRef]:
"""Extract all DatasetRefs from the dictionaries"""
for ref_dict in dataset_dict.values():
for holder in ref_dict.values():
yield holder.resolved_ref

datastore_records: Mapping[str, DatastoreRecordData] | None = None
if datastore is not None:
datastore_records = datastore.export_records(
itertools.chain(
_make_refs(self.inputs),
_make_refs(self.initInputs),
_make_refs(self.prerequisites),
self.inputs.iter_resolved_refs(),
self.initInputs.iter_resolved_refs(),
self.prerequisites.iter_resolved_refs(),
)
)

Expand Down Expand Up @@ -1509,7 +1634,7 @@ def _get_registry_dataset_types(self, registry: Registry) -> Iterable[DatasetTyp
"""Make a list of all dataset types used by a graph as defined in
registry.
"""
chain = [
chain: list[_DatasetDict | _DatasetDictMulti] = [
self.initInputs,
self.initIntermediates,
self.initOutputs,
Expand Down

0 comments on commit 96d7e54

Please sign in to comment.