Skip to content

Commit

Permalink
added asserts to check shape and types of inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
nestor committed Dec 6, 2023
1 parent 8e0921a commit 81b7c90
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
17 changes: 15 additions & 2 deletions src/toast/jax/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,21 @@ def check_pytree_axis(data, axis, info=""):
for i, (d, a) in enumerate(zip(data, axis)):
check_pytree_axis(d, a, f"{info}[{i}]")
elif isinstance(axis, type):
assert isinstance(data, type), f"{info} type ({type(data).__name__}) does not match provided axis ({pytree_to_string(axis)})"
# we do not cover the case of single values as they are assumed to be matching
is_single_number_tracer = isinstance(data, jnp.ndarray) and (data.size == 1)
data_type = data.dtype if is_single_number_tracer else type(data) # deals with JAX tracers being sorts of arrays
if jnp.issubdtype(axis, jnp.integer):
# integer types all batched together to simplify axis writing
assert jnp.issubdtype(data_type, jnp.integer), f"{info} type ({data_type.__name__}) does not match provided axis ({pytree_to_string(axis)})"
elif jnp.issubdtype(axis, jnp.floating):
# float types all batched together to simplify axis writing
assert jnp.issubdtype(data_type, jnp.floating), f"{info} type ({data_type.__name__}) does not match provided axis ({pytree_to_string(axis)})"
elif jnp.issubdtype(axis, bool):
# bool types all batched together to simplify axis writing
assert jnp.issubdtype(data_type, bool), f"{info} type ({data_type.__name__}) does not match provided axis ({pytree_to_string(axis)})"
else:
# other, more general, types
assert isinstance(data, axis), f"{info} type ({data_type.__name__}) does not match provided axis ({pytree_to_string(axis)})"
# we do not cover the case of other single values as they are assumed to be matching

def find_in_pytree(condition, structure):
"""
Expand Down
9 changes: 4 additions & 5 deletions src/toast/ops/pixels_healpix/kernels_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,15 @@ def pixels_healpix_interval(
use_flags = False

# does the computation
dummy_sub_map = jnp.empty_like(pixels_indexed) # jnp.zeros_like(pixels_indexed)
dummy_hit_submaps = jnp.empty_like(pixels_indexed) #hit_submaps[dummy_sub_map]
dummy_sub_map = jnp.zeros_like(pixels_indexed)
dummy_hit_submaps = hit_submaps[dummy_sub_map]
outputs = (pixels_indexed, dummy_sub_map, dummy_hit_submaps)
new_pixels_indexed, sub_map, new_hit_submaps = pixels_healpix_inner(quats_indexed, use_flags, flags, flag_mask, hit_submaps, n_pix_submap, hpix, nest,
interval_starts, interval_ends, intervals_max_length,
outputs)

print(f"DEBUGGING: use_flags:{use_flags}")
# TODO sub_map:(1, 360000) n_samp:360000 n_intervals:1 intervals_max_length:360000
# TODO check order of all things?
# TODO why the performance drop?
# TODO use_flags:True sub_map:(1, 360000) n_samp:360000 n_intervals:1 intervals_max_length:360000

# updates results and returns
pixels = pixels.at[pixel_index,:].set(new_pixels_indexed)
Expand Down

0 comments on commit 81b7c90

Please sign in to comment.