Skip to content

Commit

Permalink
Support reading and writing a dictionary of Ensembles to hdf5 files (#…
Browse files Browse the repository at this point in the history
…218)

* WIP - Initial commit to support writing out a dictionary collection of ensembles.

* Implementing a function to read dictionaries of qp ensembles from hdf5 files.

* Removing files that were accidentally committed.

* Adding test coverage for read and write dict.

* Marking test as skipped until the required tables_io work is released.

* Unskipping a test now that tables_io has new release.

* Adding pragma no cover to value error.
  • Loading branch information
drewoldag committed Mar 6, 2024
1 parent 248b3ee commit 96b3e13
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/qp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
data_length,
from_tables,
is_qp_file,
write_dict,
read_dict,
)
from .lazy_modules import *

Expand Down
38 changes: 38 additions & 0 deletions src/qp/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,42 @@ def concatenate(ensembles):
data[k] = np.squeeze(v)
return Ensemble(gen_func, data, ancil)

@staticmethod
def write_dict(filename, ensemble_dict, **kwargs):
output_tables = {}
for key, val in ensemble_dict.items():
# check that val is a qp.Ensemble
if not isinstance(val, Ensemble):
raise ValueError("All values in ensemble_dict must be qp.Ensemble") # pragma: no cover

output_tables[key] = val.build_tables()
io.writeDictsToHdf5(output_tables, filename, **kwargs)

@staticmethod
def read_dict(filename):
"""Assume that filename is an HDF5 file, containing multiple qp.Ensembles
that have been stored at nparrays."""
results = {}

# retrieve all the top level groups. Assume each top level group
# corresponds to an ensemble.
top_level_groups = io.readHdf5GroupNames(filename)

# for each top level group, convert the subgroups (data, meta, ancil) into
# a dictionary of dictionaries and pass the result to `from_tables`.
for top_level_group in top_level_groups:
tables = {}
keys = io.readHdf5GroupNames(filename, top_level_group)
for key_name in keys:
# retrieve the hdf5 group object
group_object, _ = io.readHdf5Group(filename, f"{top_level_group}/{key_name}")

# use the hdf5 group object to gather data into a dictionary
tables[key_name] = io.readHdf5GroupToDict(group_object)

results[top_level_group] = from_tables(tables)

return results

_FACTORY = Factory()

Expand All @@ -377,3 +413,5 @@ def instance():
data_length = _FACTORY.data_length
from_tables = _FACTORY.from_tables
is_qp_file = _FACTORY.is_qp_file
write_dict = _FACTORY.write_dict
read_dict = _FACTORY.read_dict
30 changes: 29 additions & 1 deletion tests/qp/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import unittest

import numpy as np

from numpy.testing import assert_array_equal, assert_array_almost_equal
import qp
from qp import test_data
from qp.plotting import init_matplotlib
Expand Down Expand Up @@ -235,6 +235,34 @@ def test_mixmod_with_negative_weights(self):
with self.assertRaises(ValueError):
_ = qp.mixmod(weights=weights, means=means, stds=sigmas)

def test_dictionary_output(self):
"""Test that writing and reading a dictionary of ensembles works as expected."""
key = "hist"
qp.hist_gen.make_test_data()
cls_test_data = qp.hist_gen.test_data[key]
ens_h = build_ensemble(cls_test_data)

key = "interp"
qp.interp_gen.make_test_data()
cls_test_data = qp.interp_gen.test_data[key]
ens_i = build_ensemble(cls_test_data)

output_dict = {
'hist': ens_h,
'interp': ens_i,
}

qp.factory.write_dict('test_dict.hdf5', output_dict)

input_dict = qp.factory.read_dict('test_dict.hdf5')

assert input_dict.keys() == output_dict.keys()

XVALS = np.linspace(0,3,100)
for ens_type in ["hist", "interp"]:
assert_array_equal(input_dict[ens_type].metadata()['pdf_name'], output_dict[ens_type].metadata()['pdf_name'])

assert_array_almost_equal(input_dict[ens_type].pdf(XVALS), output_dict[ens_type].pdf(XVALS))

if __name__ == "__main__":
unittest.main()

0 comments on commit 96b3e13

Please sign in to comment.