From 77407ade60b878ebfd696d32f2318759094f7bd3 Mon Sep 17 00:00:00 2001 From: Ioannis Paraskevakos Date: Tue, 3 Dec 2024 10:06:52 -0500 Subject: [PATCH 1/9] wip: functionality added, missing tests --- sotodlib/core/axisman.py | 52 ++++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/sotodlib/core/axisman.py b/sotodlib/core/axisman.py index 539b632cb..f081a7d58 100644 --- a/sotodlib/core/axisman.py +++ b/sotodlib/core/axisman.py @@ -349,7 +349,7 @@ def move(self, name, new_name): self._fields[new_name] = self._fields.pop(name) self._assignments[new_name] = self._assignments.pop(name) return self - + def add_axis(self, a): assert isinstance( a, AxisInterface) self._axes[a.name] = a.copy() @@ -358,20 +358,38 @@ def __contains__(self, name): return name in self._fields or name in self._axes def __getitem__(self, name): - if name in self._fields: - return self._fields[name] - if name in self._axes: - return self._axes[name] - raise KeyError(name) + + # We want to support options like: + # aman.focal_plane.xi . aman['focal_plane.xi'] + # We will safely assume that a getitem will always have '.' as the separator + attrs = name.split(".") + tmp_item = self + while attrs: + attr_name = attrs.pop(0) + if attr_name in tmp_item._fields: + tmp_item = tmp_item._fields[attr_name] + elif attr_name in tmp_item._axes: + tmp_item = tmp_item._axes[attr_name] + else: + raise KeyError(attr_name) + return tmp_item def __setitem__(self, name, val): - if name in self._fields: - self._fields[name] = val + + last_pos = name.rfind(".") + val_key = name[last_pos:] + attrs = name[:last_pos] + tmp_item = self[attrs] + + if val_key in tmp_item._fields: + tmp_item._fields[name] = val else: - raise KeyError(name) + raise KeyError(val_key) def __setattr__(self, name, value): # Assignment to members update those members + # We will assume that a path exists until the last member. + # If any member prior to that does not exist a keyerror is raised. if "_fields" in self.__dict__ and name in self._fields.keys(): self._fields[name] = value else: @@ -381,7 +399,11 @@ def __setattr__(self, name, value): def __getattr__(self, name): # Prevent members from override special class members. if name.startswith("__"): raise AttributeError(name) - return self[name] + try: + val = self[name] + except KeyError as ex: + raise AttributeError(name) from ex + return val def __dir__(self): return sorted(tuple(self.__dict__.keys()) + tuple(self.keys())) @@ -514,12 +536,12 @@ def concatenate(items, axis=0, other_fields='exact'): output.wrap(k, new_data[k], axis_map) else: if other_fields == "exact": - ## if every item named k is a scalar + ## if every item named k is a scalar err_msg = (f"The field '{k}' does not share axis '{axis}'; " f"{k} is not identical across all items " f"pass other_fields='drop' or 'first' or else " f"remove this field from the targets.") - + if np.any([np.isscalar(i[k]) for i in items]): if not np.all([np.isscalar(i[k]) for i in items]): raise ValueError(err_msg) @@ -527,14 +549,14 @@ def concatenate(items, axis=0, other_fields='exact'): raise ValueError(err_msg) output.wrap(k, items[0][k], axis_map) continue - + elif not np.all([i[k].shape==items[0][k].shape for i in items]): raise ValueError(err_msg) elif not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]): raise ValueError(err_msg) - + output.wrap(k, items[0][k].copy(), axis_map) - + elif other_fields == 'fail': raise ValueError( f"The field '{k}' does not share axis '{axis}'; " From 1cfd3b96114ef1ea489e7401521541cbd3cd91b4 Mon Sep 17 00:00:00 2001 From: Ioannis Paraskevakos Date: Wed, 4 Dec 2024 13:01:00 -0500 Subject: [PATCH 2/9] updating setitem to accept only AxisManager --- sotodlib/core/axisman.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sotodlib/core/axisman.py b/sotodlib/core/axisman.py index f081a7d58..1cf44f5b5 100644 --- a/sotodlib/core/axisman.py +++ b/sotodlib/core/axisman.py @@ -374,17 +374,15 @@ def __getitem__(self, name): raise KeyError(attr_name) return tmp_item - def __setitem__(self, name, val): + def __setitem__(self, name:str, val: "AxisManager"): + if not isinstance(val, AxisManager): + raise ValueError("Only AxisManagers can be setting values") last_pos = name.rfind(".") - val_key = name[last_pos:] + val_key = name[last_pos + 1:] attrs = name[:last_pos] tmp_item = self[attrs] - - if val_key in tmp_item._fields: - tmp_item._fields[name] = val - else: - raise KeyError(val_key) + tmp_item.wrap(name=val_key, data=val, overwrite=True) def __setattr__(self, name, value): # Assignment to members update those members From a7f1faf25e0a4171e2c3e8cf28d9b7d4f0f843cb Mon Sep 17 00:00:00 2001 From: Ioannis Paraskevakos Date: Wed, 4 Dec 2024 13:22:30 -0500 Subject: [PATCH 3/9] wip: somewhat reversing setitem --- sotodlib/core/axisman.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sotodlib/core/axisman.py b/sotodlib/core/axisman.py index 1cf44f5b5..34a2f9d0b 100644 --- a/sotodlib/core/axisman.py +++ b/sotodlib/core/axisman.py @@ -374,15 +374,14 @@ def __getitem__(self, name): raise KeyError(attr_name) return tmp_item - def __setitem__(self, name:str, val: "AxisManager"): + def __setitem__(self, name, val): - if not isinstance(val, AxisManager): - raise ValueError("Only AxisManagers can be setting values") last_pos = name.rfind(".") val_key = name[last_pos + 1:] attrs = name[:last_pos] tmp_item = self[attrs] - tmp_item.wrap(name=val_key, data=val, overwrite=True) + if val_key in tmp_item._fields: + tmp_item._fields[val_key] = val def __setattr__(self, name, value): # Assignment to members update those members From c414e67a88b402e6496db32cdc29fde9d4e73012 Mon Sep 17 00:00:00 2001 From: Ioannis Paraskevakos Date: Wed, 4 Dec 2024 14:20:35 -0500 Subject: [PATCH 4/9] fixed incoming bug and added test --- sotodlib/core/axisman.py | 12 +++++++++--- tests/test_core.py | 24 +++++++++++++++++++----- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/sotodlib/core/axisman.py b/sotodlib/core/axisman.py index 34a2f9d0b..2b51d9de3 100644 --- a/sotodlib/core/axisman.py +++ b/sotodlib/core/axisman.py @@ -377,11 +377,17 @@ def __getitem__(self, name): def __setitem__(self, name, val): last_pos = name.rfind(".") - val_key = name[last_pos + 1:] - attrs = name[:last_pos] - tmp_item = self[attrs] + val_key = name + tmp_item = self + if last_pos > -1: + val_key = name[last_pos + 1:] + attrs = name[:last_pos] + tmp_item = self[attrs] + if val_key in tmp_item._fields: tmp_item._fields[val_key] = val + else: + raise KeyError(name) def __setattr__(self, name, value): # Assignment to members update those members diff --git a/tests/test_core.py b/tests/test_core.py index 86a3d1b06..94411ea62 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -66,7 +66,7 @@ def test_130_not_inplace(self): # This should return a separate thing. rman = aman.restrict('samps', (10, 30), in_place=False) - #self.assertNotEqual(aman.a1[0], 0.) + # self.assertNotEqual(aman.a1[0], 0.) self.assertEqual(len(aman.a1), 100) self.assertEqual(len(rman.a1), 20) self.assertNotEqual(aman.a1[10], 0.) @@ -190,23 +190,23 @@ def test_170_concat(self): # ... other_fields="exact" aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + ## add scalars amanA.wrap("ans", 42) amanB.wrap("ans", 42) aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + # ... other_fields="exact" amanB.azimuth[:] = 2. with self.assertRaises(ValueError): aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + # ... other_fields="exact" and arrays of different shapes amanB.move("azimuth", None) amanB.wrap("azimuth", np.array([43,5,2,3])) with self.assertRaises(ValueError): aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + # ... other_fields="fail" amanB.move("azimuth",None) amanB.wrap_new('azimuth', shape=('samps',))[:] = 2. @@ -299,6 +299,20 @@ def test_300_restrict(self): # wrap of AxisManager, merge. + def test_get_set(self): + dets = ["det0", "det1", "det2"] + n, ofs = 1000, 0 + aman = core.AxisManager( + core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs) + ) + child = core.AxisManager( + core.LabelAxis("dets", dets + ["det3"]), + core.OffsetAxis("samps", n, ofs - n // 2), + ) + aman.wrap("child", child) + self.assertAlmostEqual(aman["child.dets"].count, 3) + self.assertAlmostEqual(aman["child.dets"].name, "dets") + def test_400_child(self): dets = ['det0', 'det1', 'det2'] n, ofs = 1000, 0 From eed780a7bc2856c50912a27ec4871e513a34365a Mon Sep 17 00:00:00 2001 From: Ioannis Paraskevakos Date: Mon, 9 Dec 2024 16:07:03 -0500 Subject: [PATCH 5/9] Addressing PR comments --- docs/axisman.rst | 11 +++++++++++ sotodlib/core/axisman.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/docs/axisman.rst b/docs/axisman.rst index e7875361d..5cef8fb81 100644 --- a/docs/axisman.rst +++ b/docs/axisman.rst @@ -144,6 +144,17 @@ The output of the ``wrap`` cal should be:: Note the boresight entry is marked with a ``*``, indicating that it's an AxisManager rather than a numpy array. +To access data in an AxisManager, use a path-like syntax where +attribute names are separated by dots:: + + >>> n, ofs = 1000, 0 + >>> dets = ["det0", "det1", "det2"] + >>> aman = core.AxisManager(core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs)) + >>> child = core.AxisManager(core.LabelAxis("dets", dets + ["det3"]),core.OffsetAxis("samps", n, ofs - n // 2),) + >>> aman.wrap("child", child) + >>> print(aman["child.dets"]) + LabelAxis(3:'det0','det1','det2') + To slice this object, use the restrict() method. First, let's restrict in the 'dets' axis. Since it's an Axis of type LabelAxis, the restriction selector must be a list of strings:: diff --git a/sotodlib/core/axisman.py b/sotodlib/core/axisman.py index 2b51d9de3..ddceaeb11 100644 --- a/sotodlib/core/axisman.py +++ b/sotodlib/core/axisman.py @@ -355,7 +355,17 @@ def add_axis(self, a): self._axes[a.name] = a.copy() def __contains__(self, name): - return name in self._fields or name in self._axes + attrs = name.split(".") + tmp_item = self + while attrs: + attr_name = attrs.pop(0) + if attr_name in tmp_item._fields: + tmp_item = tmp_item._fields[attr_name] + elif attr_name in tmp_item._axes: + tmp_item = tmp_item._axes[attr_name] + else: + return False + return True def __getitem__(self, name): From 29500b11f58f74dec0769dd1574ecb51f08d0793 Mon Sep 17 00:00:00 2001 From: Ioannis Paraskevakos Date: Mon, 9 Dec 2024 16:08:40 -0500 Subject: [PATCH 6/9] adding missing test --- tests/test_core.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 94411ea62..8deb846c7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -298,7 +298,6 @@ def test_300_restrict(self): self.assertNotEqual(aman.a1[0, 0, 0, 1], 0.) # wrap of AxisManager, merge. - def test_get_set(self): dets = ["det0", "det1", "det2"] n, ofs = 1000, 0 @@ -309,9 +308,34 @@ def test_get_set(self): core.LabelAxis("dets", dets + ["det3"]), core.OffsetAxis("samps", n, ofs - n // 2), ) + + child2 = core.AxisManager( + core.LabelAxis("dets2", ["det4", "det5"]), + core.OffsetAxis("samps", n, ofs - n // 2), + ) aman.wrap("child", child) - self.assertAlmostEqual(aman["child.dets"].count, 3) - self.assertAlmostEqual(aman["child.dets"].name, "dets") + aman["child"].wrap("child2", child2) + self.assertEqual(aman["child.child2.dets2"].count, 2) + self.assertEqual(aman["child.dets"].name, "dets") + np.testing.assert_array_equal(aman["child.child2.dets2"].vals, np.array(["det4", "det5"])) + with self.assertRaises(KeyError): + aman["child.someentry"] + + with self.assertRaises(KeyError): + aman["child2"] + + with self.assertRaises(AttributeError): + aman["child.dets.an_extra_layer"] + + self.assertIn("child.dets", aman) + self.assertIn("child.dets2", aman) # I am not sure why this is true + self.assertNotIn("child.child2.someentry", aman) + + aman["child"] = child2 + print(aman["child"]) + self.assertEqual(aman["child.dets2"].count, 2) + self.assertEqual(aman["child.dets2"].name, "dets2") + np.testing.assert_array_equal(aman["child.dets2"].vals, np.array(["det4", "det5"])) def test_400_child(self): dets = ['det0', 'det1', 'det2'] From dcc805cba8844f329eeae68e5238fa9def92db3c Mon Sep 17 00:00:00 2001 From: Ioannis Paraskevakos Date: Thu, 12 Dec 2024 16:30:51 +0000 Subject: [PATCH 7/9] tests updated --- sotodlib/core/axisman.py | 8 ++++---- sotodlib/core/axisman_io.py | 2 +- tests/test_core.py | 27 ++++++++++++++++++++------- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/sotodlib/core/axisman.py b/sotodlib/core/axisman.py index ddceaeb11..257650e39 100644 --- a/sotodlib/core/axisman.py +++ b/sotodlib/core/axisman.py @@ -394,10 +394,10 @@ def __setitem__(self, name, val): attrs = name[:last_pos] tmp_item = self[attrs] - if val_key in tmp_item._fields: - tmp_item._fields[val_key] = val - else: - raise KeyError(name) + if isinstance(val, AxisManager) and isinstance(tmp_item, AxisManager): + raise ValueError("Cannot assign AxisManager to AxisManager. Please use wrap method.") + + tmp_item.__setattr__(val_key, val) def __setattr__(self, name, value): # Assignment to members update those members diff --git a/sotodlib/core/axisman_io.py b/sotodlib/core/axisman_io.py index f31d27874..80dc4a5a5 100644 --- a/sotodlib/core/axisman_io.py +++ b/sotodlib/core/axisman_io.py @@ -65,7 +65,7 @@ def expand_RangesMatrix(flat_rm): if shape[0] == 0: return so3g.proj.RangesMatrix([], child_shape=shape[1:]) # Otherwise non-trivial - count = np.product(shape[:-1]) + count = np.prod(shape[:-1]) start, stride = 0, count // shape[0] for i in range(0, len(ends), stride): _e = ends[i:i+stride] - start diff --git a/tests/test_core.py b/tests/test_core.py index 8deb846c7..494299b34 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,6 +3,7 @@ import os import shutil +from networkx import selfloop_edges import numpy as np import astropy.units as u from sotodlib import core @@ -313,14 +314,19 @@ def test_get_set(self): core.LabelAxis("dets2", ["det4", "det5"]), core.OffsetAxis("samps", n, ofs - n // 2), ) + child2.wrap("tod", np.zeros((2,1000))) aman.wrap("child", child) aman["child"].wrap("child2", child2) self.assertEqual(aman["child.child2.dets2"].count, 2) self.assertEqual(aman["child.dets"].name, "dets") np.testing.assert_array_equal(aman["child.child2.dets2"].vals, np.array(["det4", "det5"])) - with self.assertRaises(KeyError): - aman["child.someentry"] + self.assertEqual(aman["child.child2.samps"].count, n // 2) + self.assertEqual(aman["child.child2.samps"].offset, 0) + self.assertEqual(aman["child.child2.samps"].count, aman.child.child2.samps.count) + self.assertEqual(aman["child.child2.samps"].offset, aman.child.child2.samps.offset) + np.testing.assert_array_equal(aman["child.child2.tod"], np.zeros((2,1000))) + with self.assertRaises(KeyError): aman["child2"] @@ -331,11 +337,18 @@ def test_get_set(self): self.assertIn("child.dets2", aman) # I am not sure why this is true self.assertNotIn("child.child2.someentry", aman) - aman["child"] = child2 - print(aman["child"]) - self.assertEqual(aman["child.dets2"].count, 2) - self.assertEqual(aman["child.dets2"].name, "dets2") - np.testing.assert_array_equal(aman["child.dets2"].vals, np.array(["det4", "det5"])) + with self.assertRaises(ValueError): + aman["child"] = child2 + + new_tods = np.ones((2, 500)) + aman.child.child2.tod = new_tods + np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 500))) + np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 500))) + + new_tods = np.ones((2, 1500)) + aman["child.child2.tod"] = new_tods + np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 1500))) + np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 1500))) def test_400_child(self): dets = ['det0', 'det1', 'det2'] From d316c517b865d38b7ab5b863453bc67982f3a517 Mon Sep 17 00:00:00 2001 From: Ioannis Paraskevakos Date: Mon, 16 Dec 2024 16:58:25 +0000 Subject: [PATCH 8/9] fixing docs to show advanced usage --- docs/axisman.rst | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/axisman.rst b/docs/axisman.rst index 5cef8fb81..8974f22bb 100644 --- a/docs/axisman.rst +++ b/docs/axisman.rst @@ -144,8 +144,17 @@ The output of the ``wrap`` cal should be:: Note the boresight entry is marked with a ``*``, indicating that it's an AxisManager rather than a numpy array. -To access data in an AxisManager, use a path-like syntax where -attribute names are separated by dots:: +Data access under an AxisManager is done based on field names. For example: + >>> n, ofs = 1000, 0 + >>> dets = ["det0", "det1", "det2"] + >>> aman = core.AxisManager(core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs)) + >>> child = core.AxisManager(core.LabelAxis("dets", dets + ["det3"]),core.OffsetAxis("samps", n, ofs - n // 2),) + >>> aman.wrap("child", child) + >>> print(aman.child.dets) + LabelAxis(3:'det0','det1','det2') + +Advanced data access is possible by a path like syntax. This is especially useful when +data access is dynamic and the field name is not known in advance. For example: >>> n, ofs = 1000, 0 >>> dets = ["det0", "det1", "det2"] From a3c86bd9c16f69c45a2296b5055c135fbcaa35a9 Mon Sep 17 00:00:00 2001 From: Ioannis Paraskevakos Date: Wed, 8 Jan 2025 16:07:20 -0500 Subject: [PATCH 9/9] addressing Matthew's comments --- docs/axisman.rst | 25 ++++-------- sotodlib/core/axisman.py | 5 ++- tests/test_core.py | 86 +++++++++++++++++++++------------------- 3 files changed, 58 insertions(+), 58 deletions(-) diff --git a/docs/axisman.rst b/docs/axisman.rst index 8974f22bb..b4d497f6d 100644 --- a/docs/axisman.rst +++ b/docs/axisman.rst @@ -144,25 +144,16 @@ The output of the ``wrap`` cal should be:: Note the boresight entry is marked with a ``*``, indicating that it's an AxisManager rather than a numpy array. -Data access under an AxisManager is done based on field names. For example: - >>> n, ofs = 1000, 0 - >>> dets = ["det0", "det1", "det2"] - >>> aman = core.AxisManager(core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs)) - >>> child = core.AxisManager(core.LabelAxis("dets", dets + ["det3"]),core.OffsetAxis("samps", n, ofs - n // 2),) - >>> aman.wrap("child", child) - >>> print(aman.child.dets) - LabelAxis(3:'det0','det1','det2') +Data access under an AxisManager is done based on field names. For example:: + + >>> print(dset.boresight.az) + [0. 0. 0. ... 0. 0. 0.] Advanced data access is possible by a path like syntax. This is especially useful when -data access is dynamic and the field name is not known in advance. For example: - - >>> n, ofs = 1000, 0 - >>> dets = ["det0", "det1", "det2"] - >>> aman = core.AxisManager(core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs)) - >>> child = core.AxisManager(core.LabelAxis("dets", dets + ["det3"]),core.OffsetAxis("samps", n, ofs - n // 2),) - >>> aman.wrap("child", child) - >>> print(aman["child.dets"]) - LabelAxis(3:'det0','det1','det2') +data access is dynamic and the field name is not known in advance. For example:: + + >>> print(dset["boresight.az"]) + [0. 0. 0. ... 0. 0. 0.] To slice this object, use the restrict() method. First, let's restrict in the 'dets' axis. Since it's an Axis of type LabelAxis, diff --git a/sotodlib/core/axisman.py b/sotodlib/core/axisman.py index 257650e39..b799e091c 100644 --- a/sotodlib/core/axisman.py +++ b/sotodlib/core/axisman.py @@ -397,7 +397,10 @@ def __setitem__(self, name, val): if isinstance(val, AxisManager) and isinstance(tmp_item, AxisManager): raise ValueError("Cannot assign AxisManager to AxisManager. Please use wrap method.") - tmp_item.__setattr__(val_key, val) + if val_key in tmp_item._fields: + tmp_item._fields[val_key] = val + else: + raise KeyError(val_key) def __setattr__(self, name, value): # Assignment to members update those members diff --git a/tests/test_core.py b/tests/test_core.py index 494299b34..af6e0cd7c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,9 +1,7 @@ import unittest import tempfile import os -import shutil -from networkx import selfloop_edges import numpy as np import astropy.units as u from sotodlib import core @@ -270,36 +268,7 @@ def test_180_overwrite(self): self.assertNotEqual(aman.a1[2,11], 0) self.assertNotEqual(aman.a1[1,10], 1.) - # Multi-dimensional restrictions. - - def test_200_multid(self): - dets = ['det0', 'det1', 'det2'] - a1 = np.zeros((len(dets), len(dets))) - a1[2, 2] = 1. - aman = core.AxisManager(core.LabelAxis('dets', dets)) - aman.wrap('a1', a1, [(0, 'dets'), (1, 'dets')]) - aman.restrict('dets', ['det1', 'det2']) - self.assertEqual(aman.a1.shape, (2, 2)) - self.assertNotEqual(aman.a1[1, 1], 0.) - - def test_300_restrict(self): - dets = ['det0', 'det1', 'det2'] - n, ofs = 1000, 5000 - aman = core.AxisManager( - core.LabelAxis('dets', dets), - core.OffsetAxis('samps', n, ofs)) - # Super-correlation matrix. - a1 = np.zeros((len(dets), len(dets), n, n)) - a1[1, 1, 20, 21] = 1. - aman.wrap('a1', a1, [(0, 'dets'), (1, 'dets'), - (2, 'samps'), (3, 'samps')]) - aman.restrict('dets', ['det1']).restrict('samps', (20 + ofs, 30 + ofs)) - self.assertEqual(aman.shape, (1, 10)) - self.assertEqual(aman.a1.shape, (1, 1, 10, 10)) - self.assertNotEqual(aman.a1[0, 0, 0, 1], 0.) - - # wrap of AxisManager, merge. - def test_get_set(self): + def test_190_get_set(self): dets = ["det0", "det1", "det2"] n, ofs = 1000, 0 aman = core.AxisManager( @@ -314,19 +283,25 @@ def test_get_set(self): core.LabelAxis("dets2", ["det4", "det5"]), core.OffsetAxis("samps", n, ofs - n // 2), ) - child2.wrap("tod", np.zeros((2,1000))) + child2.wrap("tod", np.zeros((2, 1000))) aman.wrap("child", child) aman["child"].wrap("child2", child2) self.assertEqual(aman["child.child2.dets2"].count, 2) self.assertEqual(aman["child.dets"].name, "dets") - np.testing.assert_array_equal(aman["child.child2.dets2"].vals, np.array(["det4", "det5"])) + np.testing.assert_array_equal( + aman["child.child2.dets2"].vals, np.array(["det4", "det5"]) + ) self.assertEqual(aman["child.child2.samps"].count, n // 2) self.assertEqual(aman["child.child2.samps"].offset, 0) - self.assertEqual(aman["child.child2.samps"].count, aman.child.child2.samps.count) - self.assertEqual(aman["child.child2.samps"].offset, aman.child.child2.samps.offset) + self.assertEqual( + aman["child.child2.samps"].count, aman.child.child2.samps.count + ) + self.assertEqual( + aman["child.child2.samps"].offset, aman.child.child2.samps.offset + ) + + np.testing.assert_array_equal(aman["child.child2.tod"], np.zeros((2, 1000))) - np.testing.assert_array_equal(aman["child.child2.tod"], np.zeros((2,1000))) - with self.assertRaises(KeyError): aman["child2"] @@ -334,8 +309,9 @@ def test_get_set(self): aman["child.dets.an_extra_layer"] self.assertIn("child.dets", aman) - self.assertIn("child.dets2", aman) # I am not sure why this is true + self.assertIn("child.dets2", aman) # I am not sure why this is true self.assertNotIn("child.child2.someentry", aman) + self.assertNotIn("child.child2.someentry.someotherentry", aman) with self.assertRaises(ValueError): aman["child"] = child2 @@ -344,12 +320,42 @@ def test_get_set(self): aman.child.child2.tod = new_tods np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 500))) np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 500))) - + new_tods = np.ones((2, 1500)) aman["child.child2.tod"] = new_tods np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 1500))) np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 1500))) + # Multi-dimensional restrictions. + + def test_200_multid(self): + dets = ['det0', 'det1', 'det2'] + a1 = np.zeros((len(dets), len(dets))) + a1[2, 2] = 1. + aman = core.AxisManager(core.LabelAxis('dets', dets)) + aman.wrap('a1', a1, [(0, 'dets'), (1, 'dets')]) + aman.restrict('dets', ['det1', 'det2']) + self.assertEqual(aman.a1.shape, (2, 2)) + self.assertNotEqual(aman.a1[1, 1], 0.) + + def test_300_restrict(self): + dets = ['det0', 'det1', 'det2'] + n, ofs = 1000, 5000 + aman = core.AxisManager( + core.LabelAxis('dets', dets), + core.OffsetAxis('samps', n, ofs)) + # Super-correlation matrix. + a1 = np.zeros((len(dets), len(dets), n, n)) + a1[1, 1, 20, 21] = 1. + aman.wrap('a1', a1, [(0, 'dets'), (1, 'dets'), + (2, 'samps'), (3, 'samps')]) + aman.restrict('dets', ['det1']).restrict('samps', (20 + ofs, 30 + ofs)) + self.assertEqual(aman.shape, (1, 10)) + self.assertEqual(aman.a1.shape, (1, 1, 10, 10)) + self.assertNotEqual(aman.a1[0, 0, 0, 1], 0.) + + # wrap of AxisManager, merge. + def test_400_child(self): dets = ['det0', 'det1', 'det2'] n, ofs = 1000, 0