Skip to content

Commit

Permalink
added docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
sidd3888 committed Aug 14, 2024
1 parent e1e08e4 commit b690d71
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 0 deletions.
70 changes: 70 additions & 0 deletions src/multinterp/backend/_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@


def cupy_multinterp(grids, values, args, options=None):
"""
Perform multivariate interpolation using CuPy.
Parameters
----------
grids : array-like
Grid points in the domain.
values : array-like
Functional values at the grid points.
args : array-like
Points at which to interpolate data.
options : dict, optional
Additional options for interpolation.
Returns
-------
array-like
Interpolated values of the function.
"""
mc_kwargs = update_mc_kwargs(options)

args = cp.asarray(args)
Expand All @@ -18,6 +37,27 @@ def cupy_multinterp(grids, values, args, options=None):


def cupy_gradinterp(grids, values, args, axis=None, options=None):
"""
Computes the interpolated value of the gradient evaluated at specified points using CuPy.
Parameters
----------
grids : list of array-like
Grid points in the domain.
values : array-like
Functional values at the grid points.
args : array-like
Points at which to interpolate data.
axis : int, optional
Axis along which to compute the gradient.
options : dict, optional
Additional options for interpolation.
Returns
-------
array-like
Interpolated values of the gradient.
"""
mc_kwargs = update_mc_kwargs(options)
eo = options.get("edge_order", 1) if options else 1

Expand All @@ -40,6 +80,21 @@ def cupy_gradinterp(grids, values, args, axis=None, options=None):


def cupy_get_coordinates(grids, args):
"""
Takes input values and converts them to coordinates with respect to the specified grid.
Parameters
----------
grids : cp.array
Grid points for each dimension.
args : cp.array
Points at which to interpolate data.
Returns
-------
cp.array
Coordinates with respect to the grid.
"""
coords = cp.empty_like(args)
for dim, grid in enumerate(grids):
grid_size = cp.arange(grid.size)
Expand All @@ -49,6 +104,21 @@ def cupy_get_coordinates(grids, args):


def cupy_map_coordinates(values, coords, **kwargs):
"""
Run the map_coordinates function from the cupyx.scipy.ndimage module on the specified values.
Parameters
----------
values : cp.array
Functional values from which to interpolate.
coords : cp.array
Coordinates at which to interpolate values.
Returns
-------
cp.array
Interpolated values.
"""
original_shape = coords[0].shape
coords = coords.reshape(len(values.shape), -1)
output = map_coordinates(values, coords, **kwargs)
Expand Down
57 changes: 57 additions & 0 deletions src/multinterp/backend/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,27 @@ def jax_multinterp(grids, values, args, options=None):


def jax_gradinterp(grids, values, args, axis=None, options=None):
"""
Computes the interpolated value of the gradient evaluated at specified points using JAX.
Parameters
----------
grids : list of array-like
Grid points in the domain.
values : array-like
Functional values at the grid points.
args : array-like
Points at which to interpolate data.
axis : int, optional
Axis along which to compute the gradient.
options : dict, optional
Additional options for interpolation.
Returns
-------
array-like
Interpolated values of the gradient.
"""
mc_kwargs = update_mc_kwargs(options, jax=True)
eo = options.get("edge_order", 1) if options else 1

Expand All @@ -64,6 +85,21 @@ def jax_gradinterp(grids, values, args, axis=None, options=None):

@jit
def jax_get_coordinates(grids, args):
"""
Takes input values and converts them to coordinates with respect to the specified grid.
Parameters
----------
grids : jnp.array
Grid points for each dimension.
args : jnp.array
Points at which to interpolate data.
Returns
-------
jnp.array
Coordinates of the specified input points with respect to the grid.
"""
grid_sizes = [jnp.arange(grid.size) for grid in grids]
return jnp.array(
[
Expand All @@ -75,6 +111,27 @@ def jax_get_coordinates(grids, args):

@functools.partial(jit, static_argnums=(2, 3, 4))
def jax_map_coordinates(values, coords, order=None, mode=None, cval=None):
"""
Run the map_coordinates function from the jax.scipy.ndimage module on the specified values.
Parameters
----------
values : jnp.array
The functional values from which to interpolate.
coords : jnp.array
The coordinates at which to interpolate the values.
order : int, optional
The order of interpolation, 0 for Nearest-Neighbour, 1 for Linear.
mode : str, optional
Method to handle extrapolation. See JAX documentation for options.
cval : float, optional
Value to use for extrapolation under 'constant' method.
Returns
-------
jnp.array
Interpolated values at specified coordinates.
"""
original_shape = coords[0].shape
coords = coords.reshape(len(values.shape), -1)
output = map_coordinates(values, coords, order, mode, cval)
Expand Down
37 changes: 37 additions & 0 deletions src/multinterp/backend/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@


def numba_multinterp(grids, values, args, options=None):
"""
Perform multivariate interpolation using JIT-compiled functions with Numba.
Parameters
----------
grids : array-like
Grid points in the domain.
values: array-like
Functional values at the grid points.
args: array-like
Points at which to interpolate data.
options: dict, optional
Additional options for interpolation.
Returns
-------
array-like
Interpolated values of the function.
"""
mc_kwargs = update_mc_kwargs(options)

args = np.asarray(args)
Expand All @@ -20,6 +39,21 @@ def numba_multinterp(grids, values, args, options=None):

@njit(parallel=True, cache=True, fastmath=True)
def numba_get_coordinates(grids, args):
"""
Converts input arguments to coordinates with respect to the specified grid. JIT-compiled using Numba.
Parameters
----------
grids : typed.List
Curvilinear grids for each dimension.
args : np.ndarray
Values in the domain at which the function is to be interpolated.
Returns
-------
np.ndarray
Coordinates of the input arguments.
"""
coords = np.empty_like(args)
for dim in prange(len(grids)):
grid_size = np.arange(grids[dim].size)
Expand All @@ -30,6 +64,9 @@ def numba_get_coordinates(grids, args):

# same as scipy map_coordinates until replacement is found
def numba_map_coordinates(values, coords, **kwargs):
"""
Identical to scipy_map_coordinates until a replacement is found. See documentation for scipy_map_coordinates.
"""
original_shape = coords[0].shape
coords = coords.reshape(len(values.shape), -1)
output = map_coordinates(values, coords, **kwargs)
Expand Down
70 changes: 70 additions & 0 deletions src/multinterp/backend/_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@


def scipy_multinterp(grids, values, args, options=None):
"""
Perform multivariate interpolation using SciPy.
Parameters
----------
grids : list of array-like
Grid points in the domain.
values : array-like
Functional values at the grid points.
args : array-like
Points at which to interpolate data.
options : dict, optional
Additional options for interpolation.
Returns
-------
array-like
Interpolated values of the function.
"""
mc_kwargs = update_mc_kwargs(options)

args = np.asarray(args)
Expand All @@ -18,6 +37,27 @@ def scipy_multinterp(grids, values, args, options=None):


def scipy_gradinterp(grids, values, args, axis=None, options=None):
"""
Computes the interpolated value of the gradient evaluated at specified points using SciPy.
Parameters
----------
grids : list of array-like
Grid points in the domain.
values : array-like
Functional values at the grid points.
args : array-like
Points at which to interpolate data.
axis : int, optional
Axis along which to compute the gradient.
options : dict, optional
Additional options for interpolation.
Returns
-------
array-like
Interpolated values of the gradient.
"""
mc_kwargs = update_mc_kwargs(options)
eo = options.get("edge_order", 1) if options else 1

Expand All @@ -40,6 +80,21 @@ def scipy_gradinterp(grids, values, args, axis=None, options=None):


def scipy_get_coordinates(grids, args):
"""
Takes input values and converts them to coordinates with respect to the specified grid.
Parameters
----------
grids : np.array
Grid points for each dimension.
args : np.array
Points at which to interpolate data.
Returns
-------
np.array
Coordinates with respect to the grid.
"""
coords = np.empty_like(args)
for dim, grid in enumerate(grids):
grid_size = np.arange(grid.size)
Expand All @@ -49,6 +104,21 @@ def scipy_get_coordinates(grids, args):


def scipy_map_coordinates(values, coords, **kwargs):
"""
Run the map_coordinates function from the scipy.ndimage module on the specified values.
Parameters
----------
values : np.array
Functional values from which to interpolate.
coords : np.array
Coordinates at which to interpolate values.
Returns
-------
np.array
Interpolated values of the function.
"""
original_shape = coords[0].shape
coords = coords.reshape(len(values.shape), -1)
output = map_coordinates(values, coords, **kwargs)
Expand Down
28 changes: 28 additions & 0 deletions src/multinterp/curvilinear/_scikit_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,22 @@


class PiecewiseAffineInterp(_CurvilinearGrid, MultivariateInterp):
"""Curvilinear interpolator that uses the PiecewiseAffineTransform from scikit-image."""

def __init__(self, values, grids, options=None):
"""
Initialize a PiecewiseAffineInterp object.
Parameters
----------
values : np.ndarray
Functional values on a curvilinear grid.
grids : np.ndarray
Coordinates of the points in the curvilinear grid.
options : dict, optional
Additional keyword arguments to pass to the map_coordinates backend.
"""

super().__init__(values, grids, backend="scipy")
self.mc_kwargs = update_mc_kwargs(options)

Expand All @@ -25,6 +40,19 @@ def __init__(self, values, grids, options=None):
self.interpolator = interpolator

def _get_coordinates(self, args):
"""Obtain the index coordinates for each of the arguments.
Parameters
----------
args : np.ndarray
Arguments to be interpolated.
Returns
-------
np.ndarray
Index coordinates for each of the arguments.
"""

_input = args.reshape((self.ndim, -1)).T
output = self.interpolator(_input).T.copy()
return output.reshape(args.shape)
1 change: 1 addition & 0 deletions src/multinterp/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, values, backend="scipy"):


class _StructuredGrid(_AbstractGrid):
"""Abstract class for interpolating on a structured grid. Serves as a base class for regular and abstract unstructured grid interpolators."""
def __init__(self, values, grids, backend="scipy"):
super().__init__(values, backend=backend)

Expand Down
Loading

0 comments on commit b690d71

Please sign in to comment.