Skip to content

Commit

Permalink
fix: wrap interpolant
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Sep 9, 2024
1 parent 05bdc49 commit 469e139
Showing 1 changed file with 28 additions and 94 deletions.
122 changes: 28 additions & 94 deletions jax_galsim/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,8 @@ def __setstate__(self, d):
self.__dict__ = d

@staticmethod
@implements(_galsim.interpolant.Interpolant.from_name)
def from_name(name, tol=None, gsparams=None):
"""A factory function to create an `Interpolant` of the correct type according to
the (string) name of the `Interpolant`.
This is mostly used to simplify how config files specify the `Interpolant` to use.
Valid names are:
- 'delta' = `Delta`
- 'nearest' = `Nearest`
- 'sinc' = `SincInterpolant`
- 'linear' = `Linear`
- 'cubic' = `Cubic`
- 'quintic' = `Quintic`
- 'lanczosN' = `Lanczos` (where N is an integer, given the ``n`` parameter)
In addition, if you want to specify the ``conserve_dc`` option for `Lanczos`, you can
append either T or F to represent ``conserve_dc = True/False`` (respectively). Otherwise,
the default ``conserve_dc=True`` is used.
Parameters:
name: The name of the interpolant to create.
tol: [deprecated]
gsparams: An optional `GSParams` argument. [default: None]
"""
if tol is not None:
from galsim.deprecated import depr

Expand Down Expand Up @@ -119,19 +96,20 @@ def from_name(name, tol=None, gsparams=None):
)

@property
@implements(_galsim.interpolant.Interpolant.gsparams)
def gsparams(self):
"""The `GSParams` of the `Interpolant`"""
return self._gsparams

@property
@implements(_galsim.interpolant.Interpolant.tol)
def tol(self):
from galsim.deprecated import depr

depr("interpolant.tol", 2.2, "interpolant.gsparams.kvalue_accuracy")
return self._gsparams.kvalue_accuracy

@implements(_galsim.interpolant.Interpolant.withGSParams)
def withGSParams(self, gsparams=None, **kwargs):
"""Create a version of the current interpolant with the given gsparams"""
if gsparams == self.gsparams:
return self
# Checking gsparams
Expand Down Expand Up @@ -165,17 +143,8 @@ def __ne__(self, other):
def __hash__(self):
return hash(repr(self))

@implements(_galsim.interpolant.Interpolant.xval)
def xval(self, x):
"""Calculate the value of the interpolant kernel at one or more x values
Parameters:
x: The value (as a float) or values (as a np.array) at which to compute the
amplitude of the Interpolant kernel.
Returns:
xval: The value(s) at the x location(s). If x was an array, then this is also
an array.
"""
if jnp.ndim(x) > 1:
raise GalSimValueError("xval only takes scalar or 1D array values", x)

Expand All @@ -184,17 +153,8 @@ def xval(self, x):
def _xval_noraise(self, x):
return self.__class__._xval(x)

@implements(_galsim.interpolant.Interpolant.kval)
def kval(self, k):
"""Calculate the value of the interpolant kernel in Fourier space at one or more k values.
Parameters:
k: The value (as a float) or values (as a np.array) at which to compute the
amplitude of the Interpolant kernel in Fourier space.
Returns:
kval: The k-value(s) at the k location(s). If k was an array, then this is also
an array.
"""
if jnp.ndim(k) > 1:
raise GalSimValueError("kval only takes scalar or 1D array values", k)

Expand All @@ -203,17 +163,8 @@ def kval(self, k):
def _kval_noraise(self, k):
return self.__class__._uval(k / 2.0 / jnp.pi)

@implements(_galsim.interpolant.Interpolant.unit_integrals)
def unit_integrals(self, max_len=None):
"""Compute the unit integrals of the real-space kernel.
integrals[i] = int(xval(x), i-0.5, i+0.5)
Parameters:
max_len: The maximum length of the returned array. (ignored)
Returns:
integrals: An array of unit integrals of length max_len or smaller.
"""
return self._unit_integrals

def tree_flatten(self):
Expand All @@ -233,14 +184,15 @@ def tree_unflatten(cls, aux_data, children):
return cls(**aux_data)

@property
@implements(_galsim.interpolant.Interpolant.positive_flux)
def positive_flux(self):
"""The positive-flux fraction of the interpolation kernel."""
if not hasattr(self, "_positive_flux"):
# subclasses can define this method if _positive_flux is not set
self._comp_fluxes()
return self._positive_flux

@property
@implements(_galsim.interpolant.Interpolant.negative_flux)
def negative_flux(self):
"""The negative-flux fraction of the interpolation kernel."""
if not hasattr(self, "_negative_flux"):
Expand Down Expand Up @@ -350,13 +302,13 @@ def urange(self):
return 1.0 / self._gsparams.kvalue_accuracy

@property
@implements(_galsim.interpolant.Delta.xrange)
def xrange(self):
"""The maximum extent of the interpolant from the origin (in pixels)."""
return 0.0

@property
@implements(_galsim.interpolant.Delta.ixrange)
def ixrange(self):
"""The total integral range of the interpolant. Typically 2 * xrange."""
return 0

def _shoot(self, photons, rng):
Expand Down Expand Up @@ -395,13 +347,13 @@ def urange(self):
return 1.0 / (np.pi * self._gsparams.kvalue_accuracy)

@property
@implements(_galsim.interpolant.Nearest.xrange)
def xrange(self):
"""The maximum extent of the interpolant from the origin (in pixels)."""
return 0.5

@property
@implements(_galsim.interpolant.Nearest.ixrange)
def ixrange(self):
"""The total integral range of the interpolant. Typically 2 * xrange."""
return 1

def _shoot(self, photons, rng):
Expand Down Expand Up @@ -469,22 +421,13 @@ def urange(self):
return 0.5

@property
@implements(_galsim.interpolant.SincInterpolant.xrange)
def xrange(self):
"""The maximum extent of the interpolant from the origin (in pixels)."""
# Technically infinity, but truncated by the tolerance.
return 1.0 / (np.pi * self._gsparams.kvalue_accuracy)

@implements(_galsim.interpolant.SincInterpolant.unit_integrals)
def unit_integrals(self, max_len=None):
"""Compute the unit integrals of the real-space kernel.
integrals[i] = int(xval(x), i-0.5, i+0.5)
Parameters:
max_len: The maximum length of the returned array.
Returns:
integrals: An array of unit integrals of length max_len or smaller.
"""
n = self.ixrange // 2 + 1
n = n if max_len is None else min(n, max_len)
return _sinc_unit_integrals(self.ixrange)[:n]
Expand Down Expand Up @@ -535,13 +478,13 @@ def urange(self):
return 1.0 / np.sqrt(self._gsparams.kvalue_accuracy) / np.pi

@property
@implements(_galsim.interpolant.Linear.xrange)
def xrange(self):
"""The maximum extent of the interpolant from the origin (in pixels)."""
return 1.0

@property
@implements(_galsim.interpolant.Linear.ixrange)
def ixrange(self):
"""The total integral range of the interpolant. Typically 2 * xrange."""
return 2

def _shoot(self, photons, rng):
Expand Down Expand Up @@ -606,13 +549,13 @@ def urange(self):
)

@property
@implements(_galsim.interpolant.Cubic.xrange)
def xrange(self):
"""The maximum extent of the interpolant from the origin (in pixels)."""
return 2.0

@property
@implements(_galsim.interpolant.ixrange)
def ixrange(self):
"""The total integral range of the interpolant. Typically 2 * xrange."""
return 4


Expand Down Expand Up @@ -701,13 +644,13 @@ def urange(self):
)

@property
@implements(_galsim.interpolant.Quintic.xrange)
def xrange(self):
"""The maximum extent of the interpolant from the origin (in pixels)."""
return 3.0

@property
@implements(_galsim.interpolant.Quintic.ixrange)
def ixrange(self):
"""The total integral range of the interpolant. Typically 2 * xrange."""
return 6


Expand Down Expand Up @@ -1573,52 +1516,43 @@ def urange(self):
return self._umax

@property
@implements(_galsim.interpolant.Lanczos.n)
def n(self):
"""The order of the Lanczos function."""
return self._n

@property
@implements(_galsim.interpolant.Lanczos.conserve_dc)
def conserve_dc(self):
"""Whether this interpolant is modified to improve flux conservation."""
return self._conserve_dc

@property
@implements(_galsim.interpolant.Lanczos.xrange)
def xrange(self):
"""The maximum extent of the interpolant from the origin (in pixels)."""
return self._n

@property
@implements(_galsim.interpolant.Lanczos.ixrange)
def ixrange(self):
"""The total integral range of the interpolant. Typically 2 * xrange."""
return 2 * self._n

@property
@implements(_galsim.interpolant.Lanczos.positive_flux)
def positive_flux(self):
"""The positive-flux fraction of the interpolation kernel."""
if self._conserve_dc:
return self._posflux_conserve_dc[self._n]
else:
return self._posflux_no_conserve_dc[self._n]

@property
@implements(_galsim.interpolant.Lanczos.negative_flux)
def negative_flux(self):
"""The negative-flux fraction of the interpolation kernel."""
if self._conserve_dc:
return self._negflux_conserve_dc[self._n]
else:
return self._negflux_no_conserve_dc[self._n]

@implements(_galsim.interpolant.Lanczos.unit_integrals)
def unit_integrals(self, max_len=None):
"""Compute the unit integrals of the real-space kernel.
integrals[i] = int(xval(x), i-0.5, i+0.5)
Parameters:
max_len: The maximum length of the returned array. (ignored)
Returns:
integrals: An array of unit integrals of length max_len or smaller.
"""
if max_len is not None and max_len < self._n + 1:
n = max_len
else:
Expand Down

0 comments on commit 469e139

Please sign in to comment.