Skip to content

Commit

Permalink
Add pickle support for ExecutionResources
Browse files Browse the repository at this point in the history
  • Loading branch information
timj committed Jun 21, 2023
1 parent be91805 commit d67b968
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
44 changes: 43 additions & 1 deletion python/lsst/pipe/base/_quantumContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
__all__ = ("ButlerQuantumContext", "ExecutionResources", "QuantumContext")

import numbers
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any

Expand Down Expand Up @@ -66,6 +66,48 @@ def __post_init__(self) -> None:
max_mem = max_mem.to(u.B)
object.__setattr__(self, "max_mem", max_mem)

def _reduce_kwargs(self) -> dict[str, Any]:
"""Return a dict of the keyword arguments that should be used
by `__reduce__`.
This is necessary because the dataclass is defined to be keyword
only and we wish the default pickling to only store a plain number
for the memory allocation and not a large Quantity.
Returns
-------
kwargs : `dict`
Keyword arguments to be used when pickling.
"""
kwargs: dict[str, Any] = {"num_cores": self.num_cores}
if self.max_mem is not None:
# .value is a numpy float. Cast it to a python int since we
# do not want fractional bytes.
# __post_init__ ensures that this is a Quantity but mypy cannot
# work that out.
kwargs["max_mem"] = int(self.max_mem.value) # type: ignore[union-attr]
return kwargs

@staticmethod
def _unpickle_via_factory(
factory: Callable[..., ExecutionResources], args: Sequence[Any], kwargs: dict[str, Any]
) -> ExecutionResources:
"""Unpickle something by calling a factory.
Allows unpickle using `__reduce__` with keyword
arguments as well as positional arguments.
"""
return factory(**kwargs)

def __reduce__(
self,
) -> tuple[
Callable[[Callable[..., ExecutionResources], Sequence[Any], dict[str, Any]], ExecutionResources],
tuple[type[ExecutionResources], Sequence[Any], dict[str, Any]],
]:
"""Pickler."""
return self._unpickle_via_factory, (self.__class__, [], self._reduce_kwargs())


class QuantumContext:
"""A Butler-like class specialized for a single quantum along with
Expand Down
17 changes: 17 additions & 0 deletions tests/test_pipelineTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""Simple unit test for PipelineTask.
"""

import pickle
import unittest
from typing import Any

Expand Down Expand Up @@ -320,6 +321,22 @@ def testButlerQC(self):
self.assertEqual(butlerQC.resources.num_cores, 1)
self.assertEqual(butlerQC.resources.max_mem, 5 * u.B)

def test_ExecutionResources(self):
res = pipeBase.ExecutionResources()
self.assertEqual(res.num_cores, 1)
self.assertIsNone(res.max_mem)
self.assertEqual(pickle.loads(pickle.dumps(res)), res)

res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB)
self.assertEqual(res.num_cores, 4)
self.assertEqual(res.max_mem.value, 1024 * 1024)
self.assertEqual(pickle.loads(pickle.dumps(res)), res)

res = pipeBase.ExecutionResources(max_mem=512)
self.assertEqual(res.num_cores, 1)
self.assertEqual(res.max_mem.value, 512)
self.assertEqual(pickle.loads(pickle.dumps(res)), res)

with self.assertRaises(u.UnitConversionError):
pipeBase.ExecutionResources(max_mem=1 * u.m)

Expand Down

0 comments on commit d67b968

Please sign in to comment.