Skip to content

Commit

Permalink
Rename StokesPyTree -> Stokes
Browse files Browse the repository at this point in the history
  • Loading branch information
pchanial committed Jan 22, 2025
1 parent b055a07 commit 852a41c
Show file tree
Hide file tree
Showing 16 changed files with 183 additions and 183 deletions.
6 changes: 2 additions & 4 deletions src/furax/obs/_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from furax.obs._detectors import DetectorArray
from furax.obs._samplings import Sampling
from furax.obs.landscapes import HealpixLandscape
from furax.obs.stokes import StokesPyTree
from furax.obs.stokes import Stokes


def create_projection_operator(
Expand All @@ -30,9 +30,7 @@ def create_projection_operator(
# remove the number of directions per pixels if there is only one.
indices = indices.reshape(indices.shape[0], indices.shape[2])

tod_structure = StokesPyTree.class_for(landscape.stokes).structure_for(
indices.shape, landscape.dtype
)
tod_structure = Stokes.class_for(landscape.stokes).structure_for(indices.shape, landscape.dtype)

rotation = QURotationOperator(samplings.pa, tod_structure)
reshape = RavelOperator(in_structure=landscape.structure)
Expand Down
10 changes: 5 additions & 5 deletions src/furax/obs/landscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from functools import partial

from furax.obs.stokes import StokesPyTree, ValidStokesType
from furax.obs.stokes import Stokes, ValidStokesType

if sys.version_info < (3, 11):
from typing_extensions import Self
Expand Down Expand Up @@ -101,7 +101,7 @@ def size(self) -> int:

@property
def structure(self) -> PyTree[jax.ShapeDtypeStruct]:
cls = StokesPyTree.class_for(self.stokes)
cls = Stokes.class_for(self.stokes)
return cls.structure_for(self.shape, self.dtype)

def tree_flatten(self): # type: ignore[no-untyped-def]
Expand All @@ -113,17 +113,17 @@ def tree_flatten(self): # type: ignore[no-untyped-def]
return (), aux_data

def full(self, fill_value: ScalarLike) -> PyTree[Shaped[Array, ' {self.npixel}']]:
cls = StokesPyTree.class_for(self.stokes)
cls = Stokes.class_for(self.stokes)
return cls.full(self.shape, fill_value, self.dtype)

def normal(self, key: Key[Array, '']) -> PyTree[Shaped[Array, ' {self.npixel}']]:
cls = StokesPyTree.class_for(self.stokes)
cls = Stokes.class_for(self.stokes)
return cls.normal(key, self.shape, self.dtype)

def uniform(
self, key: Key[Array, ''], low: float = 0.0, high: float = 1.0
) -> PyTree[Shaped[Array, ' {self.npixel}']]:
cls = StokesPyTree.class_for(self.stokes)
cls = Stokes.class_for(self.stokes)
return cls.uniform(self.shape, key, self.dtype, low, high)

def get_coverage(self, arg: Sampling) -> Integer[Array, ' 12*nside**2']:
Expand Down
28 changes: 14 additions & 14 deletions src/furax/obs/operators/_hwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from furax.core.rules import AbstractBinaryRule

from ..stokes import (
StokesIPyTree,
StokesIQUPyTree,
StokesIQUVPyTree,
StokesPyTree,
Stokes,
StokesI,
StokesIQU,
StokesIQUV,
StokesPyTreeType,
StokesQUPyTree,
StokesQU,
ValidStokesType,
)
from ._qu_rotations import QURotationOperator, QURotationTransposeOperator
Expand All @@ -35,23 +35,23 @@ def create(
*,
angles: Float[Array, '...'] | None = None,
) -> AbstractLinearOperator:
in_structure = StokesPyTree.class_for(stokes).structure_for(shape, dtype)
in_structure = Stokes.class_for(stokes).structure_for(shape, dtype)
hwp = cls(in_structure)
if angles is None:
return hwp
rot = QURotationOperator(angles, in_structure)
rotated_hwp: AbstractLinearOperator = rot.T @ hwp @ rot
return rotated_hwp

def mv(self, x: StokesPyTreeType) -> StokesPyTree:
if isinstance(x, StokesIPyTree):
def mv(self, x: StokesPyTreeType) -> Stokes:
if isinstance(x, StokesI):
return x
if isinstance(x, StokesQUPyTree):
return StokesQUPyTree(x.q, -x.u)
if isinstance(x, StokesIQUPyTree):
return StokesIQUPyTree(x.i, x.q, -x.u)
if isinstance(x, StokesIQUVPyTree):
return StokesIQUVPyTree(x.i, x.q, -x.u, -x.v)
if isinstance(x, StokesQU):
return StokesQU(x.q, -x.u)
if isinstance(x, StokesIQU):
return StokesIQU(x.i, x.q, -x.u)
if isinstance(x, StokesIQUV):
return StokesIQUV(x.i, x.q, -x.u, -x.v)
raise NotImplementedError

def in_structure(self) -> PyTree[jax.ShapeDtypeStruct]:
Expand Down
12 changes: 6 additions & 6 deletions src/furax/obs/operators/_polarizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from furax.core.rules import AbstractBinaryRule

from ..stokes import (
StokesIPyTree,
StokesPyTree,
Stokes,
StokesI,
StokesPyTreeType,
StokesQUPyTree,
StokesQU,
ValidStokesType,
)
from ._hwp import HWPOperator
Expand All @@ -32,7 +32,7 @@ def create(
*,
angles: Float[Array, '...'] | None = None,
) -> AbstractLinearOperator:
in_structure = StokesPyTree.class_for(stokes).structure_for(shape, dtype)
in_structure = Stokes.class_for(stokes).structure_for(shape, dtype)
polarizer = cls(in_structure)
if angles is None:
return polarizer
Expand All @@ -41,9 +41,9 @@ def create(
return rotated_polarizer

def mv(self, x: StokesPyTreeType) -> Float[Array, '...']:
if isinstance(x, StokesIPyTree):
if isinstance(x, StokesI):
return 0.5 * x.i
if isinstance(x, StokesQUPyTree):
if isinstance(x, StokesQU):
return 0.5 * x.q
return 0.5 * (x.i + x.q)

Expand Down
40 changes: 20 additions & 20 deletions src/furax/obs/operators/_qu_rotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from furax.core.rules import AbstractBinaryRule, NoReduction

from ..stokes import (
StokesIPyTree,
StokesIQUPyTree,
StokesIQUVPyTree,
StokesPyTree,
Stokes,
StokesI,
StokesIQU,
StokesIQUV,
StokesPyTreeType,
StokesQUPyTree,
StokesQU,
ValidStokesType,
)

Expand All @@ -42,24 +42,24 @@ def create(
*,
angles: Float[Array, '...'],
) -> AbstractLinearOperator:
structure = StokesPyTree.class_for(stokes).structure_for(shape, dtype)
structure = Stokes.class_for(stokes).structure_for(shape, dtype)
return cls(angles, structure)

def mv(self, x: StokesPyTreeType) -> StokesPyTreeType:
if isinstance(x, StokesIPyTree):
if isinstance(x, StokesI):
return x

cos_2angles = jnp.cos(2 * self.angles)
sin_2angles = jnp.sin(2 * self.angles)
q = x.q * cos_2angles - x.u * sin_2angles
u = x.q * sin_2angles + x.u * cos_2angles

if isinstance(x, StokesQUPyTree):
return StokesQUPyTree(q, u)
if isinstance(x, StokesIQUPyTree):
return StokesIQUPyTree(x.i, q, u)
if isinstance(x, StokesIQUVPyTree):
return StokesIQUVPyTree(x.i, q, u, x.v)
if isinstance(x, StokesQU):
return StokesQU(q, u)
if isinstance(x, StokesIQU):
return StokesIQU(x.i, q, u)
if isinstance(x, StokesIQUV):
return StokesIQUV(x.i, q, u, x.v)
raise NotImplementedError

def transpose(self) -> AbstractLinearOperator:
Expand All @@ -73,20 +73,20 @@ class QURotationTransposeOperator(AbstractLazyInverseOrthogonalOperator):
operator: QURotationOperator

def mv(self, x: StokesPyTreeType) -> StokesPyTreeType:
if isinstance(x, StokesIPyTree):
if isinstance(x, StokesI):
return x

cos_2angles = jnp.cos(2 * self.operator.angles)
sin_2angles = jnp.sin(2 * self.operator.angles)
q = x.q * cos_2angles + x.u * sin_2angles
u = -x.q * sin_2angles + x.u * cos_2angles

if isinstance(x, StokesQUPyTree):
return StokesQUPyTree(q, u)
if isinstance(x, StokesIQUPyTree):
return StokesIQUPyTree(x.i, q, u)
if isinstance(x, StokesIQUVPyTree):
return StokesIQUVPyTree(x.i, q, u, x.v)
if isinstance(x, StokesQU):
return StokesQU(q, u)
if isinstance(x, StokesIQU):
return StokesIQU(x.i, q, u)
if isinstance(x, StokesIQUV):
return StokesIQUV(x.i, q, u, x.v)
raise NotImplementedError


Expand Down
Loading

0 comments on commit 852a41c

Please sign in to comment.