Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/seds-operator' into seds-operator
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Dec 9, 2024
2 parents 3fe96e8 + 5ad9cfc commit c32455a
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 2 deletions.
62 changes: 60 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import get_args

import os
import healpy as hp
import jax
import numpy as np
import pytest
from jaxtyping import Array, Float

from furax.landscapes import StokesIQUPyTree, ValidStokesType
from tests.helpers import TEST_DATA_PLANCK, TEST_DATA_SAT
from furax.landscapes import StokesPyTree, StokesIQUPyTree, ValidStokesType, HealpixLandscape
from tests.helpers import TEST_DATA_PLANCK, TEST_DATA_SAT, TEST_DATA_FGBUSTER


def load_planck(nside: int) -> np.array:
Expand Down Expand Up @@ -47,3 +48,60 @@ def sat_nhits() -> Float[Array, '...']:
def stokes(request: pytest.FixtureRequest) -> ValidStokesType:
"""Parametrized fixture for I, QU, IQU and IQUV."""
return request.param


@pytest.fixture(scope='session')
def get_fgbuster_data():
os.makedirs(TEST_DATA_FGBUSTER, exist_ok=True)
# Check if file already exists
data_filename = f'{TEST_DATA_FGBUSTER}/fgbuster_data.npz'
nside = 32
stokes_type = 'IQU'
in_structure = HealpixLandscape(nside, stokes_type).structure
try:
# If the file already exists, we can skip data generation
fg_data = np.load(data_filename)
print(f"Data file '{data_filename}' already exists, skipping generation.")
freq_maps: Array = fg_data['freq_maps']
d = StokesPyTree.from_stokes(
I=freq_maps[:, 0, :], Q=freq_maps[:, 1, :], U=freq_maps[:, 2, :]
)
return fg_data, d, in_structure
except FileNotFoundError:
try:
from fgbuster import CMB, Dust, Synchrotron, get_observation, get_instrument
except ImportError:
raise ImportError(
'fgbuster is not installed. Please install it using `pip install fgbuster`'
)
instrument = get_instrument('LiteBIRD')
freq_maps = get_observation(instrument, 'c1d0s0', nside=nside)
nu = instrument['frequency'].values

# Generate FGBuster components
cmb_fgbuster_K_CMB = CMB().eval(nu)
dust_fgbuster_K_CMB = Dust(150.0).eval(nu, 1.54, 20.0)
synchrotron_fgbuster_K_CMB = Synchrotron(20.0).eval(nu, -3.0)

cmb_fgbuster_K_RJ = CMB(units='K_RJ').eval(nu)
dust_fgbuster_K_RJ = Dust(150.0, units='K_RJ').eval(nu, 1.54, 20.0)
synchrotron_fgbuster_K_RJ = Synchrotron(20.0, units='K_RJ').eval(nu, -3.0)

fg_data = {
'frequencies': nu,
'freq_maps': freq_maps,
'CMB_K_CMB': cmb_fgbuster_K_CMB,
'DUST_K_CMB': dust_fgbuster_K_CMB,
'SYNC_K_CMB': synchrotron_fgbuster_K_CMB,
'CMB_K_RJ': cmb_fgbuster_K_RJ,
'DUST_K_RJ': dust_fgbuster_K_RJ,
'SYNC_K_RJ': synchrotron_fgbuster_K_RJ,
}
# Save all required arrays to an .npz file
np.savez(data_filename, **fg_data)
print(f"Data saved to '{data_filename}'")

d = StokesPyTree.from_stokes(
I=freq_maps[:, 0, :], Q=freq_maps[:, 1, :], U=freq_maps[:, 2, :]
)
return fg_data, d, in_structure
Binary file added tests/data/fgbuster/fgbuster_data.npz
Binary file not shown.
1 change: 1 addition & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TEST_DATA_PLANCK = TEST_DATA / 'planck'
TEST_DATA_SAT = TEST_DATA / 'sat'
TEST_DATA_SEDS = Path(__file__).parent / 'seds/data'
TEST_DATA_FGBUSTER = TEST_DATA / 'fgbuster'


def arange(*shape: int, dtype=jnp.float32, start=1) -> jax.Array:
Expand Down
108 changes: 108 additions & 0 deletions tests/test_compsep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import jax
import jax.numpy as jnp
from furax.operators.seds import CMBOperator, DustOperator, SynchrotronOperator
from furax.landscapes import StokesPyTree


def test_cmb_k_cmb(get_fgbuster_data):
fg_data, d, in_structure = get_fgbuster_data
nu = fg_data['frequencies']

# Calculate CMB with K_CMB unit in furax
cmb_fgbuster = fg_data['CMB_K_CMB'][..., jnp.newaxis, jnp.newaxis] * fg_data['freq_maps']
cmb_fgbuster_tree = StokesPyTree.from_stokes(
I=cmb_fgbuster[:, 0, :], Q=cmb_fgbuster[:, 1, :], U=cmb_fgbuster[:, 2, :]
)

cmb_operator = CMBOperator(nu, in_structure=in_structure, units='K_CMB')
cmb_furax = cmb_operator(d)

assert jax.tree.all(jax.tree.map(jnp.allclose, cmb_furax, cmb_fgbuster_tree))


def test_cmb_k_rj(get_fgbuster_data):
fg_data, d, in_structure = get_fgbuster_data
nu = fg_data['frequencies']

# Calculate CMB with K_RJ unit in furax
cmb_fgbuster = fg_data['CMB_K_RJ'][..., jnp.newaxis, jnp.newaxis] * fg_data['freq_maps']
cmb_fgbuster_tree = StokesPyTree.from_stokes(
I=cmb_fgbuster[:, 0, :], Q=cmb_fgbuster[:, 1, :], U=cmb_fgbuster[:, 2, :]
)

cmb_operator = CMBOperator(nu, in_structure=in_structure, units='K_RJ')
cmb_furax = cmb_operator(d)

assert jax.tree.all(jax.tree.map(jnp.allclose, cmb_furax, cmb_fgbuster_tree))


def test_dust_k_cmb(get_fgbuster_data):
fg_data, d, in_structure = get_fgbuster_data
nu = fg_data['frequencies']

# Calculate Dust with K_CMB unit in furax
dust_fgbuster = fg_data['DUST_K_CMB'][..., jnp.newaxis, jnp.newaxis] * fg_data['freq_maps']
dust_fgbuster_tree = StokesPyTree.from_stokes(
I=dust_fgbuster[:, 0, :], Q=dust_fgbuster[:, 1, :], U=dust_fgbuster[:, 2, :]
)

dust_operator = DustOperator(
nu, in_structure=in_structure, frequency0=150.0, units='K_CMB', temperature=20.0, beta=1.54
)
dust_furax = dust_operator(d)

assert jax.tree.all(jax.tree.map(jnp.allclose, dust_furax, dust_fgbuster_tree))


def test_dust_k_rj(get_fgbuster_data):
fg_data, d, in_structure = get_fgbuster_data
nu = fg_data['frequencies']

# Calculate Dust with K_RJ unit in furax
dust_fgbuster = fg_data['DUST_K_RJ'][..., jnp.newaxis, jnp.newaxis] * fg_data['freq_maps']
dust_fgbuster_tree = StokesPyTree.from_stokes(
I=dust_fgbuster[:, 0, :], Q=dust_fgbuster[:, 1, :], U=dust_fgbuster[:, 2, :]
)

dust_operator = DustOperator(
nu, in_structure=in_structure, frequency0=150.0, units='K_RJ', temperature=20.0, beta=1.54
)
dust_furax = dust_operator(d)

assert jax.tree.all(jax.tree.map(jnp.allclose, dust_furax, dust_fgbuster_tree))


def test_synchrotron_k_cmb(get_fgbuster_data):
fg_data, d, in_structure = get_fgbuster_data
nu = fg_data['frequencies']

# Calculate Synchrotron with K_CMB unit in furax
synch_fgbuster = fg_data['SYNC_K_CMB'][..., jnp.newaxis, jnp.newaxis] * fg_data['freq_maps']
synch_fgbuster_tree = StokesPyTree.from_stokes(
I=synch_fgbuster[:, 0, :], Q=synch_fgbuster[:, 1, :], U=synch_fgbuster[:, 2, :]
)

synch_operator = SynchrotronOperator(
nu, in_structure=in_structure, frequency0=20.0, units='K_CMB', beta_pl=-3.0
)
synch_furax = synch_operator(d)

assert jax.tree.all(jax.tree.map(jnp.allclose, synch_furax, synch_fgbuster_tree))


def test_synchrotron_k_rj(get_fgbuster_data):
fg_data, d, in_structure = get_fgbuster_data
nu = fg_data['frequencies']

# Calculate Synchrotron with K_RJ unit in furax
synch_fgbuster = fg_data['SYNC_K_RJ'][..., jnp.newaxis, jnp.newaxis] * fg_data['freq_maps']
synch_fgbuster_tree = StokesPyTree.from_stokes(
I=synch_fgbuster[:, 0, :], Q=synch_fgbuster[:, 1, :], U=synch_fgbuster[:, 2, :]
)

synch_operator = SynchrotronOperator(
nu, in_structure=in_structure, frequency0=20.0, units='K_RJ', beta_pl=-3.0
)
synch_furax = synch_operator(d)

assert jax.tree.all(jax.tree.map(jnp.allclose, synch_furax, synch_fgbuster_tree))

0 comments on commit c32455a

Please sign in to comment.