Skip to content

Commit

Permalink
Gp/feat/aman arbitrary path (#1057)
Browse files Browse the repository at this point in the history
* wip: functionality added, missing tests

* updating setitem to accept only AxisManager

* wip: somewhat reversing setitem

* fixed incoming bug and added test

* Addressing PR comments

* adding missing test

* tests updated

* fixing docs to show advanced usage

* addressing Matthew's comments
  • Loading branch information
iparask authored Jan 10, 2025
1 parent 642b7df commit 34ffae6
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 23 deletions.
11 changes: 11 additions & 0 deletions docs/axisman.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ The output of the ``wrap`` cal should be::
Note the boresight entry is marked with a ``*``, indicating that it's
an AxisManager rather than a numpy array.

Data access under an AxisManager is done based on field names. For example::

>>> print(dset.boresight.az)
[0. 0. 0. ... 0. 0. 0.]

Advanced data access is possible by a path like syntax. This is especially useful when
data access is dynamic and the field name is not known in advance. For example::

>>> print(dset["boresight.az"])
[0. 0. 0. ... 0. 0. 0.]

To slice this object, use the restrict() method. First, let's
restrict in the 'dets' axis. Since it's an Axis of type LabelAxis,
the restriction selector must be a list of strings::
Expand Down
70 changes: 54 additions & 16 deletions sotodlib/core/axisman.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,29 +349,63 @@ def move(self, name, new_name):
self._fields[new_name] = self._fields.pop(name)
self._assignments[new_name] = self._assignments.pop(name)
return self

def add_axis(self, a):
assert isinstance( a, AxisInterface)
self._axes[a.name] = a.copy()

def __contains__(self, name):
return name in self._fields or name in self._axes
attrs = name.split(".")
tmp_item = self
while attrs:
attr_name = attrs.pop(0)
if attr_name in tmp_item._fields:
tmp_item = tmp_item._fields[attr_name]
elif attr_name in tmp_item._axes:
tmp_item = tmp_item._axes[attr_name]
else:
return False
return True

def __getitem__(self, name):
if name in self._fields:
return self._fields[name]
if name in self._axes:
return self._axes[name]
raise KeyError(name)

# We want to support options like:
# aman.focal_plane.xi . aman['focal_plane.xi']
# We will safely assume that a getitem will always have '.' as the separator
attrs = name.split(".")
tmp_item = self
while attrs:
attr_name = attrs.pop(0)
if attr_name in tmp_item._fields:
tmp_item = tmp_item._fields[attr_name]
elif attr_name in tmp_item._axes:
tmp_item = tmp_item._axes[attr_name]
else:
raise KeyError(attr_name)
return tmp_item

def __setitem__(self, name, val):
if name in self._fields:
self._fields[name] = val

last_pos = name.rfind(".")
val_key = name
tmp_item = self
if last_pos > -1:
val_key = name[last_pos + 1:]
attrs = name[:last_pos]
tmp_item = self[attrs]

if isinstance(val, AxisManager) and isinstance(tmp_item, AxisManager):
raise ValueError("Cannot assign AxisManager to AxisManager. Please use wrap method.")

if val_key in tmp_item._fields:
tmp_item._fields[val_key] = val
else:
raise KeyError(name)
raise KeyError(val_key)

def __setattr__(self, name, value):
# Assignment to members update those members
# We will assume that a path exists until the last member.
# If any member prior to that does not exist a keyerror is raised.
if "_fields" in self.__dict__ and name in self._fields.keys():
self._fields[name] = value
else:
Expand All @@ -381,7 +415,11 @@ def __setattr__(self, name, value):
def __getattr__(self, name):
# Prevent members from override special class members.
if name.startswith("__"): raise AttributeError(name)
return self[name]
try:
val = self[name]
except KeyError as ex:
raise AttributeError(name) from ex
return val

def __dir__(self):
return sorted(tuple(self.__dict__.keys()) + tuple(self.keys()))
Expand Down Expand Up @@ -514,27 +552,27 @@ def concatenate(items, axis=0, other_fields='exact'):
output.wrap(k, new_data[k], axis_map)
else:
if other_fields == "exact":
## if every item named k is a scalar
## if every item named k is a scalar
err_msg = (f"The field '{k}' does not share axis '{axis}'; "
f"{k} is not identical across all items "
f"pass other_fields='drop' or 'first' or else "
f"remove this field from the targets.")

if np.any([np.isscalar(i[k]) for i in items]):
if not np.all([np.isscalar(i[k]) for i in items]):
raise ValueError(err_msg)
if not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]):
raise ValueError(err_msg)
output.wrap(k, items[0][k], axis_map)
continue

elif not np.all([i[k].shape==items[0][k].shape for i in items]):
raise ValueError(err_msg)
elif not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]):
raise ValueError(err_msg)

output.wrap(k, items[0][k].copy(), axis_map)

elif other_fields == 'fail':
raise ValueError(
f"The field '{k}' does not share axis '{axis}'; "
Expand Down
2 changes: 1 addition & 1 deletion sotodlib/core/axisman_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def expand_RangesMatrix(flat_rm):
if shape[0] == 0:
return so3g.proj.RangesMatrix([], child_shape=shape[1:])
# Otherwise non-trivial
count = np.product(shape[:-1])
count = np.prod(shape[:-1])
start, stride = 0, count // shape[0]
for i in range(0, len(ends), stride):
_e = ends[i:i+stride] - start
Expand Down
69 changes: 63 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest
import tempfile
import os
import shutil

import numpy as np
import astropy.units as u
Expand Down Expand Up @@ -66,7 +65,7 @@ def test_130_not_inplace(self):

# This should return a separate thing.
rman = aman.restrict('samps', (10, 30), in_place=False)
#self.assertNotEqual(aman.a1[0], 0.)
# self.assertNotEqual(aman.a1[0], 0.)
self.assertEqual(len(aman.a1), 100)
self.assertEqual(len(rman.a1), 20)
self.assertNotEqual(aman.a1[10], 0.)
Expand Down Expand Up @@ -190,23 +189,23 @@ def test_170_concat(self):

# ... other_fields="exact"
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

## add scalars
amanA.wrap("ans", 42)
amanB.wrap("ans", 42)
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="exact"
amanB.azimuth[:] = 2.
with self.assertRaises(ValueError):
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="exact" and arrays of different shapes
amanB.move("azimuth", None)
amanB.wrap("azimuth", np.array([43,5,2,3]))
with self.assertRaises(ValueError):
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="fail"
amanB.move("azimuth",None)
amanB.wrap_new('azimuth', shape=('samps',))[:] = 2.
Expand Down Expand Up @@ -269,6 +268,64 @@ def test_180_overwrite(self):
self.assertNotEqual(aman.a1[2,11], 0)
self.assertNotEqual(aman.a1[1,10], 1.)

def test_190_get_set(self):
dets = ["det0", "det1", "det2"]
n, ofs = 1000, 0
aman = core.AxisManager(
core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs)
)
child = core.AxisManager(
core.LabelAxis("dets", dets + ["det3"]),
core.OffsetAxis("samps", n, ofs - n // 2),
)

child2 = core.AxisManager(
core.LabelAxis("dets2", ["det4", "det5"]),
core.OffsetAxis("samps", n, ofs - n // 2),
)
child2.wrap("tod", np.zeros((2, 1000)))
aman.wrap("child", child)
aman["child"].wrap("child2", child2)
self.assertEqual(aman["child.child2.dets2"].count, 2)
self.assertEqual(aman["child.dets"].name, "dets")
np.testing.assert_array_equal(
aman["child.child2.dets2"].vals, np.array(["det4", "det5"])
)
self.assertEqual(aman["child.child2.samps"].count, n // 2)
self.assertEqual(aman["child.child2.samps"].offset, 0)
self.assertEqual(
aman["child.child2.samps"].count, aman.child.child2.samps.count
)
self.assertEqual(
aman["child.child2.samps"].offset, aman.child.child2.samps.offset
)

np.testing.assert_array_equal(aman["child.child2.tod"], np.zeros((2, 1000)))

with self.assertRaises(KeyError):
aman["child2"]

with self.assertRaises(AttributeError):
aman["child.dets.an_extra_layer"]

self.assertIn("child.dets", aman)
self.assertIn("child.dets2", aman) # I am not sure why this is true
self.assertNotIn("child.child2.someentry", aman)
self.assertNotIn("child.child2.someentry.someotherentry", aman)

with self.assertRaises(ValueError):
aman["child"] = child2

new_tods = np.ones((2, 500))
aman.child.child2.tod = new_tods
np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 500)))
np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 500)))

new_tods = np.ones((2, 1500))
aman["child.child2.tod"] = new_tods
np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 1500)))
np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 1500)))

# Multi-dimensional restrictions.

def test_200_multid(self):
Expand Down

0 comments on commit 34ffae6

Please sign in to comment.