diff --git a/python/lsst/pipe/base/_quantumContext.py b/python/lsst/pipe/base/_quantumContext.py index 5f0740f5..09b9a549 100644 --- a/python/lsst/pipe/base/_quantumContext.py +++ b/python/lsst/pipe/base/_quantumContext.py @@ -27,7 +27,7 @@ __all__ = ("ButlerQuantumContext", "ExecutionResources", "QuantumContext") import numbers -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Any @@ -66,6 +66,48 @@ def __post_init__(self) -> None: max_mem = max_mem.to(u.B) object.__setattr__(self, "max_mem", max_mem) + def _reduce_kwargs(self) -> dict[str, Any]: + """Return a dict of the keyword arguments that should be used + by `__reduce__`. + + This is necessary because the dataclass is defined to be keyword + only and we wish the default pickling to only store a plain number + for the memory allocation and not a large Quantity. + + Returns + ------- + kwargs : `dict` + Keyword arguments to be used when pickling. + """ + kwargs: dict[str, Any] = {"num_cores": self.num_cores} + if self.max_mem is not None: + # .value is a numpy float. Cast it to a python int since we + # do not want fractional bytes. + # __post_init__ ensures that this is a Quantity but mypy cannot + # work that out. + kwargs["max_mem"] = int(self.max_mem.value) # type: ignore[union-attr] + return kwargs + + @staticmethod + def _unpickle_via_factory( + factory: Callable[..., ExecutionResources], args: Sequence[Any], kwargs: dict[str, Any] + ) -> ExecutionResources: + """Unpickle something by calling a factory. + + Allows unpickle using `__reduce__` with keyword + arguments as well as positional arguments. + """ + return factory(**kwargs) + + def __reduce__( + self, + ) -> tuple[ + Callable[[Callable[..., ExecutionResources], Sequence[Any], dict[str, Any]], ExecutionResources], + tuple[type[ExecutionResources], Sequence[Any], dict[str, Any]], + ]: + """Pickler.""" + return self._unpickle_via_factory, (self.__class__, [], self._reduce_kwargs()) + class QuantumContext: """A Butler-like class specialized for a single quantum along with diff --git a/tests/test_pipelineTask.py b/tests/test_pipelineTask.py index dee2d7b1..87ddd285 100644 --- a/tests/test_pipelineTask.py +++ b/tests/test_pipelineTask.py @@ -22,6 +22,7 @@ """Simple unit test for PipelineTask. """ +import pickle import unittest from typing import Any @@ -320,6 +321,22 @@ def testButlerQC(self): self.assertEqual(butlerQC.resources.num_cores, 1) self.assertEqual(butlerQC.resources.max_mem, 5 * u.B) + def test_ExecutionResources(self): + res = pipeBase.ExecutionResources() + self.assertEqual(res.num_cores, 1) + self.assertIsNone(res.max_mem) + self.assertEqual(pickle.loads(pickle.dumps(res)), res) + + res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB) + self.assertEqual(res.num_cores, 4) + self.assertEqual(res.max_mem.value, 1024 * 1024) + self.assertEqual(pickle.loads(pickle.dumps(res)), res) + + res = pipeBase.ExecutionResources(max_mem=512) + self.assertEqual(res.num_cores, 1) + self.assertEqual(res.max_mem.value, 512) + self.assertEqual(pickle.loads(pickle.dumps(res)), res) + with self.assertRaises(u.UnitConversionError): pipeBase.ExecutionResources(max_mem=1 * u.m)