diff --git a/docs/axisman.rst b/docs/axisman.rst index e7875361d..b4d497f6d 100644 --- a/docs/axisman.rst +++ b/docs/axisman.rst @@ -144,6 +144,17 @@ The output of the ``wrap`` cal should be:: Note the boresight entry is marked with a ``*``, indicating that it's an AxisManager rather than a numpy array. +Data access under an AxisManager is done based on field names. For example:: + + >>> print(dset.boresight.az) + [0. 0. 0. ... 0. 0. 0.] + +Advanced data access is possible by a path like syntax. This is especially useful when +data access is dynamic and the field name is not known in advance. For example:: + + >>> print(dset["boresight.az"]) + [0. 0. 0. ... 0. 0. 0.] + To slice this object, use the restrict() method. First, let's restrict in the 'dets' axis. Since it's an Axis of type LabelAxis, the restriction selector must be a list of strings:: diff --git a/sotodlib/core/axisman.py b/sotodlib/core/axisman.py index 539b632cb..b799e091c 100644 --- a/sotodlib/core/axisman.py +++ b/sotodlib/core/axisman.py @@ -349,29 +349,63 @@ def move(self, name, new_name): self._fields[new_name] = self._fields.pop(name) self._assignments[new_name] = self._assignments.pop(name) return self - + def add_axis(self, a): assert isinstance( a, AxisInterface) self._axes[a.name] = a.copy() def __contains__(self, name): - return name in self._fields or name in self._axes + attrs = name.split(".") + tmp_item = self + while attrs: + attr_name = attrs.pop(0) + if attr_name in tmp_item._fields: + tmp_item = tmp_item._fields[attr_name] + elif attr_name in tmp_item._axes: + tmp_item = tmp_item._axes[attr_name] + else: + return False + return True def __getitem__(self, name): - if name in self._fields: - return self._fields[name] - if name in self._axes: - return self._axes[name] - raise KeyError(name) + + # We want to support options like: + # aman.focal_plane.xi . aman['focal_plane.xi'] + # We will safely assume that a getitem will always have '.' as the separator + attrs = name.split(".") + tmp_item = self + while attrs: + attr_name = attrs.pop(0) + if attr_name in tmp_item._fields: + tmp_item = tmp_item._fields[attr_name] + elif attr_name in tmp_item._axes: + tmp_item = tmp_item._axes[attr_name] + else: + raise KeyError(attr_name) + return tmp_item def __setitem__(self, name, val): - if name in self._fields: - self._fields[name] = val + + last_pos = name.rfind(".") + val_key = name + tmp_item = self + if last_pos > -1: + val_key = name[last_pos + 1:] + attrs = name[:last_pos] + tmp_item = self[attrs] + + if isinstance(val, AxisManager) and isinstance(tmp_item, AxisManager): + raise ValueError("Cannot assign AxisManager to AxisManager. Please use wrap method.") + + if val_key in tmp_item._fields: + tmp_item._fields[val_key] = val else: - raise KeyError(name) + raise KeyError(val_key) def __setattr__(self, name, value): # Assignment to members update those members + # We will assume that a path exists until the last member. + # If any member prior to that does not exist a keyerror is raised. if "_fields" in self.__dict__ and name in self._fields.keys(): self._fields[name] = value else: @@ -381,7 +415,11 @@ def __setattr__(self, name, value): def __getattr__(self, name): # Prevent members from override special class members. if name.startswith("__"): raise AttributeError(name) - return self[name] + try: + val = self[name] + except KeyError as ex: + raise AttributeError(name) from ex + return val def __dir__(self): return sorted(tuple(self.__dict__.keys()) + tuple(self.keys())) @@ -514,12 +552,12 @@ def concatenate(items, axis=0, other_fields='exact'): output.wrap(k, new_data[k], axis_map) else: if other_fields == "exact": - ## if every item named k is a scalar + ## if every item named k is a scalar err_msg = (f"The field '{k}' does not share axis '{axis}'; " f"{k} is not identical across all items " f"pass other_fields='drop' or 'first' or else " f"remove this field from the targets.") - + if np.any([np.isscalar(i[k]) for i in items]): if not np.all([np.isscalar(i[k]) for i in items]): raise ValueError(err_msg) @@ -527,14 +565,14 @@ def concatenate(items, axis=0, other_fields='exact'): raise ValueError(err_msg) output.wrap(k, items[0][k], axis_map) continue - + elif not np.all([i[k].shape==items[0][k].shape for i in items]): raise ValueError(err_msg) elif not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]): raise ValueError(err_msg) - + output.wrap(k, items[0][k].copy(), axis_map) - + elif other_fields == 'fail': raise ValueError( f"The field '{k}' does not share axis '{axis}'; " diff --git a/sotodlib/core/axisman_io.py b/sotodlib/core/axisman_io.py index f31d27874..80dc4a5a5 100644 --- a/sotodlib/core/axisman_io.py +++ b/sotodlib/core/axisman_io.py @@ -65,7 +65,7 @@ def expand_RangesMatrix(flat_rm): if shape[0] == 0: return so3g.proj.RangesMatrix([], child_shape=shape[1:]) # Otherwise non-trivial - count = np.product(shape[:-1]) + count = np.prod(shape[:-1]) start, stride = 0, count // shape[0] for i in range(0, len(ends), stride): _e = ends[i:i+stride] - start diff --git a/tests/test_core.py b/tests/test_core.py index 86a3d1b06..af6e0cd7c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,7 +1,6 @@ import unittest import tempfile import os -import shutil import numpy as np import astropy.units as u @@ -66,7 +65,7 @@ def test_130_not_inplace(self): # This should return a separate thing. rman = aman.restrict('samps', (10, 30), in_place=False) - #self.assertNotEqual(aman.a1[0], 0.) + # self.assertNotEqual(aman.a1[0], 0.) self.assertEqual(len(aman.a1), 100) self.assertEqual(len(rman.a1), 20) self.assertNotEqual(aman.a1[10], 0.) @@ -190,23 +189,23 @@ def test_170_concat(self): # ... other_fields="exact" aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + ## add scalars amanA.wrap("ans", 42) amanB.wrap("ans", 42) aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + # ... other_fields="exact" amanB.azimuth[:] = 2. with self.assertRaises(ValueError): aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + # ... other_fields="exact" and arrays of different shapes amanB.move("azimuth", None) amanB.wrap("azimuth", np.array([43,5,2,3])) with self.assertRaises(ValueError): aman = core.AxisManager.concatenate([amanA, amanB], axis='dets') - + # ... other_fields="fail" amanB.move("azimuth",None) amanB.wrap_new('azimuth', shape=('samps',))[:] = 2. @@ -269,6 +268,64 @@ def test_180_overwrite(self): self.assertNotEqual(aman.a1[2,11], 0) self.assertNotEqual(aman.a1[1,10], 1.) + def test_190_get_set(self): + dets = ["det0", "det1", "det2"] + n, ofs = 1000, 0 + aman = core.AxisManager( + core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs) + ) + child = core.AxisManager( + core.LabelAxis("dets", dets + ["det3"]), + core.OffsetAxis("samps", n, ofs - n // 2), + ) + + child2 = core.AxisManager( + core.LabelAxis("dets2", ["det4", "det5"]), + core.OffsetAxis("samps", n, ofs - n // 2), + ) + child2.wrap("tod", np.zeros((2, 1000))) + aman.wrap("child", child) + aman["child"].wrap("child2", child2) + self.assertEqual(aman["child.child2.dets2"].count, 2) + self.assertEqual(aman["child.dets"].name, "dets") + np.testing.assert_array_equal( + aman["child.child2.dets2"].vals, np.array(["det4", "det5"]) + ) + self.assertEqual(aman["child.child2.samps"].count, n // 2) + self.assertEqual(aman["child.child2.samps"].offset, 0) + self.assertEqual( + aman["child.child2.samps"].count, aman.child.child2.samps.count + ) + self.assertEqual( + aman["child.child2.samps"].offset, aman.child.child2.samps.offset + ) + + np.testing.assert_array_equal(aman["child.child2.tod"], np.zeros((2, 1000))) + + with self.assertRaises(KeyError): + aman["child2"] + + with self.assertRaises(AttributeError): + aman["child.dets.an_extra_layer"] + + self.assertIn("child.dets", aman) + self.assertIn("child.dets2", aman) # I am not sure why this is true + self.assertNotIn("child.child2.someentry", aman) + self.assertNotIn("child.child2.someentry.someotherentry", aman) + + with self.assertRaises(ValueError): + aman["child"] = child2 + + new_tods = np.ones((2, 500)) + aman.child.child2.tod = new_tods + np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 500))) + np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 500))) + + new_tods = np.ones((2, 1500)) + aman["child.child2.tod"] = new_tods + np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 1500))) + np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 1500))) + # Multi-dimensional restrictions. def test_200_multid(self):