Skip to content

Commit

Permalink
Merge pull request #25 from f0uriest/rc/scipy
Browse files Browse the repository at this point in the history
Add ``scipy.interpolate`` API
  • Loading branch information
f0uriest authored Mar 5, 2024
2 parents 3c66c30 + 5ced8f9 commit e178cba
Show file tree
Hide file tree
Showing 13 changed files with 1,820 additions and 69 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
Changelog
=========

- Adds a number of classes that replicate most of the functionality of the
corresponding classes from scipy.interpolate :
- ``scipy.interpolate.PPoly`` -> ``interpax.PPoly``
- ``scipy.interpolate.Akima1DInterpolator`` -> ``interpax.Akima1DInterpolator``
- ``scipy.interpolate.CubicHermiteSpline`` -> ``interpax.CubicHermiteSpline``
- ``scipy.interpolate.CubicSpline`` -> ``interpax.CubicSpline``
- ``scipy.interpolate.PchipInterpolator`` -> ``interpax.PchipInterpolator``
- Method ``"akima"`` now available for ``Interpolator.{1D, 2D, 3D}`` and corresponding
functions.
- Method ``"monotonic"`` now works in 2D and 3D, where it will preserve monotonicity
with respect to each coordinate individually.


v0.2.4
------
- Fixes for scalar valued query points
Expand Down
5 changes: 5 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ BUILDDIR = _build
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)


clean:
rm -rf _api/
rm -rf _build/

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
Expand Down
5 changes: 3 additions & 2 deletions docs/_templates/class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
.. autosummary::
:toctree: {{ objname }}

{% for item in methods %}
{% if item != "__init__" %}

{% for item in all_methods %}
{%- if not item.startswith('_') or item in ['__call__',] %}
~{{ name }}.{{ item }}
{% endif %}
{%- endfor %}
Expand Down
81 changes: 54 additions & 27 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,65 @@
API Documentation
=================

interp1d
********
.. autofunction:: interpax.interp1d
Interpolation of 1D, 2D, or 3D data
-----------------------------------

interp2d
********
.. autofunction:: interpax.interp2d
.. autosummary::
:toctree: _api/
:recursive:
:template: class.rst

interp3d
********
.. autofunction:: interpax.interp3d
interpax.Interpolator1D
interpax.Interpolator2D
interpax.Interpolator3D

fft_interp1d
************
.. autofunction:: interpax.fft_interp1d

fft_interp2d
************
.. autofunction:: interpax.fft_interp2d
``scipy.interpolate``-like classes
----------------------------------

approx_df
*********
.. autofunction:: interpax.approx_df
These classes implement most of the functionality of the SciPy classes with the same names,
except where noted in the documentation.

Interpolator1D
**************
.. autoclass:: interpax.Interpolator1D
.. autosummary::
:toctree: _api/
:recursive:
:template: class.rst

Interpolator2D
**************
.. autoclass:: interpax.Interpolator2D
interpax.Akima1DInterpolator
interpax.CubicHermiteSpline
interpax.CubicSpline
interpax.PchipInterpolator
interpax.PPoly

Interpolator3D
**************
.. autoclass:: interpax.Interpolator3D

Functional interface for 1D, 2D, 3D interpolation
-------------------------------------------------

.. autosummary::
:toctree: _api/
:recursive:

interpax.interp1d
interpax.interp2d
interpax.interp2d


Fourier interpolation of periodic functions in 1D and 2D
--------------------------------------------------------

.. autosummary::
:toctree: _api/
:recursive:

interpax.fft_interp1d
interpax.fft_interp2d


Approximating first derivatives for cubic splines
-------------------------------------------------

.. autosummary::
:toctree: _api/
:recursive:

interpax.approx_df
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def linkcode_resolve(domain, info):

autodoc_default_options = {
"member-order": "bysource",
"special-members": "__call__",
"exclude-members": "__init__",
}
# Add any paths that contain templates here, relative to this directory.
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


.. toctree::
:maxdepth: 2
:maxdepth: 3
:caption: Public API

api
Expand Down
7 changes: 7 additions & 0 deletions interpax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from . import _version
from ._fd_derivs import approx_df
from ._fourier import fft_interp1d, fft_interp2d
from ._ppoly import (
Akima1DInterpolator,
CubicHermiteSpline,
CubicSpline,
PchipInterpolator,
PPoly,
)
from ._spline import (
Interpolator1D,
Interpolator2D,
Expand Down
35 changes: 19 additions & 16 deletions interpax/_fd_derivs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from functools import partial

import jax
import jax.numpy as jnp
from jax import jit

from .utils import errorif
from .utils import asarray_inexact, errorif


def approx_df(
Expand Down Expand Up @@ -42,10 +40,13 @@ def approx_df(
First derivative of f with respect to x.
"""
return _approx_df(x, f, method, axis, **kwargs)
# close over static args to deal with non-jittable kwargs
def fun(x, f):
return _approx_df(x, f, method, axis, **kwargs)

return jit(fun)(x, f)


@partial(jit, static_argnames=("method", "axis", "bc_type"))
def _approx_df(x, f, method, axis, c=0, bc_type="not-a-knot"):
if method == "cubic":
out = _cubic1(x, f, axis)
Expand Down Expand Up @@ -92,7 +93,7 @@ def _cubic1(x, f, axis):
return fx


def _validate_bc(bc_type, expected_deriv_shape):
def _validate_bc(bc_type, expected_deriv_shape, dtype):
if isinstance(bc_type, str):
errorif(bc_type == "periodic", NotImplementedError)
bc_type = (bc_type, bc_type)
Expand Down Expand Up @@ -136,20 +137,21 @@ def _validate_bc(bc_type, expected_deriv_shape):
if deriv_order not in [1, 2]:
raise ValueError("The specified derivative order must " "be 1 or 2.")

deriv_value = jnp.asarray(deriv_value)
deriv_value = asarray_inexact(deriv_value)
dtype = jnp.promote_types(dtype, deriv_value.dtype)
if deriv_value.shape != expected_deriv_shape:
raise ValueError(
"`deriv_value` shape {} is not the expected one {}.".format(
deriv_value.shape, expected_deriv_shape
)
)
validated_bc.append((deriv_order, deriv_value))
return validated_bc
return validated_bc, dtype


def _cubic2(x, f, axis, bc_type):
f = jnp.moveaxis(f, axis, 0)
bc = _validate_bc(bc_type, f.shape[1:])
bc, dtype = _validate_bc(bc_type, f.shape[1:], f.dtype)
dx = jnp.diff(x)
df = jnp.diff(f, axis=0)
dxr = dx.reshape([dx.shape[0]] + [1] * (f.ndim - 1))
Expand All @@ -173,7 +175,7 @@ def _cubic2(x, f, axis, bc_type):
# constructing a parabola passing through given points.
if n == 3 and bc[0] == "not-a-knot" and bc[1] == "not-a-knot":
A = jnp.zeros((3, 3)) # This is a standard matrix.
b = jnp.empty((3,) + f.shape[1:], dtype=f.dtype)
b = jnp.empty((3,) + f.shape[1:], dtype=dtype)

A = A.at[0, 0].set(1)
A = A.at[0, 1].set(1)
Expand All @@ -187,20 +189,21 @@ def _cubic2(x, f, axis, bc_type):
b = b.at[1].set(3 * (dxr[0] * df[1] + dxr[1] * df[0]))
b = b.at[2].set(2 * df[1])

s = jnp.linalg.solve(A, b)
fx = jnp.moveaxis(s, 0, axis)
solve = lambda b: jnp.linalg.solve(A, b)
fx = jnp.vectorize(solve, signature="(n)->(n)")(b.T).T
fx = jnp.moveaxis(fx, 0, axis)

else:

# Find derivative values at each x[i] by solving a tridiagonal
# system.
diag = jnp.zeros(n)
diag = jnp.zeros(n, dtype=x.dtype)
diag = diag.at[1:-1].set(2 * (dx[:-1] + dx[1:]))
upper_diag = jnp.zeros(n - 1)
upper_diag = jnp.zeros(n - 1, dtype=x.dtype)
upper_diag = upper_diag.at[1:].set(dx[:-1])
lower_diag = jnp.zeros(n - 1)
lower_diag = jnp.zeros(n - 1, dtype=x.dtype)
lower_diag = lower_diag.at[:-1].set(dx[1:])
b = jnp.zeros((n,) + f.shape[1:], dtype=f.dtype)
b = jnp.zeros((n,) + f.shape[1:], dtype=dtype)
b = b.at[1:-1].set(3 * (dxr[1:] * df[:-1] + dxr[:-1] * df[1:]))

bc_start, bc_end = bc
Expand Down
Loading

0 comments on commit e178cba

Please sign in to comment.