diff --git a/jax_galsim/core/draw.py b/jax_galsim/core/draw.py index 00696da7..8b632f7d 100644 --- a/jax_galsim/core/draw.py +++ b/jax_galsim/core/draw.py @@ -1,7 +1,8 @@ -import galsim as _galsim +from collections import namedtuple + import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +import numpy as np from jax_galsim.random import PoissonDeviate @@ -84,17 +85,147 @@ def phase(kpos): ) -def sample_poisson_flux(flux, eta_factor, rng=None): - """Sample the flux according to a Poisson distribution. +NPhotonsData = namedtuple( + "NPhotonsData", + [ + "n_photons", + "flux", + "flux_per_photon", + "max_sb", + "rng", + "poisson_flux", + "max_extra_noise", + ], +) + + +def calculate_n_photons( + flux, + eta_factor, + max_sb, + rng=None, + max_extra_noise=0, + poisson_flux=True, +): + """ + Calculate the number of photons to shoot for photon shooting. + + This routine is pure Python and is not JAX-compatible. Parameters: flux: The flux of the GSObject (e.g., ``obj.flux``). eta_factor: The flux per photon (e.g., ``obj._flux_per_photon``). + max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``). rng: If provided, a random number generator to use for photon shooting, which may be any kind of `BaseDeviate` object. If ``rng`` is None, one will be automatically created, using the time as a seed. [default: None] + max_extra_noise: If provided, the allowed extra noise in each pixel when photon + shooting. This is only relevant if ``n_photons=0``, so the number of + photons is being automatically calculated. In that case, if the image + noise is dominated by the sky background, then you can get away with + using fewer shot photons than the full ``n_photons = flux``. + Essentially each shot photon can have a ``flux > 1``, which increases + the noise in each pixel. The ``max_extra_noise`` parameter specifies + how much extra noise per pixel is allowed because of this approximation. + A typical value for this might be ``max_extra_noise = sky_level / 100`` + where ``sky_level`` is the flux per pixel due to the sky. Note that + this uses a "variance" definition of noise, not a "sigma" definition. + [default: 0.] + poisson_flux: Whether to allow total object flux scaling to vary according to + Poisson statistics for ``n_photons`` samples when photon shooting. + [default: True, unless ``n_photons`` is given, in which case the default + is False] + + Returns: + n_photons: The number of photons. + g: The gain to use when shooting the photons. """ + n_photons, g, _ = _calculate_n_photons( + flux, + eta_factor, + max_sb, + rng, + max_extra_noise, + poisson_flux, + ) + return np.atleast_1d(n_photons).ravel()[0], np.atleast_1d(g).ravel()[0] + + +@jax.jit +def get_n_photons(n_photons_data): + _n_photons, g, _rng = jax.lax.cond( + n_photons_data.n_photons == 0.0, + _sample_zero, + _sample_nonzero, + n_photons_data, + ) + if n_photons_data.rng is not None: + n_photons_data.rng._state = _rng._state + return _n_photons, g, n_photons_data.rng + + +def _sample_nonzero(n_photons_data): + g, _rng = jax.lax.cond( + n_photons_data.poisson_flux, + lambda n_photons_data: _sample_poisson_flux( + n_photons_data.flux, n_photons_data.flux_per_photon, n_photons_data.rng + ), + lambda n_photons_data: (1.0, n_photons_data.rng), + n_photons_data, + ) + if n_photons_data.rng is not None: + n_photons_data.rng._state = _rng._state + vals = jnp.int_(n_photons_data.n_photons + 0.5), g, n_photons_data.rng + return vals + + +@jax.jit +def _sample_zero(n_photons_data): + Ntot, g, _rng = _calculate_n_photons( + n_photons_data.flux, + n_photons_data.flux_per_photon, + n_photons_data.max_sb, + rng=n_photons_data.rng, + max_extra_noise=n_photons_data.max_extra_noise, + poisson_flux=n_photons_data.poisson_flux, + ) + return Ntot, g, _rng + + +@jax.jit +def _calculate_n_photons( + flux, + eta_factor, + max_sb, + rng, + max_extra_noise, + poisson_flux, +): + _n_photons, _g, _rng = jax.lax.cond( + flux == 0.0, + lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: ( + 0, + 1.0, + rng, + ), + lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: _calculate_n_photons_flux_nonzero( + flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng + ), + flux, + eta_factor, + max_sb, + poisson_flux, + max_extra_noise, + rng, + ) + if rng is not None: + rng._state = _rng._state + return _n_photons, _g, rng + + +@jax.jit +def _sample_poisson_flux(flux, eta_factor, rng): # If we have both positive and negative photons, then the mix of these # already gives us some variation in the flux value from the variance # of how many are positive and how many are negative. @@ -111,54 +242,29 @@ def sample_poisson_flux(flux, eta_factor, rng=None): # We want the variance to be equal to flux, so we need an extra: # delta Var = (1 - 4*eta + 4*eta^2) * flux # = (1-2eta)^2 * flux - absflux = abs(flux) + absflux = jnp.abs(flux) mean = eta_factor * eta_factor * absflux pd = PoissonDeviate(rng, mean) pd_val = pd() - mean + absflux - return pd_val - - -@_wraps( - _galsim.GSObject._calculate_nphotons, - lax_description="""\ -Calculate the number of photons to shoot for photon shooting. - -This routine is pure Python and is not JAX-compatible. - -Parameters: - flux: The flux of the GSObject (e.g., ``obj.flux``). - eta_factor: The flux per photon (e.g., ``obj._flux_per_photon``). - max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``). - rng: If provided, a random number generator to use for photon shooting, - which may be any kind of `BaseDeviate` object. If ``rng`` is None, one - will be automatically created, using the time as a seed. - [default: None] - max_extra_noise: If provided, the allowed extra noise in each pixel when photon - shooting. This is only relevant if ``n_photons=0``, so the number of - photons is being automatically calculated. In that case, if the image - noise is dominated by the sky background, then you can get away with - using fewer shot photons than the full ``n_photons = flux``. - Essentially each shot photon can have a ``flux > 1``, which increases - the noise in each pixel. The ``max_extra_noise`` parameter specifies - how much extra noise per pixel is allowed because of this approximation. - A typical value for this might be ``max_extra_noise = sky_level / 100`` - where ``sky_level`` is the flux per pixel due to the sky. Note that - this uses a "variance" definition of noise, not a "sigma" definition. - [default: 0.] - poisson_flux: Whether to allow total object flux scaling to vary according to - Poisson statistics for ``n_photons`` samples when photon shooting. - [default: True, unless ``n_photons`` is given, in which case the default - is False] - -""", -) -def calculate_n_photons( - flux, - eta_factor, - max_sb, - rng=None, - max_extra_noise=0, - poisson_flux=True, + return pd_val / absflux, rng + + +def _adjust_flux_g_poisson(poisson_flux, flux, mod_flux, eta_factor, rng, g): + ratio, rng = _sample_poisson_flux(flux, eta_factor, rng) + g *= ratio + mod_flux *= ratio + return jnp.abs(mod_flux), g, rng + + +def _scale_extra_noise(max_extra_noise, mod_flux, g, max_sb): + gfactor = 1.0 + max_extra_noise / jnp.abs(max_sb) + mod_flux /= gfactor + g *= gfactor + return mod_flux, g + + +def _calculate_n_photons_flux_nonzero( + flux, flux_per_photon, max_sb, poisson_flux, max_extra_noise, rng ): # For profiles that are positive definite, then N = flux. Easy. # @@ -211,188 +317,40 @@ def calculate_n_photons( # Returns the total flux placed inside the image bounds by photon shooting. # - if flux == 0.0: - return 0, 1.0 - # The _flux_per_photon property is (1-2eta) # This factor will already be accounted for by the shoot function, so don't include # that as part of our scaling here. There may be other adjustments though, so g=1 here. + eta_factor = flux_per_photon mod_flux = flux / (eta_factor * eta_factor) g = 1.0 # If requested, let the target flux value vary as a Poisson deviate - if poisson_flux: - pd_val = sample_poisson_flux(flux, eta_factor, rng=rng) - ratio = pd_val / abs(flux) - g *= ratio - mod_flux *= ratio - - n_photons = abs(mod_flux) - if max_extra_noise > 0.0: - gfactor = 1.0 + max_extra_noise / abs(max_sb) - n_photons /= gfactor - g *= gfactor - - # Make n_photons an integer. - iN = int(n_photons + 0.5) - - return iN, g - - -# the code below is a jax version of calculate_nphotons -# that I am not sure if we need or not. -# saving in a comment for now - -# def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): -# _n_photons, _g, _rng = jax.lax.cond( -# self.flux == 0.0, -# lambda n_photons, poisson_flux, max_extra_noise, rng: (0, 1.0, rng), -# lambda n_photons, poisson_flux, max_extra_noise, rng: self._calculate_nphotons_nonzero( -# n_photons, poisson_flux, max_extra_noise, rng -# ), -# n_photons, -# poisson_flux, -# max_extra_noise, -# rng, -# ) -# if rng is not None: -# rng._state = _rng._state -# return _n_photons, _g - - -# def _adjust_flux_g_poisson(self, poisson_flux, flux, mod_flux, eta_factor, rng, g): -# from jax_galsim.random import PoissonDeviate - -# # If we have both positive and negative photons, then the mix of these -# # already gives us some variation in the flux value from the variance -# # of how many are positive and how many are negative. -# # The number of negative photons varies as a binomial distribution. -# # = eta * Ntot * g -# # = (1-eta) * Ntot * g -# # = (1-2eta) * Ntot * g = flux -# # Var(F-) = eta * (1-eta) * Ntot * g^2 -# # F+ = Ntot * g - F- is not an independent variable, so -# # Var(F+ - F-) = Var(Ntot*g - 2*F-) -# # = 4 * Var(F-) -# # = 4 * eta * (1-eta) * Ntot * g^2 -# # = 4 * eta * (1-eta) * flux -# # We want the variance to be equal to flux, so we need an extra: -# # delta Var = (1 - 4*eta + 4*eta^2) * flux -# # = (1-2eta)^2 * flux -# absflux = abs(flux) -# mean = eta_factor * eta_factor * absflux -# pd = PoissonDeviate(rng, mean) -# pd_val = pd() - mean + absflux -# ratio = pd_val / absflux -# g *= ratio -# mod_flux *= ratio -# return jnp.abs(mod_flux), g, rng - - -# def _scale_extra_noise(self, max_extra_noise, mod_flux, g, max_sb): -# gfactor = 1.0 + max_extra_noise / jnp.abs(max_sb) -# mod_flux /= gfactor -# g *= gfactor -# return mod_flux, g - - -# def _calculate_nphotons_nonzero(self, n_photons, poisson_flux, max_extra_noise, rng): -# # For profiles that are positive definite, then N = flux. Easy. -# # -# # However, some profiles shoot some of their photons with negative flux. This means that -# # we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the -# # fraction of shot photons that have negative flux. -# # -# # S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2 -# # N^2 = Var(S) = (N+ + N-) = Ntot -# # -# # So flux = (S/N)^2 = Ntot (1-2eta)^2 -# # Ntot = flux / (1-2eta)^2 -# # -# # However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta). -# # So in fact, each photon needs to carry a flux of g = 1-2eta to get the right -# # total flux. -# # -# # That's all the easy case. The trickier case is when we are sky-background dominated. -# # Then we can usually get away with fewer shot photons than the above. In particular, -# # if the noise from the photon shooting is much less than the sky noise, then we can -# # use fewer shot photons and essentially have each photon have a flux > 1. This is ok -# # as long as the additional noise due to this approximation is "much less than" the -# # noise we'll be adding to the image for the sky noise. -# # -# # Let's still have Ntot photons, but now each with a flux of g. And let's look at the -# # noise we get in the brightest pixel that has a nominal total flux of Imax. -# # -# # The number of photons hitting this pixel will be Imax/flux * Ntot. -# # The variance of this number is the same thing (Poisson counting). -# # So the noise in that pixel is: -# # -# # N^2 = Imax/flux * Ntot * g^2 -# # -# # And the signal in that pixel will be: -# # -# # S = Imax/flux * (N+ - N-) * g which has to equal Imax, so -# # g = flux / Ntot(1-2eta) -# # N^2 = Imax/Ntot * flux / (1-2eta)^2 -# # -# # As expected, we see that lowering Ntot will increase the noise in that (and every -# # other) pixel. -# # The input max_extra_noise parameter is the maximum value of spurious noise we want -# # to allow. -# # -# # So setting N^2 = Imax + nu, we get -# # -# # Ntot = flux / (1-2eta)^2 / (1 + nu/Imax) -# # g = (1 - 2eta) * (1 + nu/Imax) -# # -# # Returns the total flux placed inside the image bounds by photon shooting. -# # - -# flux = self.flux - -# # The _flux_per_photon property is (1-2eta) -# # This factor will already be accounted for by the shoot function, so don't include -# # that as part of our scaling here. There may be other adjustments though, so g=1 here. -# eta_factor = self._flux_per_photon -# mod_flux = flux / (eta_factor * eta_factor) -# g = 1.0 - -# # If requested, let the target flux value vary as a Poisson deviate -# mod_flux, g, _rng = jax.lax.cond( -# poisson_flux, -# lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: self._adjust_flux_g_poisson( -# poisson_flux, flux, mod_flux, eta_factor, rng, g -# ), -# lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: (mod_flux, g, rng), -# poisson_flux, -# flux, -# mod_flux, -# eta_factor, -# rng, -# g, -# ) -# if rng is not None: -# rng._state = _rng._state - -# mod_flux, g = jax.lax.cond( -# max_extra_noise > 0.0, -# lambda max_extra_noise, mod_flux, g, max_sb: self._scale_extra_noise( -# max_extra_noise, mod_flux, g, max_sb -# ), -# lambda max_extra_noise, mod_flux, g, max_sb: (mod_flux, g), -# max_extra_noise, -# mod_flux, -# g, -# self.max_sb, -# ) - -# # Make n_photons an integer and use input if requested -# n_photons = jax.lax.cond( -# n_photons == 0.0, -# lambda n_photons, mod_flux: jnp.ceil(mod_flux).astype(int), -# lambda n_photons, mod_flux: jnp.ceil(n_photons).astype(int), -# n_photons, -# mod_flux, -# ) - -# return n_photons, g, rng + mod_flux, g, _rng = jax.lax.cond( + poisson_flux, + lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: _adjust_flux_g_poisson( + poisson_flux, flux, mod_flux, eta_factor, rng, g + ), + lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: (mod_flux, g, rng), + poisson_flux, + flux, + mod_flux, + eta_factor, + rng, + g, + ) + if rng is not None: + rng._state = _rng._state + + mod_flux, g = jax.lax.cond( + max_extra_noise > 0.0, + lambda max_extra_noise, mod_flux, g, max_sb: _scale_extra_noise( + max_extra_noise, mod_flux, g, max_sb + ), + lambda max_extra_noise, mod_flux, g, max_sb: (mod_flux, g), + max_extra_noise, + mod_flux, + g, + max_sb, + ) + + return jnp.ceil(mod_flux).astype(int), g, rng diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 6d893e4d..6d32fd6f 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -7,13 +7,12 @@ from jax._src.numpy.util import _wraps import jax_galsim.photon_array as pa -from jax_galsim.core.draw import calculate_n_photons, sample_poisson_flux +from jax_galsim.core.draw import NPhotonsData, get_n_photons from jax_galsim.core.utils import is_equal_with_arrays from jax_galsim.errors import ( GalSimError, GalSimIncompatibleValuesError, GalSimNotImplementedError, - GalSimRangeError, GalSimValueError, galsim_warn, ) @@ -666,7 +665,7 @@ def drawImage( center=None, use_true_center=True, offset=None, - n_photons=0.0, + n_photons=None, rng=None, max_extra_noise=0.0, poisson_flux=None, @@ -1104,32 +1103,35 @@ def _drawKImage( @_wraps(_galsim.GSObject._calculate_nphotons) def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): - if n_photons == 0.0: - Ntot, g = calculate_n_photons( - self.flux, - self._flux_per_photon, - self.max_sb, - rng=rng, - max_extra_noise=max_extra_noise, - poisson_flux=poisson_flux, - ) - else: - Ntot = int(n_photons + 0.5) - if poisson_flux: - pd_val = sample_poisson_flux(self.flux, self._flux_per_photon, rng=rng) - g = pd_val / jnp.abs(self.flux) - else: - g = 1.0 - - return Ntot, g + npd = NPhotonsData( + n_photons=n_photons, + poisson_flux=poisson_flux, + max_extra_noise=max_extra_noise, + rng=rng, + flux=self.flux, + flux_per_photon=self._flux_per_photon, + max_sb=self.max_sb, + ) + n_photons, g, _rng = get_n_photons(npd) + if rng is not None: + rng._state = _rng._state + return n_photons, g @_wraps( _galsim.GSObject.makePhot, - lax_description="The JAX-GalSim version of `makePhot` does not support the deprecated surface_ops argument.", + lax_description="""\ +The JAX-GalSim version of `makePhot` + + - does not support the deprecated surface_ops argument + - does little to no error checking on the inputs + - uses a default of ``n_photons=None`` instead of ``n_photons=0`` + to indicate that the number of photons should be determined + from the flux and gain +""", ) def makePhot( self, - n_photons=0, + n_photons=None, rng=None, max_extra_noise=0.0, poisson_flux=None, @@ -1143,24 +1145,17 @@ def makePhot( depr("surface_ops", 2.3, "photon_ops") photon_ops = surface_ops - # Make sure the type of n_photons is correct and has a valid value: - if not n_photons >= 0.0: - raise GalSimRangeError("Invalid n_photons < 0.", n_photons, 0.0, None) - if poisson_flux is None: # If n_photons is given, poisson_flux = False - poisson_flux = n_photons == 0.0 + poisson_flux = n_photons is None - # Check that either n_photons is set to something or flux is set to something - if n_photons == 0.0 and self.flux == 1.0: - galsim_warn( - "Warning: drawImage for object with flux == 1, area == 1, and " - "exptime == 1, but n_photons == 0. This will only shoot a single photon." + if n_photons is not None: + Ntot = int(n_photons + 0.5) + _, g = self._calculate_nphotons( + n_photons, poisson_flux, max_extra_noise, rng ) - - Ntot, g = self._calculate_nphotons( - n_photons, poisson_flux, max_extra_noise, rng - ) + else: + Ntot, g = self._calculate_nphotons(0.0, poisson_flux, max_extra_noise, rng) try: photons = self.shoot(Ntot, rng) @@ -1186,14 +1181,24 @@ def makePhot( @_wraps( _galsim.GSObject.drawPhot, - lax_description="The JAX-GalSim version of `drawPhot` does not support the deprecated surface_ops argument.", + lax_description="""\ +The JAX-GalSim version of `drawPhot` + + - does not support the deprecated surface_ops argument + - does little to no error checking on the inputs + - uses a default of ``n_photons=None`` instead of ``n_photons=0`` + to indicate that the number of photons should be determined + from the flux and gain + - the maxN option requires the use of fixed photon array sizes or a fixed + number of photons +""", ) def drawPhot( self, image, gain=1.0, add_to_image=False, - n_photons=0, + n_photons=None, rng=None, max_extra_noise=0.0, poisson_flux=None, @@ -1203,20 +1208,9 @@ def drawPhot( orig_center=PositionI(0, 0), local_wcs=None, ): - # Make sure the type of n_photons is correct and has a valid value: - if not n_photons >= 0.0: - raise GalSimRangeError("Invalid n_photons < 0.", n_photons, 0.0, None) - + # If n_photons is given and poisson_flux is None, poisson_flux = False if poisson_flux is None: - # If n_photons is given, poisson_flux = False - poisson_flux = n_photons == 0.0 - - # Check that either n_photons is set to something or flux is set to something - if n_photons == 0.0 and self.flux == 1.0: - galsim_warn( - "Warning: drawImage for object with flux == 1, area == 1, and " - "exptime == 1, but n_photons == 0. This will only shoot a single photon." - ) + poisson_flux = n_photons is None # Make sure the image is set up to have unit pixel scale and centered at 0,0. if image.wcs is None or not image.wcs._isPixelScale: @@ -1229,9 +1223,14 @@ def drawPhot( elif not isinstance(sensor, Sensor): raise TypeError("The sensor provided is not a Sensor instance") - Ntot, g = self._calculate_nphotons( - n_photons, poisson_flux, max_extra_noise, rng - ) + if n_photons is not None: + Ntot = int(n_photons + 0.5) + _, g = self._calculate_nphotons( + n_photons, poisson_flux, max_extra_noise, rng + ) + else: + Ntot, g = self._calculate_nphotons(0.0, poisson_flux, max_extra_noise, rng) + g = jax.lax.cond( gain != 1.0, lambda g, gain: g / gain, @@ -1240,40 +1239,64 @@ def drawPhot( gain, ) - if maxN is None: - maxN = Ntot - if not add_to_image: image.setZero() - ( - photons, - _rng, - added_flux, - _Nleft, - _image, - _photon_ops, - _sensor, - ) = _draw_phot_while_loop( - PhotonArray(maxN), - rng, - self, - image, - g, - Ntot, - maxN, - photon_ops, - local_wcs, - sensor, - orig_center, - ) + if maxN is None: + ( + added_flux, + _image, + _sensor, + _photon_ops, + _rng, + _, + photons, + ) = _draw_phot_while_loop_shoot( + Ntot, + Ntot, + Ntot, + self, + rng, + g, + image, + photon_ops, + sensor, + orig_center, + local_wcs, + False, + 0.0, + ) + else: + ( + photons, + _rng, + added_flux, + _Nleft, + _image, + _photon_ops, + _sensor, + ) = _draw_phot_while_loop( + PhotonArray(maxN), + rng, + self, + image, + g, + Ntot, + maxN, + photon_ops, + local_wcs, + sensor, + orig_center, + ) if rng is not None: rng._state = _rng._state else: rng = _rng for i in range(len(photon_ops)): photon_ops[i] = _photon_ops[i] + image._array = _image._array + # TODO: how to update the sensor? if sensor.__class__ is not Sensor: raise GalSimNotImplementedError( @@ -1285,13 +1308,13 @@ def drawPhot( @_wraps(_galsim.GSObject.shoot) def shoot(self, n_photons, rng=None): photons = pa.PhotonArray(n_photons) - if n_photons == 0: - # It's ok to shoot 0, but downstream can have problems with it, so just stop now. - return photons - if rng is None: - rng = BaseDeviate() - self._shoot(photons, rng) + if photons._x.shape[0] > 0: + _rng = BaseDeviate(rng) + self._shoot(photons, _rng) + if rng is not None: + rng._state = _rng._state + return photons @_wraps(_galsim.GSObject._shoot) @@ -1329,6 +1352,71 @@ def tree_unflatten(cls, aux_data, children): return cls(**(children[0]), **aux_data) +def _draw_phot_while_loop_shoot( + maxN, + thisN, + Ntot, + obj, + rng, + g, + image, + photon_ops, + sensor, + orig_center, + local_wcs, + resume, + added_flux, +): + try: + photons = obj.shoot(maxN, rng) + except (GalSimError, NotImplementedError) as e: + raise GalSimNotImplementedError( + "Unable to draw this GSObject with photon shooting. Perhaps it " + "is a Deconvolve or is a compound including one or more " + "Deconvolve objects.\nOriginal error: %r" % (e) + ) + # we drew maxN, but only keep thisN of them + photons._num_keep = thisN + + photons = jax.lax.cond( + # weird way to say gain == 1 and thisN == Ntot + jnp.abs(g - 1.0) + jnp.abs(thisN - Ntot) == 0, + lambda photons, g, thisN, Ntot: photons, + # the factor here is thisN / Ntot since we drew thisN photons, but use a total of Ntot photons + lambda photons, g, thisN, Ntot: photons.scaleFlux(g * thisN / Ntot), + photons, + g, + thisN, + Ntot, + ) + + photons = jax.lax.cond( + image.scale != 1.0, + lambda photons, scale: photons.scaleXY( + 1.0 / scale + ), # Convert x,y to image coords if necessary + lambda photons, scale: photons, + photons, + image.scale, + ) + + for op in photon_ops: + op.applyTo(photons, local_wcs, rng) + + if image.dtype in (jnp.float32, jnp.float64): + added_flux += sensor.accumulate(photons, image, orig_center, resume=resume) + resume = True # Resume from this point if there are any further iterations. + else: + # Need a temporary + from jax_galsim.image import ImageD + + im1 = ImageD(bounds=image.bounds) + added_flux += sensor.accumulate(photons, im1, orig_center) + image += im1 + + return added_flux, image, sensor, photon_ops, rng, resume, photons + + @partial(jax.jit, static_argnames=("maxN",)) def _draw_phot_while_loop( photons, @@ -1378,72 +1466,44 @@ def _body_fun(args): # Shoot at most maxN at a time thisN = jnp.minimum(maxN, Nleft) - try: - photons = obj.shoot(maxN, rng) - except (GalSimError, NotImplementedError) as e: - raise GalSimNotImplementedError( - "Unable to draw this GSObject with photon shooting. Perhaps it " - "is a Deconvolve or is a compound including one or more " - "Deconvolve objects.\nOriginal error: %r" % (e) - ) - photons.flux = jnp.where( - jnp.arange(maxN) < thisN, - photons.flux, - 0.0, - ) - - photons = jax.lax.cond( - # weird way to say gain == 1 and thisN == Ntot - jnp.abs(g - 1.0) + jnp.abs(thisN - Ntot) == 0, - lambda photons, g, thisN, Ntot: photons, - # the factor here is (maxN / thisN) * (thisN / Ntot) = maxN / Ntot - # the first bit is that we drew maxN photons, but only thisN of them are valid - # the second bit is that we only drew thisN photons, but use a total of Ntot photons - lambda photons, g, thisN, Ntot: photons.scaleFlux(g * maxN / Ntot), - photons, - g, + ( + _added_flux, + _image, + _sensor, + _photon_ops, + _rng, + _resume, + _photons, + ) = _draw_phot_while_loop_shoot( + maxN, thisN, Ntot, + obj, + rng, + g, + image, + photon_ops, + sensor, + orig_center, + local_wcs, + resume, + added_flux, ) - photons = jax.lax.cond( - image.scale != 1.0, - lambda photons, scale: photons.scaleXY( - 1.0 / scale - ), # Convert x,y to image coords if necessary - lambda photons, scale: photons, - photons, - image.scale, - ) - - for op in photon_ops: - op.applyTo(photons, local_wcs, rng) - - if image.dtype in (jnp.float32, jnp.float64): - added_flux += sensor.accumulate(photons, image, orig_center, resume=resume) - resume = True # Resume from this point if there are any further iterations. - else: - # Need a temporary - from jax_galsim.image import ImageD - - im1 = ImageD(bounds=image.bounds) - added_flux += sensor.accumulate(photons, im1, orig_center) - image += im1 - Nleft -= thisN return ( - photons, - rng, - added_flux, + _photons, + _rng, + _added_flux, obj, Nleft, - resume, - image, + _resume, + _image, g, - photon_ops, + _photon_ops, local_wcs, - sensor, + _sensor, orig_center, ) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 9cd661b9..fb220232 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -11,7 +11,6 @@ import jax.numpy as jnp from galsim.errors import GalSimValueError from jax._src.numpy.util import _wraps -from jax.tree_util import Partial as jax_partial from jax.tree_util import register_pytree_node_class from jax_galsim.bessel import si @@ -19,96 +18,7 @@ from jax_galsim.errors import GalSimError from jax_galsim.gsparams import GSParams from jax_galsim.random import UniformDeviate - - -@jax.jit -def _rejection_sample(photons, rng, tot_xrange, xval, pos_flux, neg_flux, max_val): - """Use rejection sampling to generate photons from a given 1D interpolant function. - - We sample both x and y values from the interpolant function. - - Parameters - ---------- - photons : PhotonArray - The photon array to shoot into. - rng : BaseDeviate - The random number generator to use for drawing photons. - tot_xrange : float - The total range of the interpolant function from the most negative - point to the most positive point. The interpolant is assumed to be - symmetric about zero. - xval : callable - The interpolant function. Will only be called with positive values. - pos_flux : float - The total integral under all positive regions of the interpolant function. - neg_flux : float - The absolute value of the total integral under all negative regions of - the interpolant function. - max_val : float - The maximum value of the interpolant function. Usually this is xval(0.0) and - is 1.0. - """ - - def _cond_fun(args): - # we stop drawing when we have tot photons - # curr records how many we have currently - _, _, tot, _, curr = args - return curr < tot - - def _body_fun(args): - arr, sign, tot, ud, curr = args - # arr is the array we are filling with photon positions - # sign is the array of signs of the interpolant function at the photon positions - # tot is the total number of photons to draw - # ud is the random number generator for uniform deviates from 0 to 1 - # curr is the current number of photons drawn - - # we first draw a random x location centered at zero with a - # total range of tot_xrange - xloc = (ud() - 0.5) * tot_xrange - - # next we draw a random y value between 0 and max_val - yv = ud() * max_val - xloc_val = xval(xloc) - - # this cond operator keeps the photon if the y value we drew is - # below the interpolant function at the x location we drew - arr, sign, curr = jax.lax.cond( - yv <= jnp.abs(xloc_val), - # if we keep it, assign the location, assign the sign, and increment curr - lambda arr, sign, curr, xloc, xloc_val: ( - arr.at[curr].set(xloc), - sign.at[curr].set(jnp.sign(xloc_val)), - curr + 1, - ), - # otherwise we pass - lambda arr, sign, curr, xloc, xloc_val: (arr, sign, curr), - arr, - sign, - curr, - xloc, - xloc_val, - ) - return arr, sign, tot, ud, curr - - ud = UniformDeviate(rng) - - # we first make the x and y positions - photons.x, _sign_x, _, ud, _ = jax.lax.while_loop( - _cond_fun, - _body_fun, - (jnp.zeros_like(photons.x), jnp.zeros_like(photons.x), photons.size(), ud, 0), - ) - photons.y, _sign_y, _, ud, _ = jax.lax.while_loop( - _cond_fun, - _body_fun, - (jnp.zeros_like(photons.y), jnp.zeros_like(photons.y), photons.size(), ud, 0), - ) - # this magic formula comes from looking closely at the galsim code in Interpolant.cpp - # and how things get adjusted down the line OneDimensionalDeviate.cpp - flux_per = (pos_flux + neg_flux) ** 2 / photons.size() - photons.flux = _sign_x * _sign_y * flux_per - return photons, rng +from jax_galsim.utilities import lazy_property @_wraps(_galsim.interpolant.Interpolant) @@ -350,9 +260,36 @@ def urange(self): % self.__class__.__name__ ) + @lazy_property + def _shoot_cdf(self): + x = jnp.linspace(-self.xrange, self.xrange, 10000) + px = jnp.abs(self._xval_noraise(jnp.abs(x))) + dx = x[1] - x[0] + # cumulative trapezoidal rule + # see scipy.integrate.cumulative_trapezoidal + cdfx = jnp.concatenate( + [jnp.array([0]), jnp.cumsum((px[1:] + px[:-1]) * 0.5 * dx)] + ) + cdfx /= cdfx[-1] + return x, cdfx + def _shoot(self, photons, rng): - raise NotImplementedError( - "%s does not implement shoot" % self.__class__.__name__ + x, cdfx = self._shoot_cdf + ud = UniformDeviate(rng) + ux = ud.generate(photons.x) + uy = ud.generate(photons.y) + photons.x = jnp.interp(ux, cdfx, x) + photons.y = jnp.interp(uy, cdfx, x) + if photons.size() > 0: + flux_per_photon = ( + self.positive_flux + self.negative_flux + ) ** 2 / photons.size() + else: + flux_per_photon = 0.0 + photons.flux = ( + flux_per_photon + * jnp.sign(self._xval_noraise(photons.x)) + * jnp.sign(self._xval_noraise(photons.y)) ) # subclasses should implement __init__, _xval, _uval, @@ -644,21 +581,6 @@ def ixrange(self): """The total integral range of the interpolant. Typically 2 * xrange.""" return 4 - def _shoot(self, photons, rng): - _photons, _rng = _rejection_sample( - photons, - rng, - self.xrange * 2.0, - jax_partial(self.__class__._xval), - self.positive_flux, - self.negative_flux, - self._xval_noraise(0.0), - ) - photons.x = _photons.x - photons.y = _photons.y - photons.flux = _photons.flux - rng._state = _rng._state - @_wraps(_galsim.interpolant.Quintic) @register_pytree_node_class @@ -754,21 +676,6 @@ def ixrange(self): """The total integral range of the interpolant. Typically 2 * xrange.""" return 6 - def _shoot(self, photons, rng): - _photons, _rng = _rejection_sample( - photons, - rng, - self.xrange * 2.0, - jax_partial(self.__class__._xval), - self.positive_flux, - self.negative_flux, - self._xval_noraise(0.0), - ) - photons.x = _photons.x - photons.y = _photons.y - photons.flux = _photons.flux - rng._state = _rng._state - @_wraps(_galsim.interpolant.Lanczos) @register_pytree_node_class @@ -1745,21 +1652,6 @@ def unit_integrals(self, max_len=None): else: return self._unit_integrals_no_conserve_dc[self._n][:n] - def _shoot(self, photons, rng): - _photons, _rng = _rejection_sample( - photons, - rng, - self.xrange * 2.0, - jax_partial(self.__class__._xval, self._n, self._conserve_dc, self._K_arr), - self.positive_flux, - self.negative_flux, - self._xval_noraise(0.0), - ) - photons.x = _photons.x - photons.y = _photons.y - photons.flux = _photons.flux - rng._state = _rng._state - # we apply JIT here to esnure the class init is fast @jax.jit diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index 004f543d..da72dfbb 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -1,3 +1,5 @@ +from contextlib import contextmanager + import galsim as _galsim import jax import jax.numpy as jnp @@ -16,6 +18,20 @@ from ._pyfits import pyfits +_JAX_GALSIM_PHOTON_ARRAY_SIZE = None + + +@contextmanager +def fixed_photon_array_size(size): + """Context manager to temporarily set a fixed size for photon arrays.""" + global _JAX_GALSIM_PHOTON_ARRAY_SIZE + old_size = _JAX_GALSIM_PHOTON_ARRAY_SIZE + _JAX_GALSIM_PHOTON_ARRAY_SIZE = size + try: + yield + finally: + _JAX_GALSIM_PHOTON_ARRAY_SIZE = old_size + @_wraps( _galsim.PhotonArray, @@ -41,20 +57,35 @@ def __init__( pupil_u=None, pupil_v=None, time=None, + _nokeep=None, ): - self._N = N + # self._N = N + self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N + # if ( + # _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None + # and isinstance(N, int) + # and N > _JAX_GALSIM_PHOTON_ARRAY_SIZE + # ): + # raise GalSimValueError( + # f"The given photon array size {N} is larger than " + # f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}." + # ) + if _nokeep is not None: + self._nokeep = _nokeep + else: + self._nokeep = jnp.arange(self._Ntot) >= N # Only x, y, flux are built by default, since these are always required. # The others we leave as None unless/until they are needed. - self._x = jnp.zeros(self._N, dtype=float) - self._y = jnp.zeros(self._N, dtype=float) - self._flux = jnp.zeros(self._N, dtype=float) - self._dxdz = jnp.full(self._N, jnp.nan, dtype=float) - self._dydz = jnp.full(self._N, jnp.nan, dtype=float) - self._wave = jnp.full(self._N, jnp.nan, dtype=float) - self._pupil_u = jnp.full(self._N, jnp.nan, dtype=float) - self._pupil_v = jnp.full(self._N, jnp.nan, dtype=float) - self._time = jnp.full(self._N, jnp.nan, dtype=float) + self._x = jnp.zeros(self._Ntot, dtype=float) + self._y = jnp.zeros(self._Ntot, dtype=float) + self._flux = jnp.zeros(self._Ntot, dtype=float) + self._dxdz = jnp.full(self._Ntot, jnp.nan, dtype=float) + self._dydz = jnp.full(self._Ntot, jnp.nan, dtype=float) + self._wave = jnp.full(self._Ntot, jnp.nan, dtype=float) + self._pupil_u = jnp.full(self._Ntot, jnp.nan, dtype=float) + self._pupil_v = jnp.full(self._Ntot, jnp.nan, dtype=float) + self._time = jnp.full(self._Ntot, jnp.nan, dtype=float) self._is_corr = jnp.array(False) if x is not None: @@ -113,59 +144,78 @@ def _fromArrays( time=None, is_corr=False, ): + if ( + _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None + and x.shape[0] != _JAX_GALSIM_PHOTON_ARRAY_SIZE + ): + raise GalSimValueError( + "The given arrays do not match the expected total size", + x.shape[0], + _JAX_GALSIM_PHOTON_ARRAY_SIZE, + ) + ret = cls.__new__(cls) - ret._N = x.shape[0] + # ret._N = x.shape[0] + ret._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or x.shape[0] ret._x = x.copy() ret._y = y.copy() ret._flux = flux.copy() + ret._nokeep = jnp.arange(ret._Ntot) >= x.shape[0] ret._dxdz = ( - dxdz.copy() if dxdz is not None else jnp.full(ret._N, jnp.nan, dtype=float) + dxdz.copy() + if dxdz is not None + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._dydz = ( - dydz.copy() if dydz is not None else jnp.full(ret._N, jnp.nan, dtype=float) + dydz.copy() + if dydz is not None + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._wave = ( wavelength.copy() if wavelength is not None - else jnp.full(ret._N, jnp.nan, dtype=float) + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._pupil_u = ( pupil_u.copy() if pupil_u is not None - else jnp.full(ret._N, jnp.nan, dtype=float) + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._pupil_v = ( pupil_v.copy() if pupil_v is not None - else jnp.full(ret._N, jnp.nan, dtype=float) + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._time = ( - time.copy() if time is not None else jnp.full(ret._N, jnp.nan, dtype=float) + time.copy() + if time is not None + else jnp.full(ret._Ntot, jnp.nan, dtype=float) ) ret._is_corr = jnp.array(is_corr) return ret def tree_flatten(self): children = ( - (self.x, self.y, self.flux), + (self._x, self._y, self._flux, self._nokeep), { - "dxdz": self.dxdz, - "dydz": self.dydz, - "wavelength": self.wavelength, - "pupil_u": self.pupil_u, - "pupil_v": self.pupil_v, - "time": self.time, - "is_corr": self.isCorrelated(), + "dxdz": self._dxdz, + "dydz": self._dydz, + "wavelength": self._wave, + "pupil_u": self._pupil_u, + "pupil_v": self._pupil_v, + "time": self._time, + "is_corr": self._is_corr, }, ) - aux_data = (self._N,) + aux_data = (self._Ntot,) return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" ret = cls.__new__(cls) - ret._N = aux_data[0] + ret._Ntot = aux_data[0] + ret._nokeep = children[0][3] ret._x = children[0][0] ret._y = children[0][1] ret._flux = children[0][2] @@ -180,10 +230,20 @@ def tree_unflatten(cls, aux_data, children): def size(self): """Return the size of the photon array. Equivalent to ``len(self)``.""" - return self._N + return self._Ntot def __len__(self): - return self._N + return self._Ntot + + @property + def _num_keep(self): + """The number of actual photons in the array.""" + return jnp.sum(~self._nokeep).astype(int) + + @_num_keep.setter + def _num_keep(self, num_keep): + """Set the number of actual photons in the array.""" + self._nokeep = jnp.arange(self._Ntot) >= num_keep @property def x(self): @@ -210,7 +270,13 @@ def y(self, value): @property def flux(self): """The flux of the photons.""" - return self._flux + return jax.lax.cond( + self._Ntot == self._num_keep, + lambda flux, ratio: flux, + lambda flux, ratio: flux * ratio, + jnp.where(self._nokeep, 0.0, self._flux), + self._Ntot / self._num_keep, + ) @flux.setter def flux(self, value): @@ -375,6 +441,22 @@ def scaleXY(self, scale): return self + def _sort_by_nokeep(self): + # now sort things to keep to the left + sinds = jnp.argsort(self._nokeep) + self._x = self._x.at[sinds].get() + self._y = self._y.at[sinds].get() + self._flux = self._flux.at[sinds].get() + self._nokeep = self._nokeep.at[sinds].get() + self._dxdz = self._dxdz.at[sinds].get() + self._dydz = self._dydz.at[sinds].get() + self._wave = self._wave.at[sinds].get() + self._pupil_u = self._pupil_u.at[sinds].get() + self._pupil_v = self._pupil_v.at[sinds].get() + self._time = self._time.at[sinds].get() + + return self + def assignAt(self, istart, rhs): """Assign the contents of another `PhotonArray` to this one starting at istart.""" if istart + rhs.size() > self.size(): @@ -385,6 +467,7 @@ def assignAt(self, istart, rhs): self._x = self._x.at[istart : istart + rhs.size()].set(rhs.x) self._y = self._y.at[istart : istart + rhs.size()].set(rhs.y) self._flux = self._flux.at[istart : istart + rhs.size()].set(rhs.flux) + self._nokeep = self._nokeep.at[istart : istart + rhs.size()].set(rhs._nokeep) self._dxdz = self._dxdz.at[istart : istart + rhs.size()].set(rhs.dxdz) self._dydz = self._dydz.at[istart : istart + rhs.size()].set(rhs.dydz) self._wave = self._wave.at[istart : istart + rhs.size()].set(rhs.wavelength) @@ -392,7 +475,7 @@ def assignAt(self, istart, rhs): self._pupil_v = self._pupil_v.at[istart : istart + rhs.size()].set(rhs.pupil_v) self._time = self._time.at[istart : istart + rhs.size()].set(rhs.time) - return self + return self._sort_by_nokeep() def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): """Assign the contents of another `PhotonArray` to this one at locations @@ -402,6 +485,7 @@ def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): self._x = jnp.where(msk, rhs._x, self._x) self._y = jnp.where(msk, rhs._y, self._y) self._flux = jnp.where(msk, rhs._flux, self._flux) + self._nokeep = jnp.where(msk, rhs._nokeep, self._nokeep) self._dxdz = jnp.where(msk, rhs._dxdz, self._dxdz) self._dydz = jnp.where(msk, rhs._dydz, self._dydz) @@ -410,7 +494,7 @@ def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs): self._pupil_v = jnp.where(msk, rhs._pupil_v, self._pupil_v) self._time = jnp.where(msk, rhs._time, self._time) - return self + return self._sort_by_nokeep() def convolve(self, rhs, rng=None): """Convolve this `PhotonArray` with another. @@ -428,7 +512,7 @@ def convolve(self, rhs, rng=None): rng = BaseDeviate(rng) rsinds = jrng.choice( rng._state.split_one(), - self.size(), + self._Ntot, shape=(self.size(),), replace=False, ) @@ -436,7 +520,9 @@ def convolve(self, rhs, rng=None): sinds = jax.lax.cond( jnp.array(self.isCorrelated()) & jnp.array(rhs.isCorrelated()), - lambda nrsinds, rsinds: rsinds, + lambda nrsinds, rsinds: rsinds.at[ + jnp.argsort(rhs._nokeep.at[rsinds].get()) + ].get(), lambda nrsinds, rsinds: nrsinds, nrsinds, rsinds, @@ -527,6 +613,7 @@ def __repr__(self): ) if self.hasAllocatedTimes(): s += ", time=array(%r)" % np.array(self.time).tolist() + s += ", _nokeep=array(%r)" % np.array(self._nokeep).tolist() s += ")" return s @@ -541,6 +628,7 @@ def __eq__(self, other): and jnp.array_equal(self.x, other.x) and jnp.array_equal(self.y, other.y) and jnp.array_equal(self.flux, other.flux) + and jnp.array_equal(self._nokeep, other._nokeep) and jnp.array_equal(self.dxdz, other.dxdz, equal_nan=True) and jnp.array_equal(self.dydz, other.dydz, equal_nan=True) and jnp.array_equal(self.wavelength, other.wavelength, equal_nan=True) @@ -574,7 +662,7 @@ def addTo(self, image): _arr, _flux_sum = _add_photons_to_image( self._x, self._y, - self._flux, + jnp.where(self._nokeep, 0.0, self._flux) * self._Ntot / self._num_keep, image.bounds.xmin, image.bounds.ymin, image._array, @@ -651,6 +739,9 @@ def write(self, file_name): cols.append(pyfits.Column(name="x", format="D", array=np.array(self.x))) cols.append(pyfits.Column(name="y", format="D", array=np.array(self.y))) cols.append(pyfits.Column(name="flux", format="D", array=np.array(self.flux))) + cols.append( + pyfits.Column(name="_nokeep", format="L", array=np.array(self._nokeep)) + ) if self.hasAllocatedAngles(): cols.append( @@ -708,6 +799,7 @@ def read(cls, file_name): y=jnp.array(data["y"]), flux=jnp.array(data["flux"]), ) + photons._nokeep = jnp.array(data["_nokeep"]) if "dxdz" in names: photons.dxdz = jnp.array(data["dxdz"]) photons.dydz = jnp.array(data["dydz"]) diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 1b4ce1f4..d45a9169 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -22,6 +22,19 @@ - Within a single routine linking may work. - You may encounter errors related to global side effects for some combinations of linked states and jitted/vmapped routines. + +Seeding the JAX-GalSim PRNG can be done in a few ways: + + - pass seed=None (This is equivalent to passing seed=0) + - pass an integer seed (This method will throw errors if the integer is traced by JAX.) + - pass another JAX-GalSim PRNG + - pass a JAX PRNG key made via `jax.random.key`. + +**JAX PRNG keys made via `jax.random.PRNGKey` are not supported.** + +When using JAX-GalSim PRNGs and JIT, you should always return the PRNG from the function +and then set the state on input PRNG via `prng.reset(new_prng)`. This will ensure that the +PRNG state is propagated correctly outside the JITed code. """ @@ -33,8 +46,8 @@ class _DeviateState: Parameters ---------- - key : jax.random.PRNGKey - The JAX PRNG key made via `jrandom.PRNGKey` or equivalent. + key : key data with dtype `jax.dtypes.prng_key` + The JAX PRNG key made via `jrandom.key` """ def __init__(self, key): @@ -79,13 +92,13 @@ def generates_in_pairs(self): _galsim.BaseDeviate.seed, lax_description="The JAX version of this method does no type checking.", ) - def seed(self, seed=0): + def seed(self, seed=None): self._seed(seed=seed) @_wraps(_galsim.BaseDeviate._seed) - def _seed(self, seed=0): + def _seed(self, seed=None): _initial_seed = seed or secrets.randbelow(2**31) - self._state.key = jrandom.PRNGKey(_initial_seed) + self._state.key = jrandom.key(_initial_seed) @_wraps( _galsim.BaseDeviate.reset, @@ -96,8 +109,10 @@ def reset(self, seed=None): self._state = seed elif isinstance(seed, BaseDeviate): self._state = seed._state - elif isinstance(seed, jax.Array) and seed.shape == (2,): - self._state = _DeviateState(wrap_key_data(seed)) + elif hasattr(seed, "dtype") and jax.dtypes.issubdtype( + seed.dtype, jax.dtypes.prng_key + ): + self._state = _DeviateState(seed) elif isinstance(seed, str): self._state = _DeviateState( wrap_key_data(jnp.array(eval(seed), dtype=jnp.uint32)) @@ -108,7 +123,7 @@ def reset(self, seed=None): ) else: _initial_seed = seed or secrets.randbelow(2**31) - self._state = _DeviateState(jrandom.PRNGKey(_initial_seed)) + self._state = _DeviateState(jrandom.key(_initial_seed)) @property def _key(self): diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 47059b8f..05af6a36 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp from jax._src.numpy.util import _wraps +from jax.tree_util import tree_flatten from jax_galsim.errors import GalSimIncompatibleValuesError, GalSimValueError from jax_galsim.position import PositionD, PositionI @@ -11,27 +12,51 @@ printoptions = _galsim.utilities.printoptions +def has_tracers(x): + """Return True if the data is equal, False otherwise. Handles jax.Array types.""" + for item in tree_flatten(x)[0]: + if isinstance(item, jax.core.Tracer): + return True + return False + + @_wraps( _galsim.utilities.lazy_property, lax_description=( "The LAX version of this decorator uses an `_workspace` attribute " "attached to the object so that the cache can easily be discarded " - "for certain operations." + "for certain operations. It also will not cache jax.core.Tracer objects " + "in order to avoid side-effects in jit/grad/vmap transformations." ), ) -def lazy_property(func): - attname = func.__name__ + "_cached" - - @property - @functools.wraps(func) - def _func(self): - if not hasattr(self, "_workspace"): - self._workspace = {} - if attname not in self._workspace: - self._workspace[attname] = func(self) - return self._workspace[attname] - - return _func +def lazy_property(func_=None, cache_jax_tracers=False): + # see https://stackoverflow.com/a/57268935 + def _decorator(func): + attname = func.__name__ + "_cached" + + @property + @functools.wraps(func) + def wrapper(self): + if not hasattr(self, "_workspace"): + self._workspace = {} + if attname not in self._workspace: + val = func(self) + if cache_jax_tracers or (not has_tracers(val)): + self._workspace[attname] = val + else: + val = self._workspace[attname] + return val + + return wrapper + + if callable(func_): + return _decorator(func_) + elif func_ is None: + return _decorator + else: + raise RuntimeWarning( + "Positional arguments are not supported for the lazy_property decorator" + ) @_wraps(_galsim.utilities.parse_pos_args) diff --git a/tests/GalSim b/tests/GalSim index 710cca28..9e8d6565 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 710cca286c5fcd229d1c309aaf6e5c61ec81f9dc +Subproject commit 9e8d6565e88260586911339d1b3d8f32a7a8e1ba diff --git a/tests/jax/galsim/test_draw_jax.py b/tests/jax/galsim/test_draw_jax.py index cb93d249..fe4656e7 100644 --- a/tests/jax/galsim/test_draw_jax.py +++ b/tests/jax/galsim/test_draw_jax.py @@ -1115,19 +1115,22 @@ def test_shoot(): # in exact arithmetic. We had an assert there which blew up in a not very nice way. obj = galsim.Gaussian(sigma=0.2398318) + 0.1*galsim.Gaussian(sigma=0.47966352) obj = obj.withFlux(100001) - image1 = galsim.ImageF(32,32, init_value=100) + # JAX-Galsim adjusts the images to double here + image1 = galsim.ImageD(32,32, init_value=100) rng = galsim.BaseDeviate(1234) obj.drawImage(image1, method='phot', poisson_flux=False, add_to_image=True, rng=rng, maxN=100000) # The test here is really just that it doesn't crash. # But let's do something to check correctness. - image2 = galsim.ImageF(32,32) + # JAX-Galsim adjusts the images to double here + image2 = galsim.ImageD(32,32) rng = galsim.BaseDeviate(1234) obj.drawImage(image2, method='phot', poisson_flux=False, add_to_image=False, rng=rng, maxN=100000) image2 += 100 - np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=12) + # with double, we get the same result to 10 decimal places + np.testing.assert_array_almost_equal(image2.array, image1.array, decimal=10) # Also check that you get the same answer with a smaller maxN. image3 = galsim.ImageF(32,32, init_value=100) @@ -1141,13 +1144,15 @@ def test_shoot(): np.testing.assert_array_equal(image4.array, 0) # Warns if flux is 1 and n_photons not given. + # JAX-GalSim doesn't warn in this case psf = galsim.Gaussian(sigma=3) - with assert_warns(galsim.GalSimWarning): - psf.drawImage(method='phot') - with assert_warns(galsim.GalSimWarning): - psf.drawPhot(image4) - with assert_warns(galsim.GalSimWarning): - psf.makePhot() + # with assert_warns(galsim.GalSimWarning): + # psf.drawImage(method='phot') + # with assert_warns(galsim.GalSimWarning): + # psf.drawPhot(image4) + # with assert_warns(galsim.GalSimWarning): + # psf.makePhot() + # With n_photons=1, it's fine. psf.drawImage(method='phot', n_photons=1) psf.drawPhot(image4, n_photons=1) @@ -1204,23 +1209,24 @@ def test_drawImage_area_exptime(): msg = "obj.drawImage(method='phot') unexpectedly produced equal images with different rng" assert not np.allclose(im5.array, im4.array), msg - # Shooting with flux=1 raises a warning. - obj1 = obj.withFlux(1) - with assert_warns(galsim.GalSimWarning): - obj1.drawImage(method='phot') - # But not if we explicitly tell it to shoot 1 photon - with assert_raises(AssertionError): - assert_warns(galsim.GalSimWarning, obj1.drawImage, method='phot', n_photons=1) - # Likewise for makePhot - with assert_warns(galsim.GalSimWarning): - obj1.makePhot() - with assert_raises(AssertionError): - assert_warns(galsim.GalSimWarning, obj1.makePhot, n_photons=1) - # And drawPhot - with assert_warns(galsim.GalSimWarning): - obj1.drawPhot(im1) - with assert_raises(AssertionError): - assert_warns(galsim.GalSimWarning, obj1.drawPhot, im1, n_photons=1) + # JAX-GalSim doesn't raise for these things + # # Shooting with flux=1 raises a warning. + # obj1 = obj.withFlux(1) + # with assert_warns(galsim.GalSimWarning): + # obj1.drawImage(method='phot') + # # But not if we explicitly tell it to shoot 1 photon + # with assert_raises(AssertionError): + # assert_warns(galsim.GalSimWarning, obj1.drawImage, method='phot', n_photons=1) + # # Likewise for makePhot + # with assert_warns(galsim.GalSimWarning): + # obj1.makePhot() + # with assert_raises(AssertionError): + # assert_warns(galsim.GalSimWarning, obj1.makePhot, n_photons=1) + # # And drawPhot + # with assert_warns(galsim.GalSimWarning): + # obj1.drawPhot(im1) + # with assert_raises(AssertionError): + # assert_warns(galsim.GalSimWarning, obj1.drawPhot, im1, n_photons=1) @timer diff --git a/tests/jax/test_interpolatedimage_utils.py b/tests/jax/test_interpolatedimage_utils.py index 6c0200f0..c76278ef 100644 --- a/tests/jax/test_interpolatedimage_utils.py +++ b/tests/jax/test_interpolatedimage_utils.py @@ -329,10 +329,7 @@ def test_interpolatedimage_utils_jax_galsim_fft_vs_galsim_fft(n): Lanczos(7), ], ) -def test_interpolatedimage_interpolant_rejection_sample(interp): - from jax.tree_util import Partial as jax_partial - - from jax_galsim.interpolant import _rejection_sample +def test_interpolatedimage_interpolant_sample(interp): from jax_galsim.photon_array import PhotonArray from jax_galsim.random import BaseDeviate @@ -340,15 +337,7 @@ def test_interpolatedimage_interpolant_rejection_sample(interp): ntot = 1000000 photons = PhotonArray(ntot) - photons, _ = _rejection_sample( - photons, - rng, - interp.xrange * 2.0, - jax_partial(interp._xval_noraise), - interp.positive_flux, - interp.negative_flux, - interp._xval_noraise(0.0), - ) + interp._shoot(photons, rng) h, bins = jnp.histogram(photons.x, bins=500) mid = (bins[1:] + bins[:-1]) / 2.0 @@ -365,11 +354,12 @@ def test_interpolatedimage_interpolant_rejection_sample(interp): np.testing.assert_allclose(fdev[msk], 0, rtol=0, atol=5.0, err_msg=f"{interp}") np.testing.assert_allclose(fdev[~msk], 0, rtol=0, atol=15.0, err_msg=f"{interp}") - if interp.__class__.__name__ == "Quintic" and False: + if interp.__class__.__name__ in ["Quintic", "Lanczos"] and False: import proplot as pplt - fig, axs = pplt.subplots(figsize=(4, 4)) - axs.hist(photons.x, bins=500, log=False) + fig, axs = pplt.subplots(figsize=(6, 6)) + axs.hist(photons.x, bins=500, log=True) axs.plot(mid, yv, color="k") + axs.format(title=interp.__class__.__name__) fig.show() breakpoint() diff --git a/tests/jax/test_jitting.py b/tests/jax/test_jitting.py index 9b99b0a5..6a427f94 100644 --- a/tests/jax/test_jitting.py +++ b/tests/jax/test_jitting.py @@ -7,6 +7,7 @@ import jax_galsim as galsim from jax_galsim.core.draw import calculate_n_photons from jax_galsim.core.testing import time_code_block +from jax_galsim.photon_array import fixed_photon_array_size # Defining jitting identity identity = jax.jit(lambda x: x) @@ -239,7 +240,7 @@ def _build_and_draw(hlr, fwhm, jit=True): final._flux_per_photon, final.max_sb, poisson_flux=False, - )[0] + )[0].item() gain = 1.0 if jit: return _draw_it_jit(final, n, n_photons, gain) @@ -283,3 +284,67 @@ def _draw_it_jit(obj, n, nphotons, gain): img = _build_and_draw(0.5, 1.0) np.testing.assert_allclose(img.array.sum(), 1100.0) + + +def test_jitting_draw_phot_fixed(): + def _build_and_draw(hlr, fwhm, jit=True): + gal = galsim.Exponential( + half_light_radius=hlr, flux=1000.0 + ) + galsim.Exponential(half_light_radius=hlr * 2.0, flux=100.0) + psf = galsim.Gaussian(fwhm=fwhm, flux=1.0) + final = galsim.Convolve( + [gal, psf], + ) + n = final.getGoodImageSize(0.2).item() + n += 1 + n_photons = calculate_n_photons( + final.flux, + final._flux_per_photon, + final.max_sb, + poisson_flux=False, + )[0] + gain = 1.0 + if jit: + return _draw_it_jit(final, n, n_photons, gain) + else: + with fixed_photon_array_size(2048): + return final.drawImage( + nx=n, + ny=n, + scale=0.2, + method="phot", + n_photons=n_photons, + poisson_flux=False, + gain=gain, + ) + + @partial(jax.jit, static_argnums=(1, 2)) + def _draw_it_jit(obj, n, nphotons, gain): + with fixed_photon_array_size(2048): + return obj.drawImage( + nx=n, + ny=n, + scale=0.2, + n_photons=nphotons, + method="phot", + poisson_flux=False, + gain=gain, + maxN=101, + ) + + with time_code_block("warmup no-jit"): + img = _build_and_draw(0.5, 1.0, jit=False) + np.testing.assert_allclose(img.array.sum(), 1100.0) + + with time_code_block("no-jit"): + img = _build_and_draw(0.5, 1.0, jit=False) + np.testing.assert_allclose(img.array.sum(), 1100.0) + + with time_code_block("warmup jit"): + img = _build_and_draw(0.5, 1.0) + np.testing.assert_allclose(img.array.sum(), 1100.0) + + with time_code_block("jit"): + img = _build_and_draw(0.5, 1.0) + + np.testing.assert_allclose(img.array.sum(), 1100.0) diff --git a/tests/jax/test_photon_shooting_jax.py b/tests/jax/test_photon_shooting_jax.py index 9346a149..bb2dd3b2 100644 --- a/tests/jax/test_photon_shooting_jax.py +++ b/tests/jax/test_photon_shooting_jax.py @@ -1,3 +1,4 @@ +import galsim as _galsim import jax import jax.numpy as jnp import numpy as np @@ -6,6 +7,7 @@ import jax_galsim from jax_galsim.core.testing import time_code_block +from jax_galsim.photon_array import fixed_photon_array_size def test_photon_shooting_jax_make_from_image_notranspose(): @@ -168,3 +170,69 @@ def test_photon_shooting_jax_offset(offset): ) np.testing.assert_allclose(img_fft.array, img_phot.array, rtol=rtol, atol=atol) + + +def test_photon_shooting_jax_vmapping(): + n_stamps = 100 + rng = np.random.RandomState(1234) + shifts = jnp.array(rng.uniform(-1, 1, size=(n_stamps, 2))) + hlrs = jnp.array(rng.uniform(0.1, 1.0, size=(n_stamps,))) + fwhms = jnp.array(rng.uniform(0.9, 1.0, size=(n_stamps,))) + fluxes = jnp.array(rng.uniform(100, 1000, size=(n_stamps,))) + rng = jax_galsim.BaseDeviate(1234) + seeds = [] + for i in range(n_stamps): + seeds.append(jax.random.key(i + 1)) + max_n_phot = 2048 + seeds = jnp.array(seeds) + + @jax.jit + def _draw(hlr, fwhm, shift, flux, seed): + obj = jax_galsim.Convolve( + [ + jax_galsim.Exponential(half_light_radius=hlr, flux=flux).shift(*shift), + jax_galsim.Gaussian(fwhm=fwhm, flux=1.0), + ] + ) + with fixed_photon_array_size(max_n_phot): + return obj.drawImage( + nx=33, + ny=33, + scale=0.2, + method="phot", + rng=jax_galsim.BaseDeviate(seed), + ) + + with time_code_block("one warmup"): + img = _draw(hlrs[0], fwhms[0], shifts[0], fluxes[0], seeds[0]) + with time_code_block("one"): + img = _draw(hlrs[0], fwhms[0], shifts[0], fluxes[0], seeds[0]) + print(img.array.shape, img.bounds, img.array.sum(), fluxes[0]) + + _vmap_draw = jax.jit(jax.vmap(_draw, in_axes=(0, 0, 0, 0, 0))) + with time_code_block("vmap warmup"): + imgs = _vmap_draw(hlrs, fwhms, shifts, fluxes, seeds) + with time_code_block("vmap"): + imgs = _vmap_draw(hlrs, fwhms, shifts, fluxes, seeds) + print(imgs.array.shape) + + np.testing.assert_allclose(img.array.sum(), imgs.array[0].sum()) + + def _draw_galsim(hlr, fwhm, shift, flux, seed): + obj = _galsim.Convolve( + [ + _galsim.Exponential(half_light_radius=hlr, flux=flux).shift(*shift), + _galsim.Gaussian(fwhm=fwhm, flux=1.0), + ] + ) + return obj.drawImage( + nx=33, + ny=33, + scale=0.2, + method="phot", + rng=_galsim.BaseDeviate(seed), + ) + + with time_code_block("galsim"): + for i in range(n_stamps): + _draw_galsim(hlrs[i], fwhms[i], shifts[i], fluxes[i], i + 1)