diff --git a/jax_galsim/gsparams.py b/jax_galsim/gsparams.py index 89a4fd37..69499371 100644 --- a/jax_galsim/gsparams.py +++ b/jax_galsim/gsparams.py @@ -47,12 +47,8 @@ def from_galsim(cls, gsparams): ) @staticmethod + @implements(_galsim.GSParams.check) def check(gsparams, default=None, **kwargs): - """Checks that gsparams is either a valid GSParams instance or None. - - In the former case, it returns gsparams, in the latter it returns default - (GSParams.default if no other default specified). - """ if gsparams is None: if default is not None: if isinstance(default, GSParams): @@ -65,10 +61,8 @@ def check(gsparams, default=None, **kwargs): raise TypeError("Invalid GSParams: %s" % gsparams) return gsparams.withParams(**kwargs) + @implements(_galsim.GSParams.withParams) def withParams(self, **kwargs): - """Return a `GSParams` that is identical to the current one except for any keyword - arguments given here, which supersede the current value. - """ if len(kwargs) == 0: return self else: @@ -80,12 +74,8 @@ def withParams(self, **kwargs): return GSParams(**d) @staticmethod + @implements(_galsim.GSParams.combine) def combine(gsp_list): - """Combine a list of `GSParams` instances using the most restrictive parameter from each. - - Uses the minimum value for most parameters. For the following parameters, it uses the - maximum numerical value: minimum_fft_size, maximum_fft_size, stepk_minimum_hlr. - """ if len(gsp_list) == 1: return gsp_list[0] elif all(g == gsp_list[0] for g in gsp_list[1:]): diff --git a/jax_galsim/image.py b/jax_galsim/image.py index acd701cf..f8494a99 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -325,64 +325,56 @@ def __str__(self): # Read-only attributes: @property + @implements(_galsim.Image.dtype) def dtype(self): - """The dtype of the underlying numpy array.""" return self._dtype @property + @implements(_galsim.Image.bounds) def bounds(self): - """The bounds of the `Image`.""" return self._bounds @property + @implements(_galsim.Image.array) def array(self): - """The underlying numpy array.""" return self._array @property + @implements(_galsim.Image.nrow) def nrow(self): - """The number of rows in the image""" return self._array.shape[0] @property + @implements(_galsim.Image.ncol) def ncol(self): - """The number of columns in the image""" return self._array.shape[1] @property + @implements(_galsim.Image.isconst) def isconst(self): - """Whether the `Image` is constant. I.e. modifying its values is an error.""" return self._is_const @property + @implements(_galsim.Image.iscomplex) def iscomplex(self): - """Whether the `Image` values are complex.""" return self._array.dtype.kind == "c" @property + @implements(_galsim.Image.isinteger) def isinteger(self): - """Whether the `Image` values are integral.""" return self._array.dtype.kind in ("i", "u") @property + @implements( + _galsim.Image.iscontiguous, lax_description="In JAX all arrays are contiguous." + ) def iscontiguous(self): - """Indicates whether each row of the image is contiguous in memory. - - Note: it is ok for the end of one row to not be contiguous with the start of the - next row. This just checks that each individual row has a stride of 1. - """ return True # In JAX all arrays are contiguous (almost) # Allow scale to work as a PixelScale wcs. @property + @implements(_galsim.Image.scale) def scale(self): - """The pixel scale of the `Image`. Only valid if the wcs is a `PixelScale`. - - If the WCS is either not set (i.e. it is ``None``) or it is a `PixelScale`, then - it is permissible to change the scale with:: - - >>> image.scale = new_pixel_scale - """ try: return self.wcs.scale except Exception: @@ -404,57 +396,43 @@ def scale(self, value): # Convenience functions @property + @implements(_galsim.Image.xmin) def xmin(self): - """Alias for self.bounds.xmin.""" return self._bounds.xmin @property + @implements(_galsim.Image.xmax) def xmax(self): - """Alias for self.bounds.xmax.""" return self._bounds.xmax @property + @implements(_galsim.Image.ymin) def ymin(self): - """Alias for self.bounds.ymin.""" return self._bounds.ymin @property + @implements(_galsim.Image.ymax) def ymax(self): - """Alias for self.bounds.ymax.""" return self._bounds.ymax @property + @implements(_galsim.Image.outer_bounds) def outer_bounds(self): - """The bounds of the outer edge of the pixels. - - Equivalent to galsim.BoundsD(im.xmin-0.5, im.xmax+0.5, im.ymin-0.5, im.ymax+0.5) - """ return BoundsD( self.xmin - 0.5, self.xmax + 0.5, self.ymin - 0.5, self.ymax + 0.5 ) # real, imag for everything, even real images. @property + @implements(_galsim.Image.real) def real(self): - """Return the real part of an image. - - This is a property, not a function. So write ``im.real``, not ``im.real()``. - - This works for real or complex. For real images, it acts the same as `view`. - """ return self.__class__( self.array.real, bounds=self.bounds, wcs=self.wcs, make_const=self._is_const ) @property + @implements(_galsim.Image.imag) def imag(self): - """Return the imaginary part of an image. - - This is a property, not a function. So write ``im.imag``, not ``im.imag()``. - - This works for real or complex. For real images, the returned array is read-only and - all elements are 0. - """ return self.__class__( self.array.imag, bounds=self.bounds, @@ -464,26 +442,16 @@ def imag(self): ) @property + @implements(_galsim.Image.conjugate) def conjugate(self): - """Return the complex conjugate of an image. - - This works for real or complex. For real images, it acts the same as `view`. - - Note that for complex images, this is not a conjugate view into the original image. - So changing the original image does not change the conjugate (or vice versa). - """ return self.__class__(self.array.conjugate(), bounds=self.bounds, wcs=self.wcs) + @implements(_galsim.Image.copy) def copy(self): - """Make a copy of the `Image`""" return self.__class__(self.array.copy(), bounds=self.bounds, wcs=self.wcs) + @implements(_galsim.Image.get_pixel_centers) def get_pixel_centers(self): - """A convenience function to get the x and y values at the centers of the image pixels. - - Returns: - (x, y), each of which is a numpy array the same shape as ``self.array`` - """ x, y = jnp.meshgrid( jnp.arange(self.array.shape[1], dtype=float), jnp.arange(self.array.shape[0], dtype=float), @@ -500,19 +468,8 @@ def _make_empty(self, shape, dtype): else: return jnp.zeros(shape=shape, dtype=dtype) + @implements(_galsim.Image.resize) def resize(self, bounds, wcs=None): - """Resize the image to have a new bounds (must be a `BoundsI` instance) - - Note that the resized image will have uninitialized data. If you want to preserve - the existing data values, you should either use `subImage` (if you want a smaller - portion of the current `Image`) or make a new `Image` and copy over the current values - into a portion of the new image (if you are resizing to a larger `Image`). - - Parameters: - bounds: The new bounds to resize to. - wcs: If provided, also update the wcs to the given value. [default: None, - which means keep the existing wcs] - """ if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): @@ -522,11 +479,8 @@ def resize(self, bounds, wcs=None): if wcs is not None: self.wcs = wcs + @implements(_galsim.Image.subImage) def subImage(self, bounds): - """Return a view of a portion of the full image - - This is equivalent to self[bounds] - """ if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") if not self.bounds.isDefined(): @@ -548,11 +502,8 @@ def subImage(self, bounds): # reorigin that you need to update the wcs. So that's taken care of in im.shift. return self.__class__(subarray, bounds=bounds, wcs=self.wcs) + @implements(_galsim.Image.setSubImage) def setSubImage(self, bounds, rhs): - """Set a portion of the full image to the values in another image - - This is equivalent to self[bounds] = rhs - """ if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): @@ -670,11 +621,8 @@ def wrap(self, bounds, hermitian=False): "Invalid value for hermitian", hermitian, (False, "x", "y") ) + @implements(_galsim.Image._wrap) def _wrap(self, bounds, hermx, hermy): - """A version of `wrap` without the sanity checks. - - Equivalent to ``image.wrap(bounds, hermitian=='x', hermitian=='y')``. - """ if not hermx and not hermy: from jax_galsim.core.wrap_image import wrap_nonhermitian @@ -826,13 +774,8 @@ def calculate_inverse_fft(self): return out @classmethod + @implements(_galsim.Image.good_fft_size) def good_fft_size(cls, input_size): - """Round the given input size up to the next higher power of 2 or 3 times a power of 2. - - This rounds up to the next higher value that is either 2^k or 3*2^k. If you are - going to be performing FFTs on an image, these will tend to be faster at performing - the FFT. - """ # we use the math module here since this function should not be jitted. import math @@ -848,8 +791,8 @@ def good_fft_size(cls, input_size): Nk = max(int(math.ceil(math.exp(min(log2n, log2n3)) - 1.0e-5)), 2) return Nk + @implements(_galsim.Image.copyFrom) def copyFrom(self, rhs): - """Copy the contents of another image""" if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(rhs, Image): @@ -924,13 +867,8 @@ def shift(self, *args, **kwargs): delta = parse_pos_args(args, kwargs, "dx", "dy", integer=True) self._shift(delta) + @implements(_galsim.Image._shift) def _shift(self, delta): - """Equivalent to `shift`, but without some of the sanity checks and ``delta`` must - be a `PositionI` instance. - - Parameters: - delta: The amount to shift as a `PositionI`. - """ self._bounds = self._bounds.shift(delta) if self.wcs is not None: self.wcs = self.wcs.shiftOrigin(delta) @@ -983,10 +921,8 @@ def getValue(self, x, y): ) return self._getValue(x, y) + @implements(_galsim.Image._getValue) def _getValue(self, x, y): - """Equivalent to `getValue`, except there are no checks that the values fall - within the bounds of the image. - """ return self.array[y - self.ymin, x - self.xmin] @implements(_galsim.Image.setValue) @@ -1006,15 +942,8 @@ def setValue(self, *args, **kwargs): ) self._setValue(pos.x, pos.y, value) + @implements(_galsim.Image._setValue) def _setValue(self, x, y, value): - """Equivalent to `setValue` except that there are no checks that the values - fall within the bounds of the image, and the coordinates must be given as ``x``, ``y``. - - Parameters: - x: The x coordinate of the pixel to set. - y: The y coordinate of the pixel to set. - value: The value to set the pixel to. - """ self._array = self._array.at[y - self.ymin, x - self.xmin].set(value) @implements(_galsim.Image.addValue) @@ -1034,23 +963,12 @@ def addValue(self, *args, **kwargs): ) self._addValue(pos.x, pos.y, value) + @implements(_galsim.Image._addValue) def _addValue(self, x, y, value): - """Equivalent to `addValue` except that there are no checks that the values - fall within the bounds of the image, and the coordinates must be given as ``x``, ``y``. - - Parameters: - x: The x coordinate of the pixel to add to. - y: The y coordinate of the pixel to add to. - value: The value to add to this pixel. - """ self._array = self._array.at[y - self.ymin, x - self.xmin].add(value) + @implements(_galsim.Image.fill) def fill(self, value): - """Set all pixel values to the given ``value`` - - Parameter: - value: The value to set all the pixels to. - """ if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not self.bounds.isDefined(): @@ -1059,22 +977,18 @@ def fill(self, value): ) self._fill(value) + @implements(_galsim.Image._fill) def _fill(self, value): - """Equivalent to `fill`, except that there are no checks that the bounds are defined.""" self._array = jnp.zeros_like(self._array) + value + @implements(_galsim.Image.setZero) def setZero(self): - """Set all pixel values to zero.""" if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) self._fill(0) + @implements(_galsim.Image.invertSelf) def invertSelf(self): - """Set all pixel values to their inverse: x -> 1/x. - - Note: any pixels whose value is 0 originally are ignored. They remain equal to 0 - on the output, rather than turning into inf. - """ if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not self.bounds.isDefined(): @@ -1083,21 +997,14 @@ def invertSelf(self): ) self._invertSelf() + @implements(_galsim.Image._invertSelf) def _invertSelf(self): - """Equivalent to `invertSelf`, except that there are no checks that the bounds are defined.""" array = 1.0 / self._array array = array.at[jnp.isinf(array)].set(0.0) self._array = array.astype(self._array.dtype) + @implements(_galsim.Image.replaceNegative) def replaceNegative(self, replace_value=0): - """Replace any negative values currently in the image with 0 (or some other value). - - Sometimes FFT drawing can result in tiny negative values, which may be undesirable for - some purposes. This method replaces those values with 0 or some other value if desired. - - Parameters: - replace_value: The value with which to replace any negative pixels. [default: 0] - """ if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) self._array = self.array.at[self.array < 0].set(replace_value) @@ -1214,48 +1121,80 @@ def _Image(array, bounds, wcs): # These are essentially aliases for the regular Image with the correct dtype +@implements( + _galsim.ImageUS, + lax_description=IMAGE_LAX_DOCS, +) def ImageUS(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.uint16)""" kwargs["dtype"] = jnp.uint16 return Image(*args, **kwargs) +@implements( + _galsim.ImageUI, + lax_description=IMAGE_LAX_DOCS, +) def ImageUI(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.uint32)""" kwargs["dtype"] = jnp.uint32 return Image(*args, **kwargs) +@implements( + _galsim.ImageS, + lax_description=IMAGE_LAX_DOCS, +) def ImageS(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.int16)""" kwargs["dtype"] = jnp.int16 return Image(*args, **kwargs) +@implements( + _galsim.ImageI, + lax_description=IMAGE_LAX_DOCS, +) def ImageI(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.int32)""" kwargs["dtype"] = jnp.int32 return Image(*args, **kwargs) +@implements( + _galsim.ImageF, + lax_description=IMAGE_LAX_DOCS, +) def ImageF(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.float32)""" kwargs["dtype"] = jnp.float32 return Image(*args, **kwargs) +@implements( + _galsim.ImageD, + lax_description=IMAGE_LAX_DOCS, +) def ImageD(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.float64)""" kwargs["dtype"] = jnp.float64 return Image(*args, **kwargs) +@implements( + _galsim.ImageCF, + lax_description=IMAGE_LAX_DOCS, +) def ImageCF(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.complex64)""" kwargs["dtype"] = jnp.complex64 return Image(*args, **kwargs) +@implements( + _galsim.ImageCD, + lax_description=IMAGE_LAX_DOCS, +) def ImageCD(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.complex128)""" kwargs["dtype"] = jnp.complex128