Skip to content

Commit

Permalink
Move WCSLandscape from sotodlib_utils to obs.landscape
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuhyun committed Feb 11, 2025
1 parent 2003a0b commit 0a7db5d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 32 deletions.
33 changes: 1 addition & 32 deletions src/furax/interfaces/toast/mapmaker/sotodlib_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import matplotlib.pyplot as plt
import numpy as np
import pixell
from astropy.wcs import WCS
from jax import Array, ShapeDtypeStruct
from jaxtyping import Bool, DTypeLike, Float, Inexact, Integer, PyTree
from numpy.typing import NDArray
Expand All @@ -31,44 +30,14 @@
from furax.mapmaking.preconditioner import BJPreconditioner
from furax.mapmaking.utils import psd_to_invntt
from furax.obs import QURotationOperator
from furax.obs.landscapes import HealpixLandscape, StokesLandscape
from furax.obs.landscapes import HealpixLandscape, WCSLandscape
from furax.obs.stokes import Stokes, StokesIQU, StokesPyTreeType, ValidStokesType

from . import templates

""" Custom FURAX classes and operators """


class WCSLandscape(StokesLandscape):
"""Stokes PyTree for WCS maps
Not fully implemented yet, potentially should be added to FURAX
"""

def __init__(
self,
shape: tuple[int, ...],
wcs: WCS,
stokes: ValidStokesType,
dtype: DTypeLike = np.float32,
) -> None:
super().__init__(shape, stokes, dtype)
self.wcs = wcs

def tree_flatten(self): # type: ignore[no-untyped-def]
aux_data = {
'shape': self.shape,
'dtype': self.dtype,
'stokes': self.stokes,
'wcs': self.wcs,
} # static values
return (), aux_data

def world2pixel(
self, theta: Float[Array, '...'], phi: Float[Array, '...']
) -> tuple[Float[Array, '...'], ...]:
raise NotImplementedError()


class StokesIndexOperator(AbstractLinearOperator):
"""Operator for integer index operation on Stokes PyTrees
The indices are assumed to be identical for I, Q and U
Expand Down
50 changes: 50 additions & 0 deletions src/furax/obs/landscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import jax.numpy as jnp
import jax_healpy as jhp
import numpy as np
from astropy.coordinates import SkyCoord
from astropy.units import u
from astropy.wcs import WCS
from jaxtyping import Array, DTypeLike, Float, Integer, Key, PyTree, ScalarLike, Shaped

from furax.obs._samplings import Sampling
Expand Down Expand Up @@ -256,3 +259,50 @@ def tree_flatten(self): # type: ignore[no-untyped-def]
'frequencies': self.frequencies,
} # static values
return (), aux_data


@jax.tree_util.register_pytree_node_class
class WCSLandscape(StokesLandscape):
"""Class representing an astropy WCS map of Stokes vectors."""

def __init__(
self,
shape: tuple[int, ...],
wcs: WCS,
stokes: ValidStokesType = 'IQU',
dtype: DTypeLike = np.float64,
) -> None:
super().__init__(shape, stokes, dtype)
self.wcs = wcs

def tree_flatten(self): # type: ignore[no-untyped-def]
aux_data = {
'shape': self.shape,
'dtype': self.dtype,
'stokes': self.stokes,
'wcs': self.wcs,
} # static values
return (), aux_data

def world2pixel(
self, theta: Float[Array, ' *dims'], phi: Float[Array, ' *dims']
) -> tuple[Integer[Array, ' *dims'], ...]:
r"""Convert angles to WCS map indices.
Args:
theta (float): Spherical :math:`\theta` angle.
phi (float): Spherical :math:`\phi` angle.
Returns:
WCS map index pairs
"""

def f(theta, phi): # type: ignore[no-untyped-def]
# SkyCoord takes (lon,lat)
pix_i, pix_j = self.wcs.world_to_pixel(SkyCoord(phi, (np.pi / 2 - theta), unit=u.rad))
return tuple(np.array(np.round([pix_i, pix_j]), dtype=np.int64))

struct = jax.ShapeDtypeStruct(theta.shape, jnp.int64)
result_shape = (struct, struct)

return jax.pure_callback(f, result_shape, theta, phi) # type: ignore[no-any-return]

0 comments on commit 0a7db5d

Please sign in to comment.