Skip to content

Commit

Permalink
Merge pull request #464 from lsst/tickets/DM-35396
Browse files Browse the repository at this point in the history
DM-35396: Record provenance in QuantumContext
  • Loading branch information
timj authored Feb 4, 2025
2 parents fcfd244 + dd217c8 commit 704effe
Show file tree
Hide file tree
Showing 30 changed files with 150 additions and 75 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: trailing-whitespace
- id: check-toml
- repo: https://github.com/psf/black
rev: 24.10.0
rev: 25.1.0
hooks:
- id: black
# It is recommended to specify the latest version of Python
Expand All @@ -16,7 +16,7 @@ repos:
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.11
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 6.0.0
hooks:
- id: isort
name: isort (python)
Expand Down
3 changes: 3 additions & 0 deletions doc/changes/DM-35396.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
* Modified ``QuantumContext`` such that it now tracks all datasets that are retrieved and records them in ``dataset_provenance``.
This provenance is then passed to Butler on ``put()``.
* Added ``QuantumContext.add_additional_provenance()`` to allow a pipeline task author to attach additional provenance information to be recorded and associated with a particular input dataset.
2 changes: 2 additions & 0 deletions doc/changes/DM-35396.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Modified ``TaskMetadata`` such that it can now be assigned an empty list.
This list can be retrieved with ``getArray`` but if an attempt is made to get a scalar `KeyError` will be raised.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ target-version = "py311"
exclude = [
"__init__.py",
]
[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 79

[tool.ruff.lint]
ignore = [
Expand Down
51 changes: 42 additions & 9 deletions python/lsst/pipe/base/_quantumContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,21 @@
__all__ = ("ExecutionResources", "QuantumContext")

import numbers
import uuid
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any

import astropy.units as u
from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, LimitedButler, Quantum
from lsst.daf.butler import (
DataCoordinate,
DatasetProvenance,
DatasetRef,
DatasetType,
DimensionUniverse,
LimitedButler,
Quantum,
)
from lsst.utils.introspection import get_full_type_name
from lsst.utils.logging import PeriodicLogger, getLogger

Expand Down Expand Up @@ -174,6 +183,8 @@ class QuantumContext:
single execution of this node in the pipeline graph.
resources : `ExecutionResources`, optional
The resources allocated for executing quanta.
quantum_id : `uuid.UUID` or `None`, optional
The ID of the quantum being executed. Used for provenance.
Notes
-----
Expand All @@ -191,7 +202,12 @@ class QuantumContext:
resources: ExecutionResources

def __init__(
self, butler: LimitedButler, quantum: Quantum, *, resources: ExecutionResources | None = None
self,
butler: LimitedButler,
quantum: Quantum,
*,
resources: ExecutionResources | None = None,
quantum_id: uuid.UUID | None = None,
):
self.quantum = quantum
if resources is None:
Expand All @@ -202,7 +218,7 @@ def __init__(
self.allOutputs = set()
for refs in quantum.inputs.values():
for ref in refs:
self.allInputs.add((ref.datasetType, ref.dataId))
self.allInputs.add((ref.datasetType, ref.dataId, ref.id))
for dataset_type, refs in quantum.outputs.items():
if dataset_type.name.endswith(METADATA_OUTPUT_CONNECTION_NAME) or dataset_type.name.endswith(
LOG_OUTPUT_CONNECTION_NAME
Expand All @@ -212,27 +228,30 @@ def __init__(
# write them itself; that's for the execution system to do.
continue
for ref in refs:
self.allOutputs.add((ref.datasetType, ref.dataId))
self.outputsPut: set[tuple[DatasetType, DataCoordinate]] = set()
self.allOutputs.add((ref.datasetType, ref.dataId, ref.id))
self.outputsPut: set[tuple[DatasetType, DataCoordinate, uuid.UUID]] = set()
self.__butler = butler
self.dataset_provenance = DatasetProvenance(quantum_id=quantum_id)

def _get(self, ref: DeferredDatasetRef | DatasetRef | None) -> Any:
# Butler methods below will check for unresolved DatasetRefs and
# raise appropriately, so no need for us to do that here.
if isinstance(ref, DeferredDatasetRef):
self._checkMembership(ref.datasetRef, self.allInputs)
self.dataset_provenance.add_input(ref.datasetRef)
return self.__butler.getDeferred(ref.datasetRef)
elif ref is None:
return None
else:
self._checkMembership(ref, self.allInputs)
self.dataset_provenance.add_input(ref)
return self.__butler.get(ref)

def _put(self, value: Any, ref: DatasetRef) -> None:
"""Store data in butler."""
self._checkMembership(ref, self.allOutputs)
self.__butler.put(value, ref)
self.outputsPut.add((ref.datasetType, ref.dataId))
self.__butler.put(value, ref, provenance=self.dataset_provenance)
self.outputsPut.add((ref.datasetType, ref.dataId, ref.id))

def get(
self,
Expand Down Expand Up @@ -261,7 +280,7 @@ def get(
Returns
-------
return : `object`
This function returns arbitrary objects fetched from the bulter.
This function returns arbitrary objects fetched from the butler.
The structure these objects are returned in depends on the type of
the input argument. If the input dataset argument is a
`InputQuantizedConnection`, then the return type will be a
Expand Down Expand Up @@ -425,7 +444,7 @@ def _checkMembership(self, ref: list[DatasetRef] | DatasetRef, inout: set) -> No
if not isinstance(ref, list | tuple):
ref = [ref]
for r in ref:
if (r.datasetType, r.dataId) not in inout:
if (r.datasetType, r.dataId, r.id) not in inout:
raise ValueError("DatasetRef is not part of the Quantum being processed")

@property
Expand All @@ -434,3 +453,17 @@ def dimensions(self) -> DimensionUniverse:
repository (`~lsst.daf.butler.DimensionUniverse`).
"""
return self.__butler.dimensions

def add_additional_provenance(self, ref: DatasetRef, extra: dict[str, int | float | str | bool]) -> None:
"""Add additional provenance information to the dataset provenance.
Parameters
----------
ref : `DatasetRef`
The dataset to attach provenance to. This dataset must have been
retrieved by this quantum context.
extra : `dict` [ `str`, `int` | `float` | `str` | `bool` ]
Additional information to attach as provenance information. Keys
must be strings and values must be simple scalars.
"""
self.dataset_provenance.add_extra_provenance(ref.id, extra)
9 changes: 8 additions & 1 deletion python/lsst/pipe/base/_task_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,12 @@ def __getitem__(self, key: str) -> Any:
if key0 in self.metadata:
return self.metadata[key0]
if key0 in self.arrays:
return self.arrays[key0][-1]
arr = self.arrays[key0]
if not arr:
# If there are no elements then returning a scalar
# is an error.
raise KeyError(f"'{key}' not found")
return arr[-1]
raise KeyError(f"'{key}' not found")
# Hierarchical lookup so the top key can only be in the metadata
# property. Trap KeyError and reraise so that the correct key
Expand Down Expand Up @@ -613,6 +618,8 @@ def _validate_value(self, value: Any) -> tuple[str, Any]:
# For model consistency, need to check that every item in the
# list has the same type.
value = list(value)
if not value:
return "array", value

type0 = type(value[0])
for i in value:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,7 @@ def from_builder(
result.query_args["collections"] = builder.input_collections
else:
raise QuantumGraphBuilderError(
f"Unable to handle type {builder.dataset_query_constraint} "
"given as datasetQueryConstraint."
f"Unable to handle type {builder.dataset_query_constraint} given as datasetQueryConstraint."
)
builder.log.verbose("Querying for data IDs with arguments:")
builder.log.verbose(" dimensions=%s,", list(result.query_args["dimensions"].names))
Expand Down
5 changes: 3 additions & 2 deletions python/lsst/pipe/base/caching_limited_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from lsst.daf.butler import (
DatasetId,
DatasetProvenance,
DatasetRef,
DeferredDatasetHandle,
DimensionUniverse,
Expand Down Expand Up @@ -155,7 +156,7 @@ def stored_many(self, refs: Iterable[DatasetRef]) -> dict[DatasetRef, bool]:
def isWriteable(self) -> bool:
return self._wrapped.isWriteable()

def put(self, obj: Any, ref: DatasetRef) -> DatasetRef:
def put(self, obj: Any, ref: DatasetRef, /, *, provenance: DatasetProvenance | None = None) -> DatasetRef:
if ref.datasetType.name in self._cache_on_put:
self._cache[ref.datasetType.name] = (
ref.id,
Expand All @@ -167,7 +168,7 @@ def put(self, obj: Any, ref: DatasetRef) -> DatasetRef:
),
)
_LOG.debug("Cached dataset %s on put", ref)
return self._wrapped.put(obj, ref)
return self._wrapped.put(obj, ref, provenance=provenance)

def pruneDatasets(
self,
Expand Down
3 changes: 1 addition & 2 deletions python/lsst/pipe/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Module defining config classes for PipelineTask.
"""
"""Module defining config classes for PipelineTask."""

from __future__ import annotations

Expand Down
3 changes: 1 addition & 2 deletions python/lsst/pipe/base/configOverrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Module which defines ConfigOverrides class and related methods.
"""
"""Module which defines ConfigOverrides class and related methods."""
from __future__ import annotations

__all__ = ["ConfigOverrides"]
Expand Down
42 changes: 24 additions & 18 deletions python/lsst/pipe/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Module defining connection classes for PipelineTask.
"""
"""Module defining connection classes for PipelineTask."""

from __future__ import annotations

Expand Down Expand Up @@ -590,26 +589,33 @@ class attribute must match a function argument name in the ``run``
>>> from lsst.pipe.base import connectionTypes as cT
>>> from lsst.pipe.base import PipelineTaskConnections
>>> from lsst.pipe.base import PipelineTaskConfig
>>> class ExampleConnections(PipelineTaskConnections,
... dimensions=("A", "B"),
... defaultTemplates={"foo": "Example"}):
... inputConnection = cT.Input(doc="Example input",
... dimensions=("A", "B"),
... storageClass=Exposure,
... name="{foo}Dataset")
... outputConnection = cT.Output(doc="Example output",
... dimensions=("A", "B"),
... storageClass=Exposure,
... name="{foo}output")
>>> class ExampleConfig(PipelineTaskConfig,
... pipelineConnections=ExampleConnections):
... pass
>>> class ExampleConnections(
... PipelineTaskConnections,
... dimensions=("A", "B"),
... defaultTemplates={"foo": "Example"},
... ):
... inputConnection = cT.Input(
... doc="Example input",
... dimensions=("A", "B"),
... storageClass=Exposure,
... name="{foo}Dataset",
... )
... outputConnection = cT.Output(
... doc="Example output",
... dimensions=("A", "B"),
... storageClass=Exposure,
... name="{foo}output",
... )
>>> class ExampleConfig(
... PipelineTaskConfig, pipelineConnections=ExampleConnections
... ):
... pass
>>> config = ExampleConfig()
>>> config.connections.foo = Modified
>>> config.connections.outputConnection = "TotallyDifferent"
>>> connections = ExampleConnections(config=config)
>>> assert(connections.inputConnection.name == "ModifiedDataset")
>>> assert(connections.outputConnection.name == "TotallyDifferent")
>>> assert connections.inputConnection.name == "ModifiedDataset"
>>> assert connections.outputConnection.name == "TotallyDifferent"
"""

# We annotate these attributes as mutable sets because that's what they are
Expand Down
26 changes: 24 additions & 2 deletions python/lsst/pipe/base/mermaid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,33 @@ def pipeline2mermaid(
edges: list[tuple[str, str, bool]] = []

def get_task_id(idx: int) -> str:
"""Generate a safe Mermaid node ID for a task."""
"""Generate a safe Mermaid node ID for a task.
Parameters
----------
idx : `int`
Task index.
Returns
-------
id : `str`
Node ID for a task.
"""
return f"TASK_{idx}"

def get_dataset_id(name: str) -> str:
"""Generate a safe Mermaid node ID for a dataset."""
"""Generate a safe Mermaid node ID for a dataset.
Parameters
----------
name : `str`
Dataset name.
Returns
-------
id : `str`
Node ID for the dataset.
"""
# Replace non-alphanumerics with underscores.
return "DATASET_" + re.sub(r"[^0-9A-Za-z_]", "_", name)

Expand Down
3 changes: 1 addition & 2 deletions python/lsst/pipe/base/pipelineTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Define `PipelineTask` class and related methods.
"""
"""Define `PipelineTask` class and related methods."""

from __future__ import annotations

Expand Down
2 changes: 1 addition & 1 deletion python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def diff_tasks(self, other: PipelineGraph) -> list[str]:
b = other.tasks[label]
if a.task_class != b.task_class:
messages.append(
f"Task {label!r} has class {a.task_class_name} in A, " f"but {b.task_class_name} in B."
f"Task {label!r} has class {a.task_class_name} in A, but {b.task_class_name} in B."
)
messages.extend(a.diff_edges(b))
return messages
Expand Down
3 changes: 1 addition & 2 deletions python/lsst/pipe/base/quantum_provenance_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,8 +1249,7 @@ def assemble_quantum_provenance_graph(
"""
if read_caveats not in ("lazy", "exhaustive", None):
raise TypeError(
f"Invalid option {read_caveats!r} for read_caveats; "
"should be 'lazy', 'exhaustive', or None."
f"Invalid option {read_caveats!r} for read_caveats; should be 'lazy', 'exhaustive', or None."
)
output_runs = []
for graph in qgraphs:
Expand Down
3 changes: 1 addition & 2 deletions python/lsst/pipe/base/taskFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Module defining TaskFactory interface.
"""
"""Module defining TaskFactory interface."""

from __future__ import annotations

Expand Down
3 changes: 1 addition & 2 deletions python/lsst/pipe/base/tests/pipelineStepTester.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Utility to facilitate testing of pipelines consisting of multiple steps.
"""
"""Utility to facilitate testing of pipelines consisting of multiple steps."""

__all__ = ["PipelineStepTester"]

Expand Down
3 changes: 1 addition & 2 deletions python/lsst/pipe/base/tests/simpleQGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Bunch of common classes and methods for use in unit tests.
"""
"""Bunch of common classes and methods for use in unit tests."""
from __future__ import annotations

__all__ = ["AddTaskConfig", "AddTask", "AddTaskFactoryMock"]
Expand Down
Loading

0 comments on commit 704effe

Please sign in to comment.