Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-39582: Decrease memory usage and load times when reading graphs #348

Merged
merged 9 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/DM-39582.api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Deprecated reconstituteDimensions argument from QuantumNode.from_simple
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably removal.md instead of api.md?

3 changes: 3 additions & 0 deletions doc/changes/DM-39582.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The back-end to quantum graph loading has been optimized such that duplicate objects are not created in
memory, but create shared references. This results in a large decrease in memory usage, and decrease in load
times.
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/executionButlerBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _accumulate(
for type, refs in attr.items():
# This if block is because init inputs has a different
# signature for its items
if not isinstance(refs, list):
if not isinstance(refs, (list, tuple)):
refs = [refs]
for ref in refs:
if ref.isComponent():
Expand All @@ -177,7 +177,7 @@ def _accumulate(
attr = getattr(quantum, attrName)

for type, refs in attr.items():
if not isinstance(refs, list):
if not isinstance(refs, (list, tuple)):
refs = [refs]
if type.component() is not None:
type = type.makeCompositeDatasetType()
Expand Down
9 changes: 7 additions & 2 deletions python/lsst/pipe/base/graph/_loadHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
from typing import TYPE_CHECKING, BinaryIO
from uuid import UUID

from lsst.daf.butler import DimensionUniverse
from lsst.daf.butler import DimensionUniverse, PersistenceContextVars
from lsst.resources import ResourceHandleProtocol, ResourcePath


if TYPE_CHECKING:
from ._versionDeserializers import DeserializerBase
from .graph import QuantumGraph
Expand Down Expand Up @@ -219,7 +220,11 @@ def load(
_readBytes = self._readBytes
if universe is None:
universe = headerInfo.universe
return self.deserializer.constructGraph(nodeSet, _readBytes, universe)
# use the daf butler context vars to aid in ensuring deduplication in
# object instantiation.
runner = PersistenceContextVars()
graph = runner.run(self.deserializer.constructGraph, nodeSet, _readBytes, universe)
return graph # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not entirely understand PersistenceContextVars/context vars, would it be nice if it could work like a regular context manager, e.g.:

with PersistenceContextVars():
   return self.deserializer.constructGraph(nodeSet, _readBytes, universe)

but may be it's not possible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worked on that for a while, and the outcome kept coming out more awkward, and not at all nice like that. I think it is mostly because context vars are more geared toward asyncio, callback programming style.


def _readBytes(self, start: int, stop: int) -> bytes:
"""Load the specified byte range from the ResourcePath object
Expand Down
2 changes: 2 additions & 0 deletions python/lsst/pipe/base/graph/_versionDeserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ def constructGraph(

# Turn the json back into the pydandtic model
nodeDeserialized = SerializedQuantumNode.direct(**dump)
del dump

# attach the dictionary of dimension records to the pydantic model
# these are stored separately because the are stored over and over
# and this saves a lot of space and time.
Expand Down
17 changes: 13 additions & 4 deletions python/lsst/pipe/base/graph/quantumNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import uuid
from dataclasses import dataclass
from typing import Any, NewType
import warnings

from lsst.daf.butler import (
DatasetRef,
Expand Down Expand Up @@ -96,6 +97,8 @@ class QuantumNode:
creation.
"""

__slots__ = ("quantum", "taskDef", "nodeId", "_precomputedHash")

def __post_init__(self) -> None:
# use setattr here to preserve the frozenness of the QuantumNode
self._precomputedHash: int
Expand Down Expand Up @@ -135,15 +138,21 @@ def from_simple(
universe: DimensionUniverse,
recontitutedDimensions: dict[int, tuple[str, DimensionRecord]] | None = None,
) -> QuantumNode:
if recontitutedDimensions is not None:
warnings.warn(
"The recontitutedDimensions argument is now ignored and may be removed after v 27",
category=DeprecationWarning,
)
return QuantumNode(
quantum=Quantum.from_simple(
simple.quantum, universe, reconstitutedDimensions=recontitutedDimensions
),
quantum=Quantum.from_simple(simple.quantum, universe),
taskDef=taskDefMap[simple.taskLabel],
nodeId=simple.nodeId,
)


_fields_set = {"quantum", "taskLabel", "nodeId"}


class SerializedQuantumNode(BaseModel):
quantum: SerializedQuantum
taskLabel: str
Expand All @@ -156,5 +165,5 @@ def direct(cls, *, quantum: dict[str, Any], taskLabel: str, nodeId: str) -> Seri
setter(node, "quantum", SerializedQuantum.direct(**quantum))
setter(node, "taskLabel", taskLabel)
setter(node, "nodeId", uuid.UUID(nodeId))
setter(node, "__fields_set__", {"quantum", "taskLabel", "nodeId"})
setter(node, "__fields_set__", _fields_set)
return node
4 changes: 1 addition & 3 deletions python/lsst/pipe/base/tests/mocks/_pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,7 @@ def runQuantum(

# store mock outputs
for name, refs in outputRefs:
if not isinstance(refs, list):
refs = [refs]
for ref in refs:
for ref in ensure_iterable(refs):
output = MockDataset(
ref=ref.to_simple(), quantum=mock_dataset_quantum, output_connection_name=name
)
Expand Down