diff --git a/mypy.ini b/mypy.ini index caec25da68..8097d51c08 100644 --- a/mypy.ini +++ b/mypy.ini @@ -18,6 +18,9 @@ ignore_missing_imports = True [mypy-numpy.*] ignore_missing_imports = True +[mypy-matplotlib.*] +ignore_missing_imports = True + [mypy-pyarrow.*] ignore_missing_imports = True diff --git a/python/lsst/daf/butler/configs/storageClasses.yaml b/python/lsst/daf/butler/configs/storageClasses.yaml index f7a58c9d79..7198e20f2c 100644 --- a/python/lsst/daf/butler/configs/storageClasses.yaml +++ b/python/lsst/daf/butler/configs/storageClasses.yaml @@ -342,10 +342,14 @@ storageClasses: dict: lsst.utils.packages.Packages NumpyArray: pytype: numpy.ndarray + converters: + matplotlib.figure.Figure: lsst.daf.butler.formatters.matplotlib.MatplotlibFormatter.dummyConverter Thumbnail: pytype: numpy.ndarray Plot: pytype: matplotlib.figure.Figure + converters: + numpy.ndarray: lsst.daf.butler.formatters.matplotlib.MatplotlibFormatter.fromArray MetricValue: pytype: lsst.verify.Measurement StampsBase: diff --git a/python/lsst/daf/butler/formatters/matplotlib.py b/python/lsst/daf/butler/formatters/matplotlib.py index d7dc4d2c01..69769d841f 100644 --- a/python/lsst/daf/butler/formatters/matplotlib.py +++ b/python/lsst/daf/butler/formatters/matplotlib.py @@ -27,6 +27,9 @@ from typing import Any +import matplotlib.pyplot as plt +import numpy as np + from .file import FileFormatter @@ -38,8 +41,22 @@ class MatplotlibFormatter(FileFormatter): def _readFile(self, path: str, pytype: type[Any] | None = None) -> Any: # docstring inherited from FileFormatter._readFile - raise NotImplementedError(f"matplotlib figures cannot be read by the butler; path is {path}") + return plt.imread(path) def _writeFile(self, inMemoryDataset: Any) -> None: # docstring inherited from FileFormatter._writeFile inMemoryDataset.savefig(self.fileDescriptor.location.path) + + @staticmethod + def fromArray(cls: np.ndarray) -> plt.Figure: + """Convert an array into a Figure.""" + fig = plt.figure() + plt.imshow(cls) + return fig + + @staticmethod + def dummyCovnerter(cls: np.ndarray) -> np.ndarray: + """This converter exists to trick the Butler into allowing + a numpy array on read with ``storageClass='NumpyArray'``. + """ + return cls diff --git a/tests/test_matplotlibFormatter.py b/tests/test_matplotlibFormatter.py index 5987027abc..080097d445 100644 --- a/tests/test_matplotlibFormatter.py +++ b/tests/test_matplotlibFormatter.py @@ -27,6 +27,8 @@ import unittest from random import Random +import numpy as np + try: import matplotlib @@ -78,8 +80,13 @@ def testMatplotlibFormatter(self): pyplot.gcf().savefig(file.name) self.assertTrue(filecmp.cmp(local.ospath, file.name, shallow=True)) self.assertTrue(butler.exists(ref)) - with self.assertRaises(ValueError): - butler.get(ref) + + fig = butler.get(ref) + # Ensure that the result is a figure + self.assertTrue(isinstance(fig, pyplot.Figure)) + image = butler.get(ref, storageClass="NumpyArray") + self.assertTrue(isinstance(image, np.ndarray)) + butler.pruneDatasets([ref], unstore=True, purge=True) self.assertFalse(butler.exists(ref))