Skip to content

Commit

Permalink
Add approx_df to public API
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Nov 7, 2023
1 parent e3c7455 commit 77f9ea3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 24 deletions.
4 changes: 4 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ fft_interp2d
************
.. autofunction:: interpax.fft_interp2d

approx_df
*********
.. autofunction:: interpax.approx_df

Interpolator1D
**************
.. autoclass:: interpax.Interpolator1D
Expand Down
1 change: 1 addition & 0 deletions interpax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Interpolator1D,
Interpolator2D,
Interpolator3D,
approx_df,
interp1d,
interp2d,
interp3d,
Expand Down
77 changes: 53 additions & 24 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, x, f, method="cubic", extrap=False, period=None, **kwargs):
self.period = period

if fx is None:
fx = _approx_df(x, f, method, axis, **kwargs)
fx = approx_df(x, f, method, axis, **kwargs)

self.derivs = {"fx": fx}

Expand Down Expand Up @@ -193,11 +193,11 @@ def __init__(self, x, y, f, method="cubic", extrap=False, period=None, **kwargs)
self.period = period

if fx is None:
fx = _approx_df(x, f, method, 0, **kwargs)
fx = approx_df(x, f, method, 0, **kwargs)
if fy is None:
fy = _approx_df(y, f, method, 1, **kwargs)
fy = approx_df(y, f, method, 1, **kwargs)
if fxy is None:
fxy = _approx_df(y, fx, method, 1, **kwargs)
fxy = approx_df(y, fx, method, 1, **kwargs)

self.derivs = {"fx": fx, "fy": fy, "fxy": fxy}

Expand Down Expand Up @@ -320,19 +320,19 @@ def __init__(self, x, y, z, f, method="cubic", extrap=False, period=None, **kwar
self.period = period

if fx is None:
fx = _approx_df(x, f, method, 0, **kwargs)
fx = approx_df(x, f, method, 0, **kwargs)
if fy is None:
fy = _approx_df(y, f, method, 1, **kwargs)
fy = approx_df(y, f, method, 1, **kwargs)
if fz is None:
fz = _approx_df(z, f, method, 2, **kwargs)
fz = approx_df(z, f, method, 2, **kwargs)
if fxy is None:
fxy = _approx_df(y, fx, method, 1, **kwargs)
fxy = approx_df(y, fx, method, 1, **kwargs)
if fxz is None:
fxz = _approx_df(z, fx, method, 2, **kwargs)
fxz = approx_df(z, fx, method, 2, **kwargs)
if fyz is None:
fyz = _approx_df(z, fy, method, 2, **kwargs)
fyz = approx_df(z, fy, method, 2, **kwargs)
if fxyz is None:
fxyz = _approx_df(z, fxy, method, 2, **kwargs)
fxyz = approx_df(z, fxy, method, 2, **kwargs)

self.derivs = {
"fx": fx,
Expand Down Expand Up @@ -486,7 +486,7 @@ def derivative2():

i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
if fx is None:
fx = _approx_df(x, f, method, axis, **kwargs)
fx = approx_df(x, f, method, axis, **kwargs)
assert fx.shape == f.shape

dx = x[i] - x[i - 1]
Expand Down Expand Up @@ -659,11 +659,11 @@ def derivative1():

elif method in CUBIC_METHODS:
if fx is None:
fx = _approx_df(x, f, method, 0, **kwargs)
fx = approx_df(x, f, method, 0, **kwargs)
if fy is None:
fy = _approx_df(y, f, method, 1, **kwargs)
fy = approx_df(y, f, method, 1, **kwargs)
if fxy is None:
fxy = _approx_df(y, fx, method, 1, **kwargs)
fxy = approx_df(y, fx, method, 1, **kwargs)
assert fx.shape == fy.shape == fxy.shape == f.shape

i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
Expand Down Expand Up @@ -912,19 +912,19 @@ def derivative1():

elif method in CUBIC_METHODS:
if fx is None:
fx = _approx_df(x, f, method, 0, **kwargs)
fx = approx_df(x, f, method, 0, **kwargs)
if fy is None:
fy = _approx_df(y, f, method, 1, **kwargs)
fy = approx_df(y, f, method, 1, **kwargs)
if fz is None:
fz = _approx_df(z, f, method, 2, **kwargs)
fz = approx_df(z, f, method, 2, **kwargs)
if fxy is None:
fxy = _approx_df(y, fx, method, 1, **kwargs)
fxy = approx_df(y, fx, method, 1, **kwargs)
if fxz is None:
fxz = _approx_df(z, fx, method, 2, **kwargs)
fxz = approx_df(z, fx, method, 2, **kwargs)
if fyz is None:
fyz = _approx_df(z, fy, method, 2, **kwargs)
fyz = approx_df(z, fy, method, 2, **kwargs)
if fxyz is None:
fxyz = _approx_df(z, fxy, method, 2, **kwargs)
fxyz = approx_df(z, fxy, method, 2, **kwargs)
assert (
fx.shape
== fy.shape
Expand Down Expand Up @@ -1095,8 +1095,37 @@ def noclip(fq, *_):


@partial(jit, static_argnames=("method", "axis"))
def _approx_df(x, f, method, axis, **kwargs):
"""Approximates derivatives for cubic spline interpolation."""
def approx_df(x, f, method="cubic", axis=-1, **kwargs):
"""Approximates first derivatives using cubic spline interpolation.
Parameters
----------
x : ndarray, shape(Nx,)
coordinates of known function values ("knots")
f : ndarray
Known function values. Should have length ``Nx`` along axis=axis
method : str
method of approximation
- ``'cubic'``: C1 cubic splines (aka local splines)
- ``'cubic2'``: C2 cubic splines (aka natural splines)
- ``'catmull-rom'``: C1 cubic centripetal "tension" splines
- ``'cardinal'``: C1 cubic general tension splines. If used, can also pass
keyword parameter ``c`` in float[0,1] to specify tension
- ``'monotonic'``: C1 cubic splines that attempt to preserve monotonicity in the
data, and will not introduce new extrema in the interpolated points
- ``'monotonic-0'``: same as ``'monotonic'`` but with 0 first derivatives at
both endpoints
axis : int
Axis along which f is varying.
Returns
-------
df : ndarray, shape(f.shape)
First derivative of f with respect to x.
"""
if method == "cubic":
dx = jnp.diff(x)
df = jnp.diff(f, axis=axis)
Expand Down

0 comments on commit 77f9ea3

Please sign in to comment.