Skip to content

Commit

Permalink
Add new QuantumGraph.get_refs method
Browse files Browse the repository at this point in the history
  • Loading branch information
timj committed Nov 1, 2024
1 parent 1378e19 commit 6b11830
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 40 deletions.
87 changes: 87 additions & 0 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import getpass
import io
import json
import logging
import lzma
import os
import struct
Expand Down Expand Up @@ -75,6 +76,7 @@
from .quantumNode import BuildId, QuantumNode

_T = TypeVar("_T", bound="QuantumGraph")
_LOG = logging.getLogger(__name__)

# modify this constant any time the on disk representation of the save file
# changes, and update the load helpers to behave properly for each version.
Expand Down Expand Up @@ -1656,3 +1658,88 @@ def init_output_run(self, butler: LimitedButler, existing: bool = True) -> None:
self.write_configs(butler, compare_existing=existing)
self.write_packages(butler, compare_existing=existing)
self.write_init_outputs(butler, skip_existing=existing)

def get_refs(
self,
*,
include_inputs: bool = False,
include_init_inputs: bool = False,
include_init_outputs: bool = False,
include_outputs: bool = False,
conform_outputs: bool = True,
) -> tuple[set[DatasetRef], dict[str, DatastoreRecordData]]:
"""Get the requested dataset refs from the graph.
Parameters
----------
include_inputs : `bool`, optional
Include inputs.
include_init_inputs : `bool`, optional
Include init inputs.
include_init_outputs : `bool`, optional
Include init outpus.
include_outputs : `bool`, optional
Include outputs.
conform_outputs : `bool`, optional
Whether any outputs found should have their dataset types conformed
with the registry dataset types.
Returns
-------
refs : `set` [ `lsst.daf.butler.DatasetRef` ]
The requested dataset refs found in the graph.
datastore_records : `dict` [ `str`, \
`lsst.daf.butler.datastore.record_data.DatastoreRecordData` ]
Any datastore records found.
"""
datastore_records: dict[str, DatastoreRecordData] = {}
init_input_refs: set[DatasetRef] = set()
init_output_refs: set[DatasetRef] = set()

Check warning on line 1697 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1695-L1697

Added lines #L1695 - L1697 were not covered by tests

if include_init_inputs or include_init_outputs:
for task_def in self.iterTaskGraph():
if include_init_inputs:
if in_refs := self.initInputRefs(task_def):
init_input_refs.update(in_refs)

Check warning on line 1703 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1703

Added line #L1703 was not covered by tests
if include_init_outputs:
if out_refs := self.initOutputRefs(task_def):
init_output_refs.update(out_refs)

Check warning on line 1706 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1706

Added line #L1706 was not covered by tests

input_refs: set[DatasetRef] = set()
output_refs: set[DatasetRef] = set()

Check warning on line 1709 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1708-L1709

Added lines #L1708 - L1709 were not covered by tests

for qnode in self:
if include_inputs:
for other_refs in qnode.quantum.inputs.values():
input_refs.update(other_refs)

Check warning on line 1714 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1714

Added line #L1714 was not covered by tests
# Inputs can come with datastore records.
for store_name, records in qnode.quantum.datastore_records.items():
datastore_records.setdefault(store_name, DatastoreRecordData()).update(records)

Check warning on line 1717 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1717

Added line #L1717 was not covered by tests
if include_outputs:
for other_refs in qnode.quantum.outputs.values():
output_refs.update(other_refs)

Check warning on line 1720 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1720

Added line #L1720 was not covered by tests

if conform_outputs:
# Get data repository definitions from the QuantumGraph; these can
# have different storage classes than those in the quanta.
dataset_types = {dstype.name: dstype for dstype in self.registryDatasetTypes()}

Check warning on line 1725 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1725

Added line #L1725 was not covered by tests

def _update_ref(ref: DatasetRef) -> DatasetRef:
internal_dataset_type = dataset_types.get(ref.datasetType.name, ref.datasetType)

Check warning on line 1728 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1727-L1728

Added lines #L1727 - L1728 were not covered by tests
if internal_dataset_type.storageClass_name != ref.datasetType.storageClass_name:
ref = ref.overrideStorageClass(internal_dataset_type.storageClass_name)
return ref

Check warning on line 1731 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1730-L1731

Added lines #L1730 - L1731 were not covered by tests

# Convert output_refs to the data repository storage classes, too.
output_refs = {_update_ref(ref) for ref in output_refs}
init_output_refs = {_update_ref(ref) for ref in init_output_refs}

Check warning on line 1735 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1734-L1735

Added lines #L1734 - L1735 were not covered by tests

_LOG.info(

Check warning on line 1737 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1737

Added line #L1737 was not covered by tests
"Found the following datasets. InitInputs: %d; Inputs: %d; InitOutputs: %s; Outputs: %d",
len(init_input_refs),
len(input_refs),
len(init_output_refs),
len(output_refs),
)
refs = input_refs | init_input_refs | init_output_refs | output_refs
return refs, datastore_records

Check warning on line 1745 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1744-L1745

Added lines #L1744 - L1745 were not covered by tests
49 changes: 9 additions & 40 deletions python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@

import logging

from lsst.daf.butler import DatasetRef, QuantumBackedButler
from lsst.daf.butler.datastore.record_data import DatastoreRecordData
from lsst.daf.butler import QuantumBackedButler
from lsst.pipe.base import QuantumGraph
from lsst.resources import ResourcePath

Expand Down Expand Up @@ -81,48 +80,18 @@ def retrieve_artifacts_for_quanta(
nodes = qgraph_node_id or None
qgraph = QuantumGraph.loadUri(graph, nodes=nodes)

Check warning on line 81 in python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py#L80-L81

Added lines #L80 - L81 were not covered by tests

refs, datastore_records = qgraph.get_refs(

Check warning on line 83 in python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py#L83

Added line #L83 was not covered by tests
include_inputs=include_inputs,
include_init_inputs=include_inputs,
include_outputs=include_outputs,
include_init_outputs=include_outputs,
conform_outputs=True, # Need to look for predicted outputs with correct storage class.
)

# Get data repository definitions from the QuantumGraph; these can have
# different storage classes than those in the quanta.
dataset_types = {dstype.name: dstype for dstype in qgraph.registryDatasetTypes()}

Check warning on line 93 in python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py#L93

Added line #L93 was not covered by tests

datastore_records: dict[str, DatastoreRecordData] = {}
refs: set[DatasetRef] = set()
if include_inputs:
# Collect input refs used by this graph.
for task_def in qgraph.iterTaskGraph():
if in_refs := qgraph.initInputRefs(task_def):
refs.update(in_refs)
for qnode in qgraph:
for otherRefs in qnode.quantum.inputs.values():
refs.update(otherRefs)
for store_name, records in qnode.quantum.datastore_records.items():
datastore_records.setdefault(store_name, DatastoreRecordData()).update(records)
n_inputs = len(refs)
if n_inputs:
_LOG.info("Found %d input dataset%s.", n_inputs, "" if n_inputs == 1 else "s")

if include_outputs:
# Collect output refs that could be created by this graph.
original_output_refs: set[DatasetRef] = set(qgraph.globalInitOutputRefs())
for task_def in qgraph.iterTaskGraph():
if out_refs := qgraph.initOutputRefs(task_def):
original_output_refs.update(out_refs)
for qnode in qgraph:
for otherRefs in qnode.quantum.outputs.values():
original_output_refs.update(otherRefs)

# Convert output_refs to the data repository storage classes, too.
for ref in original_output_refs:
internal_dataset_type = dataset_types.get(ref.datasetType.name, ref.datasetType)
if internal_dataset_type.storageClass_name != ref.datasetType.storageClass_name:
refs.add(ref.overrideStorageClass(internal_dataset_type.storageClass_name))
else:
refs.add(ref)

n_outputs = len(refs) - n_inputs
if n_outputs:
_LOG.info("Found %d output dataset%s.", n_outputs, "" if n_outputs == 1 else "s")

# Make QBB, its config is the same as output Butler.
qbb = QuantumBackedButler.from_predicted(

Check warning on line 96 in python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py#L96

Added line #L96 was not covered by tests
config=repo,
Expand Down

0 comments on commit 6b11830

Please sign in to comment.