Skip to content

Commit

Permalink
Merge pull request #2 from beckermr/photon-fixed
Browse files Browse the repository at this point in the history
ENH add fixed photon arrays
  • Loading branch information
beckermr authored Nov 21, 2023
2 parents 63e40d5 + 9be573f commit de1aad9
Show file tree
Hide file tree
Showing 11 changed files with 777 additions and 606 deletions.
412 changes: 185 additions & 227 deletions jax_galsim/core/draw.py

Large diffs are not rendered by default.

340 changes: 200 additions & 140 deletions jax_galsim/gsobject.py

Large diffs are not rendered by default.

168 changes: 30 additions & 138 deletions jax_galsim/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,104 +11,14 @@
import jax.numpy as jnp
from galsim.errors import GalSimValueError
from jax._src.numpy.util import _wraps
from jax.tree_util import Partial as jax_partial
from jax.tree_util import register_pytree_node_class

from jax_galsim.bessel import si
from jax_galsim.core.utils import is_equal_with_arrays
from jax_galsim.errors import GalSimError
from jax_galsim.gsparams import GSParams
from jax_galsim.random import UniformDeviate


@jax.jit
def _rejection_sample(photons, rng, tot_xrange, xval, pos_flux, neg_flux, max_val):
"""Use rejection sampling to generate photons from a given 1D interpolant function.
We sample both x and y values from the interpolant function.
Parameters
----------
photons : PhotonArray
The photon array to shoot into.
rng : BaseDeviate
The random number generator to use for drawing photons.
tot_xrange : float
The total range of the interpolant function from the most negative
point to the most positive point. The interpolant is assumed to be
symmetric about zero.
xval : callable
The interpolant function. Will only be called with positive values.
pos_flux : float
The total integral under all positive regions of the interpolant function.
neg_flux : float
The absolute value of the total integral under all negative regions of
the interpolant function.
max_val : float
The maximum value of the interpolant function. Usually this is xval(0.0) and
is 1.0.
"""

def _cond_fun(args):
# we stop drawing when we have tot photons
# curr records how many we have currently
_, _, tot, _, curr = args
return curr < tot

def _body_fun(args):
arr, sign, tot, ud, curr = args
# arr is the array we are filling with photon positions
# sign is the array of signs of the interpolant function at the photon positions
# tot is the total number of photons to draw
# ud is the random number generator for uniform deviates from 0 to 1
# curr is the current number of photons drawn

# we first draw a random x location centered at zero with a
# total range of tot_xrange
xloc = (ud() - 0.5) * tot_xrange

# next we draw a random y value between 0 and max_val
yv = ud() * max_val
xloc_val = xval(xloc)

# this cond operator keeps the photon if the y value we drew is
# below the interpolant function at the x location we drew
arr, sign, curr = jax.lax.cond(
yv <= jnp.abs(xloc_val),
# if we keep it, assign the location, assign the sign, and increment curr
lambda arr, sign, curr, xloc, xloc_val: (
arr.at[curr].set(xloc),
sign.at[curr].set(jnp.sign(xloc_val)),
curr + 1,
),
# otherwise we pass
lambda arr, sign, curr, xloc, xloc_val: (arr, sign, curr),
arr,
sign,
curr,
xloc,
xloc_val,
)
return arr, sign, tot, ud, curr

ud = UniformDeviate(rng)

# we first make the x and y positions
photons.x, _sign_x, _, ud, _ = jax.lax.while_loop(
_cond_fun,
_body_fun,
(jnp.zeros_like(photons.x), jnp.zeros_like(photons.x), photons.size(), ud, 0),
)
photons.y, _sign_y, _, ud, _ = jax.lax.while_loop(
_cond_fun,
_body_fun,
(jnp.zeros_like(photons.y), jnp.zeros_like(photons.y), photons.size(), ud, 0),
)
# this magic formula comes from looking closely at the galsim code in Interpolant.cpp
# and how things get adjusted down the line OneDimensionalDeviate.cpp
flux_per = (pos_flux + neg_flux) ** 2 / photons.size()
photons.flux = _sign_x * _sign_y * flux_per
return photons, rng
from jax_galsim.utilities import lazy_property


@_wraps(_galsim.interpolant.Interpolant)
Expand Down Expand Up @@ -350,9 +260,36 @@ def urange(self):
% self.__class__.__name__
)

@lazy_property
def _shoot_cdf(self):
x = jnp.linspace(-self.xrange, self.xrange, 10000)
px = jnp.abs(self._xval_noraise(jnp.abs(x)))
dx = x[1] - x[0]
# cumulative trapezoidal rule
# see scipy.integrate.cumulative_trapezoidal
cdfx = jnp.concatenate(
[jnp.array([0]), jnp.cumsum((px[1:] + px[:-1]) * 0.5 * dx)]
)
cdfx /= cdfx[-1]
return x, cdfx

def _shoot(self, photons, rng):
raise NotImplementedError(
"%s does not implement shoot" % self.__class__.__name__
x, cdfx = self._shoot_cdf
ud = UniformDeviate(rng)
ux = ud.generate(photons.x)
uy = ud.generate(photons.y)
photons.x = jnp.interp(ux, cdfx, x)
photons.y = jnp.interp(uy, cdfx, x)
if photons.size() > 0:
flux_per_photon = (
self.positive_flux + self.negative_flux
) ** 2 / photons.size()
else:
flux_per_photon = 0.0
photons.flux = (
flux_per_photon
* jnp.sign(self._xval_noraise(photons.x))
* jnp.sign(self._xval_noraise(photons.y))
)

# subclasses should implement __init__, _xval, _uval,
Expand Down Expand Up @@ -644,21 +581,6 @@ def ixrange(self):
"""The total integral range of the interpolant. Typically 2 * xrange."""
return 4

def _shoot(self, photons, rng):
_photons, _rng = _rejection_sample(
photons,
rng,
self.xrange * 2.0,
jax_partial(self.__class__._xval),
self.positive_flux,
self.negative_flux,
self._xval_noraise(0.0),
)
photons.x = _photons.x
photons.y = _photons.y
photons.flux = _photons.flux
rng._state = _rng._state


@_wraps(_galsim.interpolant.Quintic)
@register_pytree_node_class
Expand Down Expand Up @@ -754,21 +676,6 @@ def ixrange(self):
"""The total integral range of the interpolant. Typically 2 * xrange."""
return 6

def _shoot(self, photons, rng):
_photons, _rng = _rejection_sample(
photons,
rng,
self.xrange * 2.0,
jax_partial(self.__class__._xval),
self.positive_flux,
self.negative_flux,
self._xval_noraise(0.0),
)
photons.x = _photons.x
photons.y = _photons.y
photons.flux = _photons.flux
rng._state = _rng._state


@_wraps(_galsim.interpolant.Lanczos)
@register_pytree_node_class
Expand Down Expand Up @@ -1745,21 +1652,6 @@ def unit_integrals(self, max_len=None):
else:
return self._unit_integrals_no_conserve_dc[self._n][:n]

def _shoot(self, photons, rng):
_photons, _rng = _rejection_sample(
photons,
rng,
self.xrange * 2.0,
jax_partial(self.__class__._xval, self._n, self._conserve_dc, self._K_arr),
self.positive_flux,
self.negative_flux,
self._xval_noraise(0.0),
)
photons.x = _photons.x
photons.y = _photons.y
photons.flux = _photons.flux
rng._state = _rng._state


# we apply JIT here to esnure the class init is fast
@jax.jit
Expand Down
Loading

0 comments on commit de1aad9

Please sign in to comment.