Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gp/feat/aman arbitrary path #1057

Merged
merged 10 commits into from
Jan 10, 2025
Merged
20 changes: 20 additions & 0 deletions docs/axisman.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,26 @@ The output of the ``wrap`` cal should be::
Note the boresight entry is marked with a ``*``, indicating that it's
an AxisManager rather than a numpy array.

Data access under an AxisManager is done based on field names. For example:
iparask marked this conversation as resolved.
Show resolved Hide resolved
>>> n, ofs = 1000, 0
>>> dets = ["det0", "det1", "det2"]
>>> aman = core.AxisManager(core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs))
>>> child = core.AxisManager(core.LabelAxis("dets", dets + ["det3"]),core.OffsetAxis("samps", n, ofs - n // 2),)
>>> aman.wrap("child", child)
>>> print(aman.child.dets)
LabelAxis(3:'det0','det1','det2')

Advanced data access is possible by a path like syntax. This is especially useful when
data access is dynamic and the field name is not known in advance. For example:

>>> n, ofs = 1000, 0
>>> dets = ["det0", "det1", "det2"]
>>> aman = core.AxisManager(core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs))
>>> child = core.AxisManager(core.LabelAxis("dets", dets + ["det3"]),core.OffsetAxis("samps", n, ofs - n // 2),)
>>> aman.wrap("child", child)
>>> print(aman["child.dets"])
LabelAxis(3:'det0','det1','det2')

To slice this object, use the restrict() method. First, let's
restrict in the 'dets' axis. Since it's an Axis of type LabelAxis,
the restriction selector must be a list of strings::
Expand Down
69 changes: 52 additions & 17 deletions sotodlib/core/axisman.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,29 +349,60 @@ def move(self, name, new_name):
self._fields[new_name] = self._fields.pop(name)
self._assignments[new_name] = self._assignments.pop(name)
return self

def add_axis(self, a):
assert isinstance( a, AxisInterface)
self._axes[a.name] = a.copy()

def __contains__(self, name):
return name in self._fields or name in self._axes
attrs = name.split(".")
tmp_item = self
while attrs:
attr_name = attrs.pop(0)
if attr_name in tmp_item._fields:
tmp_item = tmp_item._fields[attr_name]
elif attr_name in tmp_item._axes:
tmp_item = tmp_item._axes[attr_name]
else:
return False
return True

def __getitem__(self, name):
if name in self._fields:
return self._fields[name]
if name in self._axes:
return self._axes[name]
raise KeyError(name)

# We want to support options like:
# aman.focal_plane.xi . aman['focal_plane.xi']
# We will safely assume that a getitem will always have '.' as the separator
attrs = name.split(".")
tmp_item = self
while attrs:
attr_name = attrs.pop(0)
if attr_name in tmp_item._fields:
tmp_item = tmp_item._fields[attr_name]
elif attr_name in tmp_item._axes:
tmp_item = tmp_item._axes[attr_name]
else:
raise KeyError(attr_name)
return tmp_item

def __setitem__(self, name, val):
if name in self._fields:
self._fields[name] = val
else:
raise KeyError(name)

last_pos = name.rfind(".")
val_key = name
tmp_item = self
if last_pos > -1:
val_key = name[last_pos + 1:]
attrs = name[:last_pos]
tmp_item = self[attrs]

if isinstance(val, AxisManager) and isinstance(tmp_item, AxisManager):
raise ValueError("Cannot assign AxisManager to AxisManager. Please use wrap method.")

tmp_item.__setattr__(val_key, val)
iparask marked this conversation as resolved.
Show resolved Hide resolved

def __setattr__(self, name, value):
# Assignment to members update those members
# We will assume that a path exists until the last member.
# If any member prior to that does not exist a keyerror is raised.
if "_fields" in self.__dict__ and name in self._fields.keys():
self._fields[name] = value
else:
Expand All @@ -381,7 +412,11 @@ def __setattr__(self, name, value):
def __getattr__(self, name):
# Prevent members from override special class members.
if name.startswith("__"): raise AttributeError(name)
return self[name]
try:
val = self[name]
except KeyError as ex:
raise AttributeError(name) from ex
return val

def __dir__(self):
return sorted(tuple(self.__dict__.keys()) + tuple(self.keys()))
Expand Down Expand Up @@ -514,27 +549,27 @@ def concatenate(items, axis=0, other_fields='exact'):
output.wrap(k, new_data[k], axis_map)
else:
if other_fields == "exact":
## if every item named k is a scalar
## if every item named k is a scalar
err_msg = (f"The field '{k}' does not share axis '{axis}'; "
f"{k} is not identical across all items "
f"pass other_fields='drop' or 'first' or else "
f"remove this field from the targets.")

if np.any([np.isscalar(i[k]) for i in items]):
if not np.all([np.isscalar(i[k]) for i in items]):
raise ValueError(err_msg)
if not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]):
raise ValueError(err_msg)
output.wrap(k, items[0][k], axis_map)
continue

elif not np.all([i[k].shape==items[0][k].shape for i in items]):
raise ValueError(err_msg)
elif not np.all([np.array_equal(i[k], items[0][k], equal_nan=True) for i in items]):
raise ValueError(err_msg)

output.wrap(k, items[0][k].copy(), axis_map)

elif other_fields == 'fail':
raise ValueError(
f"The field '{k}' does not share axis '{axis}'; "
Expand Down
2 changes: 1 addition & 1 deletion sotodlib/core/axisman_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def expand_RangesMatrix(flat_rm):
if shape[0] == 0:
return so3g.proj.RangesMatrix([], child_shape=shape[1:])
# Otherwise non-trivial
count = np.product(shape[:-1])
count = np.prod(shape[:-1])
start, stride = 0, count // shape[0]
for i in range(0, len(ends), stride):
_e = ends[i:i+stride] - start
Expand Down
61 changes: 56 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import shutil

from networkx import selfloop_edges
iparask marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import astropy.units as u
from sotodlib import core
Expand Down Expand Up @@ -66,7 +67,7 @@ def test_130_not_inplace(self):

# This should return a separate thing.
rman = aman.restrict('samps', (10, 30), in_place=False)
#self.assertNotEqual(aman.a1[0], 0.)
# self.assertNotEqual(aman.a1[0], 0.)
self.assertEqual(len(aman.a1), 100)
self.assertEqual(len(rman.a1), 20)
self.assertNotEqual(aman.a1[10], 0.)
Expand Down Expand Up @@ -190,23 +191,23 @@ def test_170_concat(self):

# ... other_fields="exact"
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

## add scalars
amanA.wrap("ans", 42)
amanB.wrap("ans", 42)
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="exact"
amanB.azimuth[:] = 2.
with self.assertRaises(ValueError):
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="exact" and arrays of different shapes
amanB.move("azimuth", None)
amanB.wrap("azimuth", np.array([43,5,2,3]))
with self.assertRaises(ValueError):
aman = core.AxisManager.concatenate([amanA, amanB], axis='dets')

# ... other_fields="fail"
amanB.move("azimuth",None)
amanB.wrap_new('azimuth', shape=('samps',))[:] = 2.
Expand Down Expand Up @@ -298,6 +299,56 @@ def test_300_restrict(self):
self.assertNotEqual(aman.a1[0, 0, 0, 1], 0.)

# wrap of AxisManager, merge.
def test_get_set(self):
iparask marked this conversation as resolved.
Show resolved Hide resolved
dets = ["det0", "det1", "det2"]
n, ofs = 1000, 0
aman = core.AxisManager(
core.LabelAxis("dets", dets), core.OffsetAxis("samps", n, ofs)
)
child = core.AxisManager(
core.LabelAxis("dets", dets + ["det3"]),
core.OffsetAxis("samps", n, ofs - n // 2),
)

child2 = core.AxisManager(
core.LabelAxis("dets2", ["det4", "det5"]),
core.OffsetAxis("samps", n, ofs - n // 2),
)
child2.wrap("tod", np.zeros((2,1000)))
aman.wrap("child", child)
aman["child"].wrap("child2", child2)
iparask marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(aman["child.child2.dets2"].count, 2)
self.assertEqual(aman["child.dets"].name, "dets")
np.testing.assert_array_equal(aman["child.child2.dets2"].vals, np.array(["det4", "det5"]))
self.assertEqual(aman["child.child2.samps"].count, n // 2)
self.assertEqual(aman["child.child2.samps"].offset, 0)
self.assertEqual(aman["child.child2.samps"].count, aman.child.child2.samps.count)
self.assertEqual(aman["child.child2.samps"].offset, aman.child.child2.samps.offset)

np.testing.assert_array_equal(aman["child.child2.tod"], np.zeros((2,1000)))

with self.assertRaises(KeyError):
aman["child2"]

with self.assertRaises(AttributeError):
aman["child.dets.an_extra_layer"]

self.assertIn("child.dets", aman)
self.assertIn("child.dets2", aman) # I am not sure why this is true
iparask marked this conversation as resolved.
Show resolved Hide resolved
self.assertNotIn("child.child2.someentry", aman)
iparask marked this conversation as resolved.
Show resolved Hide resolved

with self.assertRaises(ValueError):
aman["child"] = child2

new_tods = np.ones((2, 500))
aman.child.child2.tod = new_tods
np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 500)))
np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 500)))

new_tods = np.ones((2, 1500))
aman["child.child2.tod"] = new_tods
np.testing.assert_array_equal(aman["child.child2.tod"], np.ones((2, 1500)))
np.testing.assert_array_equal(aman.child.child2.tod, np.ones((2, 1500)))

def test_400_child(self):
dets = ['det0', 'det1', 'det2']
Expand Down
Loading