Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement variable mu(z) computation #80

Open
tlmakinen opened this issue Jan 6, 2022 · 0 comments
Open

implement variable mu(z) computation #80

tlmakinen opened this issue Jan 6, 2022 · 0 comments

Comments

@tlmakinen
Copy link

it'd be great to implement a couple of functions to compute a mu(z) calculation (e.g. for supernova cosmology).

I've whipped up a couple of functions that work on my end within a Numpyro-based BHM sampler (based on BAHAMAS)

import jax.numpy as np
from jax import grad, jit, vmap, random, lax
import jax
import jax_cosmo as jc
import scipy.constants as cnst

inference_type = "w"

# the integrand in the calculation of mu from z,cosmology
@jit
def integrand(zba, omegam, omegade, w):
    return 1.0/np.sqrt(
        omegam*(1+zba)**3 + omegade*(1+zba)**(3.+3.*w) + (1.-omegam-omegade)*(1.+zba)**2
    )

# integration of the integrand given above, vmapped over z-axis
@jit
def hubble(z,omegam, omegade,w):
    # method for calculating the integral
    fn = lambda z: jc.scipy.integrate.romb(integrand,0., z, args=(omegam,omegade,w)) #[0]
    I = jax.vmap(fn)(z)
    return I

then we can compute a Dlz that changes with the curvature value omegakmag by defining a couple of lax conditional statements

@jit
def Dlz(omegam, omegade, h, z, w, z_helio):

    # which inference are we doing ?
    if inference_type == "omegade":
      omegakmag =  np.sqrt(np.abs(1-omegam-omegade))  
    else:
      omegakmag = 0.

    hubbleint = hubble(z,omegam,omegade,w)
    condition1 = (omegam + omegade == 1) # return True if = 1 
    condition2 = (omegam + omegade > 1.)

    #if (omegam+omegade)>1:
    def ifbigger(omegakmag):
      return (cnst.c*1e-5 *(1+z_helio)/(h*omegakmag)) * np.sin(hubbleint*omegakmag)

    # if (omegam+omegade)<1:
    def ifsmaller(omegakmag):
      return cnst.c*1e-5 *(1+z_helio)/(h*omegakmag) *np.sinh(hubbleint*omegakmag)   

    # if (omegam+omegade==1):
    def equalfun(omegakmag):
      return cnst.c*1e-5 *(1+z_helio)* hubbleint/h

    # if not equal, default to >1 condition
    def notequalfun(omegakmag):
      return lax.cond(condition2, true_fun=ifbigger, false_fun=ifsmaller, operand=omegakmag)

    distance = lax.cond(condition1, true_fun=equalfun, false_fun=notequalfun, operand=omegakmag)

    return distance


# muz: distance modulus as function of params, redshift
@jit
def muz(omegam, w, z):
    z_helio = z # should this be different ?
    omegade = 1. - omegam
    #w = -1.0 # freeze w
    h = 0.72
    return (5.0 * np.log10(Dlz(omegam, omegade, h, z, w, z_helio))+25.)

the calculation for 500 supernova distances is super quick:

zs = np.linspace(0, 1.2, num=500)
print('time to compute 500 SNIa distance integrals:')
%time mymus = muz(0.3, 0.7, zs)
plt.plot(zs, mymus, label='fid')

plt.xlabel(r'$z$')
plt.ylabel(r'$\mu(z; \mathcal{C})$')
plt.legend()
plt.show()
time to compute 500 SNIa distance integrals:
CPU times: user 1.23 s, sys: 9.98 ms, total: 1.24 s
Wall time: 1.28 s

image

A lot of this might be redundant, but would be great to see integrated into the full package. I'm polishing my sampler code on my end, so at the very least these functions could live over there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant