From 96b3e13690eacb82e2aa7c6fcc17c3bc49ac8ba6 Mon Sep 17 00:00:00 2001 From: Drew Oldag <47493171+drewoldag@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:57:03 -0800 Subject: [PATCH] Support reading and writing a dictionary of Ensembles to hdf5 files (#218) * WIP - Initial commit to support writing out a dictionary collection of ensembles. * Implementing a function to read dictionaries of qp ensembles from hdf5 files. * Removing files that were accidentally committed. * Adding test coverage for read and write dict. * Marking test as skipped until the required tables_io work is released. * Unskipping a test now that tables_io has new release. * Adding pragma no cover to value error. --- src/qp/__init__.py | 2 ++ src/qp/factory.py | 38 ++++++++++++++++++++++++++++++++++++++ tests/qp/test_ensemble.py | 30 +++++++++++++++++++++++++++++- 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/src/qp/__init__.py b/src/qp/__init__.py index f9cfb46..f23f038 100644 --- a/src/qp/__init__.py +++ b/src/qp/__init__.py @@ -24,6 +24,8 @@ data_length, from_tables, is_qp_file, + write_dict, + read_dict, ) from .lazy_modules import * diff --git a/src/qp/factory.py b/src/qp/factory.py index a8ef2c0..d00240c 100644 --- a/src/qp/factory.py +++ b/src/qp/factory.py @@ -357,6 +357,42 @@ def concatenate(ensembles): data[k] = np.squeeze(v) return Ensemble(gen_func, data, ancil) + @staticmethod + def write_dict(filename, ensemble_dict, **kwargs): + output_tables = {} + for key, val in ensemble_dict.items(): + # check that val is a qp.Ensemble + if not isinstance(val, Ensemble): + raise ValueError("All values in ensemble_dict must be qp.Ensemble") # pragma: no cover + + output_tables[key] = val.build_tables() + io.writeDictsToHdf5(output_tables, filename, **kwargs) + + @staticmethod + def read_dict(filename): + """Assume that filename is an HDF5 file, containing multiple qp.Ensembles + that have been stored at nparrays.""" + results = {} + + # retrieve all the top level groups. Assume each top level group + # corresponds to an ensemble. + top_level_groups = io.readHdf5GroupNames(filename) + + # for each top level group, convert the subgroups (data, meta, ancil) into + # a dictionary of dictionaries and pass the result to `from_tables`. + for top_level_group in top_level_groups: + tables = {} + keys = io.readHdf5GroupNames(filename, top_level_group) + for key_name in keys: + # retrieve the hdf5 group object + group_object, _ = io.readHdf5Group(filename, f"{top_level_group}/{key_name}") + + # use the hdf5 group object to gather data into a dictionary + tables[key_name] = io.readHdf5GroupToDict(group_object) + + results[top_level_group] = from_tables(tables) + + return results _FACTORY = Factory() @@ -377,3 +413,5 @@ def instance(): data_length = _FACTORY.data_length from_tables = _FACTORY.from_tables is_qp_file = _FACTORY.is_qp_file +write_dict = _FACTORY.write_dict +read_dict = _FACTORY.read_dict diff --git a/tests/qp/test_ensemble.py b/tests/qp/test_ensemble.py index 46a0cbe..73d6a5e 100644 --- a/tests/qp/test_ensemble.py +++ b/tests/qp/test_ensemble.py @@ -6,7 +6,7 @@ import unittest import numpy as np - +from numpy.testing import assert_array_equal, assert_array_almost_equal import qp from qp import test_data from qp.plotting import init_matplotlib @@ -235,6 +235,34 @@ def test_mixmod_with_negative_weights(self): with self.assertRaises(ValueError): _ = qp.mixmod(weights=weights, means=means, stds=sigmas) + def test_dictionary_output(self): + """Test that writing and reading a dictionary of ensembles works as expected.""" + key = "hist" + qp.hist_gen.make_test_data() + cls_test_data = qp.hist_gen.test_data[key] + ens_h = build_ensemble(cls_test_data) + + key = "interp" + qp.interp_gen.make_test_data() + cls_test_data = qp.interp_gen.test_data[key] + ens_i = build_ensemble(cls_test_data) + + output_dict = { + 'hist': ens_h, + 'interp': ens_i, + } + + qp.factory.write_dict('test_dict.hdf5', output_dict) + + input_dict = qp.factory.read_dict('test_dict.hdf5') + + assert input_dict.keys() == output_dict.keys() + + XVALS = np.linspace(0,3,100) + for ens_type in ["hist", "interp"]: + assert_array_equal(input_dict[ens_type].metadata()['pdf_name'], output_dict[ens_type].metadata()['pdf_name']) + + assert_array_almost_equal(input_dict[ens_type].pdf(XVALS), output_dict[ens_type].pdf(XVALS)) if __name__ == "__main__": unittest.main()