diff --git a/CHANGELOG.md b/CHANGELOG.md index d6b626b..2f88bd1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,19 @@ Changelog ========= +- Adds a number of classes that replicate most of the functionality of the +corresponding classes from scipy.interpolate : + - ``scipy.interpolate.PPoly`` -> ``interpax.PPoly`` + - ``scipy.interpolate.Akima1DInterpolator`` -> ``interpax.Akima1DInterpolator`` + - ``scipy.interpolate.CubicHermiteSpline`` -> ``interpax.CubicHermiteSpline`` + - ``scipy.interpolate.CubicSpline`` -> ``interpax.CubicSpline`` + - ``scipy.interpolate.PchipInterpolator`` -> ``interpax.PchipInterpolator`` +- Method ``"akima"`` now available for ``Interpolator.{1D, 2D, 3D}`` and corresponding +functions. +- Method ``"monotonic"`` now works in 2D and 3D, where it will preserve monotonicity +with respect to each coordinate individually. + + v0.2.4 ------ - Fixes for scalar valued query points diff --git a/docs/Makefile b/docs/Makefile index d4bb2cb..29adaed 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -12,6 +12,11 @@ BUILDDIR = _build help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +clean: + rm -rf _api/ + rm -rf _build/ + .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new diff --git a/docs/_templates/class.rst b/docs/_templates/class.rst index c6d931c..1b62b09 100644 --- a/docs/_templates/class.rst +++ b/docs/_templates/class.rst @@ -12,8 +12,9 @@ .. autosummary:: :toctree: {{ objname }} - {% for item in methods %} - {% if item != "__init__" %} + + {% for item in all_methods %} + {%- if not item.startswith('_') or item in ['__call__',] %} ~{{ name }}.{{ item }} {% endif %} {%- endfor %} diff --git a/docs/api.rst b/docs/api.rst index 8eeae9e..9727881 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -2,38 +2,65 @@ API Documentation ================= -interp1d -******** -.. autofunction:: interpax.interp1d +Interpolation of 1D, 2D, or 3D data +----------------------------------- -interp2d -******** -.. autofunction:: interpax.interp2d +.. autosummary:: + :toctree: _api/ + :recursive: + :template: class.rst -interp3d -******** -.. autofunction:: interpax.interp3d + interpax.Interpolator1D + interpax.Interpolator2D + interpax.Interpolator3D -fft_interp1d -************ -.. autofunction:: interpax.fft_interp1d -fft_interp2d -************ -.. autofunction:: interpax.fft_interp2d +``scipy.interpolate``-like classes +---------------------------------- -approx_df -********* -.. autofunction:: interpax.approx_df +These classes implement most of the functionality of the SciPy classes with the same names, +except where noted in the documentation. -Interpolator1D -************** -.. autoclass:: interpax.Interpolator1D +.. autosummary:: + :toctree: _api/ + :recursive: + :template: class.rst -Interpolator2D -************** -.. autoclass:: interpax.Interpolator2D + interpax.Akima1DInterpolator + interpax.CubicHermiteSpline + interpax.CubicSpline + interpax.PchipInterpolator + interpax.PPoly -Interpolator3D -************** -.. autoclass:: interpax.Interpolator3D + +Functional interface for 1D, 2D, 3D interpolation +------------------------------------------------- + +.. autosummary:: + :toctree: _api/ + :recursive: + + interpax.interp1d + interpax.interp2d + interpax.interp2d + + +Fourier interpolation of periodic functions in 1D and 2D +-------------------------------------------------------- + +.. autosummary:: + :toctree: _api/ + :recursive: + + interpax.fft_interp1d + interpax.fft_interp2d + + +Approximating first derivatives for cubic splines +------------------------------------------------- + +.. autosummary:: + :toctree: _api/ + :recursive: + + interpax.approx_df diff --git a/docs/conf.py b/docs/conf.py index b41201d..a8486b2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -159,7 +159,6 @@ def linkcode_resolve(domain, info): autodoc_default_options = { "member-order": "bysource", - "special-members": "__call__", "exclude-members": "__init__", } # Add any paths that contain templates here, relative to this directory. diff --git a/docs/index.rst b/docs/index.rst index 9b05d85..7035a5b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -2,7 +2,7 @@ .. toctree:: - :maxdepth: 2 + :maxdepth: 3 :caption: Public API api diff --git a/interpax/__init__.py b/interpax/__init__.py index 539bdb5..b650dbf 100644 --- a/interpax/__init__.py +++ b/interpax/__init__.py @@ -3,6 +3,13 @@ from . import _version from ._fd_derivs import approx_df from ._fourier import fft_interp1d, fft_interp2d +from ._ppoly import ( + Akima1DInterpolator, + CubicHermiteSpline, + CubicSpline, + PchipInterpolator, + PPoly, +) from ._spline import ( Interpolator1D, Interpolator2D, diff --git a/interpax/_fd_derivs.py b/interpax/_fd_derivs.py index 895f506..896daeb 100644 --- a/interpax/_fd_derivs.py +++ b/interpax/_fd_derivs.py @@ -1,10 +1,8 @@ -from functools import partial - import jax import jax.numpy as jnp from jax import jit -from .utils import errorif +from .utils import asarray_inexact, errorif def approx_df( @@ -42,10 +40,13 @@ def approx_df( First derivative of f with respect to x. """ - return _approx_df(x, f, method, axis, **kwargs) + # close over static args to deal with non-jittable kwargs + def fun(x, f): + return _approx_df(x, f, method, axis, **kwargs) + + return jit(fun)(x, f) -@partial(jit, static_argnames=("method", "axis", "bc_type")) def _approx_df(x, f, method, axis, c=0, bc_type="not-a-knot"): if method == "cubic": out = _cubic1(x, f, axis) @@ -92,7 +93,7 @@ def _cubic1(x, f, axis): return fx -def _validate_bc(bc_type, expected_deriv_shape): +def _validate_bc(bc_type, expected_deriv_shape, dtype): if isinstance(bc_type, str): errorif(bc_type == "periodic", NotImplementedError) bc_type = (bc_type, bc_type) @@ -136,7 +137,8 @@ def _validate_bc(bc_type, expected_deriv_shape): if deriv_order not in [1, 2]: raise ValueError("The specified derivative order must " "be 1 or 2.") - deriv_value = jnp.asarray(deriv_value) + deriv_value = asarray_inexact(deriv_value) + dtype = jnp.promote_types(dtype, deriv_value.dtype) if deriv_value.shape != expected_deriv_shape: raise ValueError( "`deriv_value` shape {} is not the expected one {}.".format( @@ -144,12 +146,12 @@ def _validate_bc(bc_type, expected_deriv_shape): ) ) validated_bc.append((deriv_order, deriv_value)) - return validated_bc + return validated_bc, dtype def _cubic2(x, f, axis, bc_type): f = jnp.moveaxis(f, axis, 0) - bc = _validate_bc(bc_type, f.shape[1:]) + bc, dtype = _validate_bc(bc_type, f.shape[1:], f.dtype) dx = jnp.diff(x) df = jnp.diff(f, axis=0) dxr = dx.reshape([dx.shape[0]] + [1] * (f.ndim - 1)) @@ -173,7 +175,7 @@ def _cubic2(x, f, axis, bc_type): # constructing a parabola passing through given points. if n == 3 and bc[0] == "not-a-knot" and bc[1] == "not-a-knot": A = jnp.zeros((3, 3)) # This is a standard matrix. - b = jnp.empty((3,) + f.shape[1:], dtype=f.dtype) + b = jnp.empty((3,) + f.shape[1:], dtype=dtype) A = A.at[0, 0].set(1) A = A.at[0, 1].set(1) @@ -187,20 +189,21 @@ def _cubic2(x, f, axis, bc_type): b = b.at[1].set(3 * (dxr[0] * df[1] + dxr[1] * df[0])) b = b.at[2].set(2 * df[1]) - s = jnp.linalg.solve(A, b) - fx = jnp.moveaxis(s, 0, axis) + solve = lambda b: jnp.linalg.solve(A, b) + fx = jnp.vectorize(solve, signature="(n)->(n)")(b.T).T + fx = jnp.moveaxis(fx, 0, axis) else: # Find derivative values at each x[i] by solving a tridiagonal # system. - diag = jnp.zeros(n) + diag = jnp.zeros(n, dtype=x.dtype) diag = diag.at[1:-1].set(2 * (dx[:-1] + dx[1:])) - upper_diag = jnp.zeros(n - 1) + upper_diag = jnp.zeros(n - 1, dtype=x.dtype) upper_diag = upper_diag.at[1:].set(dx[:-1]) - lower_diag = jnp.zeros(n - 1) + lower_diag = jnp.zeros(n - 1, dtype=x.dtype) lower_diag = lower_diag.at[:-1].set(dx[1:]) - b = jnp.zeros((n,) + f.shape[1:], dtype=f.dtype) + b = jnp.zeros((n,) + f.shape[1:], dtype=dtype) b = b.at[1:-1].set(3 * (dxr[1:] * df[:-1] + dxr[:-1] * df[1:])) bc_start, bc_end = bc diff --git a/interpax/_ppoly.py b/interpax/_ppoly.py new file mode 100644 index 0000000..9bf3a9c --- /dev/null +++ b/interpax/_ppoly.py @@ -0,0 +1,805 @@ +"""Functions for interpolating splines that are JAX differentiable. + +The docstrings and API are from SciPy, under a BSD license: + +Copyright (c) 2001-2002 Enthought, Inc. 2003-2024, SciPy Developers. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" + +from functools import partial +from typing import Union + +import equinox as eqx +import jax +import jax.numpy as jnp +from jax import jit + +from ._coefs import A_CUBIC +from ._spline import approx_df +from .utils import asarray_inexact, errorif + + +class PPoly(eqx.Module): + """Piecewise polynomial in terms of coefficients and breakpoints. + + The polynomial between ``x[i]`` and ``x[i + 1]`` is written in the + local power basis:: + + S = sum(c[m, i] * (xp - x[i])**(k-m) for m in range(k+1)) + + where ``k`` is the degree of the polynomial. + + Parameters + ---------- + c : ndarray, shape (k, m, ...) + Polynomial coefficients, order `k` and `m` intervals. + x : ndarray, shape (m+1,) + Polynomial breakpoints. Must be sorted in either increasing or + decreasing order. + extrapolate : bool or 'periodic', optional + If bool, determines whether to extrapolate to out-of-bounds points + based on first and last intervals, or to return NaNs. If 'periodic', + periodic extrapolation is used. Default is True. + axis : int, optional + Interpolation axis. Default is zero. + check : bool + Whether to perform checks on the input. Should be False if used under JIT. + + Notes + ----- + High-order polynomials in the power basis can be numerically + unstable. Precision problems can start to appear for orders + larger than 20-30. + """ + + _c: jax.Array + _x: jax.Array + _extrapolate: Union[bool, str] = eqx.field(static=True) + _axis: int = eqx.field(static=True) + + def __init__( + self, + c: jax.Array, + x: jax.Array, + extrapolate: Union[bool, str] = None, + axis: int = 0, + check: bool = True, + ): + c = asarray_inexact(c) + x = asarray_inexact(x) + + errorif( + c.ndim < 2, + ValueError, + "Coefficients array must be at least 2-dimensional.", + ) + errorif(x.ndim != 1, ValueError, "x must be 1-dimensional") + errorif(x.size < 2, ValueError, "at least 2 breakpoints are needed") + errorif(c.ndim < 2, ValueError, "c must have at least 2 dimensions") + + axis = axis % (c.ndim - 1) + if extrapolate is None: + extrapolate = True + elif extrapolate != "periodic": + extrapolate = bool(extrapolate) + + if axis != 0: + # move the interpolation axis to be the first one in self.c + # More specifically, the target shape for self.c is (k, m, ...), + # and axis !=0 means that we have c.shape (..., k, m, ...) + # ^ + # axis + # So we roll two of them. + c = jnp.moveaxis(c, axis + 1, 0) + c = jnp.moveaxis(c, axis + 1, 0) + + errorif( + c.shape[0] == 0, + ValueError, + "polynomial must be at least of order 0", + ) + errorif( + c.shape[1] != x.size - 1, + ValueError, + "number of coefficients != len(x)-1", + ) + + if check: + + dx = jnp.diff(x) + errorif( + jnp.any(dx < 0), ValueError, "`x` must be strictly increasing sequence." + ) + + self._extrapolate = extrapolate + self._axis = axis + self._x = x + self._c = c + + @property + def c(self) -> jax.Array: + """Array of spline coefficients, shape(order, knots-1, ...).""" + return self._c + + @property + def x(self) -> jax.Array: + """Array of knot values, shape(knots).""" + return self._x + + @property + def extrapolate(self) -> Union[bool, str]: + """Whether to extrapolate beyond domain of known values.""" + return self._extrapolate + + @property + def axis(self) -> int: + """Axis along which to interpolate.""" + return self._axis + + @classmethod + def construct_fast( + cls, + c: jax.Array, + x: jax.Array, + extrapolate: Union[bool, str] = None, + axis: int = 0, + ): + """Construct the piecewise polynomial without making checks. + + Takes the same parameters as the constructor. Input arguments + ``c`` and ``x`` must be arrays of the correct shape and type. The + ``c`` array can only be of dtypes float and complex, and ``x`` + array must have dtype float. + """ + self = object.__new__(cls) + object.__setattr__(self, "_c", c) + object.__setattr__(self, "_x", x) + object.__setattr__(self, "_extrapolate", extrapolate) + object.__setattr__(self, "_axis", axis) + return self + + @partial(jit, static_argnames=("nu", "extrapolate")) + def __call__(self, x: jax.Array, nu: int = 0, extrapolate: Union[bool, str] = None): + """Evaluate the piecewise polynomial or its derivative. + + Parameters + ---------- + x : array_like + Points to evaluate the interpolant at. + nu : int, optional + Order of derivative to evaluate. Must be non-negative. + extrapolate : {bool, 'periodic', None}, optional + If bool, determines whether to extrapolate to out-of-bounds points + based on first and last intervals, or to return NaNs. + If 'periodic', periodic extrapolation is used. + If None (default), use `self.extrapolate`. + + Returns + ------- + y : array_like + Interpolated values. Shape is determined by replacing + the interpolation axis in the original array with the shape of x. + + Notes + ----- + Derivatives are evaluated piecewise for each polynomial + segment, even if the polynomial is not differentiable at the + breakpoints. The polynomial intervals are considered half-open, + ``[a, b)``, except for the last interval which is closed + ``[a, b]``. + """ + if extrapolate is None: + extrapolate = self.extrapolate + x = asarray_inexact(x) + x_shape, x_ndim = x.shape, x.ndim + x = x.flatten() + + # With periodic extrapolation we map x to the segment + # [self.x[0], self.x[-1]]. + if extrapolate == "periodic": + x = self.x[0] + (x - self.x[0]) % (self.x[-1] - self.x[0]) + extrapolate = False + + # TODO: implement extrap + + i = jnp.clip(jnp.searchsorted(self.x, x, side="right"), 1, len(self.x) - 1) + + t = x - self.x[i - 1] + c = self.c[:, i - 1] + + c = jnp.vectorize(lambda x: jnp.polyder(x, nu), signature="(n)->(m)")(c.T).T + y = jnp.vectorize(jnp.polyval, signature="(n),()->()")(c.T, t).T + + y = y.reshape(x_shape + self.c.shape[2:]) + + if not extrapolate: + mask = jnp.logical_or(x > self.x[-1], x < self.x[0]) + y = jnp.where(mask, jnp.nan, y.T).T + + if self.axis != 0: + # transpose to move the calculated values to the interpolation axis + l = list(range(y.ndim)) + l = l[x_ndim : x_ndim + self.axis] + l[:x_ndim] + l[x_ndim + self.axis :] + y = y.transpose(l) + return y + + def derivative(self, nu: int = 1): + """Construct a new piecewise polynomial representing the derivative. + + Parameters + ---------- + nu : int, optional + Order of derivative to evaluate. Default is 1, i.e., compute the + first derivative. If negative, the antiderivative is returned. + + Returns + ------- + pp : PPoly + Piecewise polynomial of order k2 = k - n representing the derivative + of this polynomial. + + Notes + ----- + Derivatives are evaluated piecewise for each polynomial + segment, even if the polynomial is not differentiable at the + breakpoints. The polynomial intervals are considered half-open, + ``[a, b)``, except for the last interval which is closed + ``[a, b]``. + """ + if nu < 0: + return self.antiderivative(-nu) + + if nu == 0: + c2 = self.c.copy() + else: + c2 = jnp.vectorize(lambda x: jnp.polyder(x, nu), signature="(n)->(m)")( + self.c.T + ).T + + if c2.shape[0] == 0: + # derivative of order 0 is zero + c2 = jnp.zeros((1,) + c2.shape[1:], dtype=c2.dtype) + + return self.construct_fast(c2, self.x, self.extrapolate, self.axis) + + def antiderivative(self, nu: int = 1): + """Construct a new piecewise polynomial representing the antiderivative. + + Antiderivative is also the indefinite integral of the function, + and derivative is its inverse operation. + + Parameters + ---------- + nu : int, optional + Order of antiderivative to evaluate. Default is 1, i.e., compute + the first integral. If negative, the derivative is returned. + + Returns + ------- + pp : PPoly + Piecewise polynomial of order k2 = k + n representing + the antiderivative of this polynomial. + + Notes + ----- + The antiderivative returned by this function is continuous and + continuously differentiable to order n-1, up to floating point + rounding error. + + If antiderivative is computed and ``self.extrapolate='periodic'``, + it will be set to False for the returned instance. This is done because + the antiderivative is no longer periodic and its correct evaluation + outside of the initially given x interval is difficult. + """ + if nu <= 0: + return self.derivative(-nu) + + if nu == 0: + c2 = self.c.copy() + else: + c2 = self.c.copy() + for _ in range(nu): + c2 = jnp.vectorize(jnp.polyint, signature="(n)->(m)")(c2.T).T + # need to patch up continuity + dx = jnp.diff(self.x) + z = jnp.vectorize(jnp.polyval, signature="(n),()->()")(c2.T, dx).T + c2 = c2.at[-1, 1:].add(jnp.cumsum(z, axis=self.axis)[:-1]) + + if self.extrapolate == "periodic": + extrapolate = False + else: + extrapolate = self.extrapolate + + return self.construct_fast(c2, self.x, extrapolate, self.axis) + + def integrate(self, a: float, b: float, extrapolate: Union[bool, str] = None): + """Compute a definite integral over a piecewise polynomial. + + Parameters + ---------- + a : float + Lower integration bound + b : float + Upper integration bound + extrapolate : {bool, 'periodic', None}, optional + If bool, determines whether to extrapolate to out-of-bounds points + based on first and last intervals, or to return NaNs. + If 'periodic', periodic extrapolation is used. + If None (default), use `self.extrapolate`. + + Returns + ------- + ig : array_like + Definite integral of the piecewise polynomial over [a, b] + """ + if extrapolate is None: + extrapolate = self.extrapolate + + integral = self.antiderivative(1) + # Swap integration bounds if needed + sign = 1 - 2 * (b < a) + a, b = jnp.sort(jnp.array([a, b])) + + # Compute the integral. + if extrapolate == "periodic": + # Split the integral into the part over period (can be several + # of them) and the remaining part. + + xs, xe = self.x[0], self.x[-1] + period = xe - xs + interval = b - a + n_periods, left = jnp.divmod(interval, period) + + def truefun(): + return (integral(xe) - integral(xs)) * n_periods + + def falsefun(): + return ( + jnp.zeros(self.c.shape[2:]) + if self.c.shape[2:] + else jnp.array([0.0]) + ) + + out = jax.lax.cond(n_periods > 0, truefun, falsefun) + + # Map a to [xs, xe], b is always a + left. + a = xs + (a - xs) % period + b = a + left + + # If b <= xe then we need to integrate over [a, b], otherwise + # over [a, xe] and from xs to what is remained. + + def truefun(out): + return out + (integral(b) - integral(a)) + + def falsefun(out): + out += integral(xe) - integral(a) + out += integral(xs + left + a - xe) - integral(xs) + return out + + out = jax.lax.cond(b <= xe, truefun, falsefun, out) + else: + out = integral(b, extrapolate=extrapolate) - integral( + a, extrapolate=extrapolate + ) + + return sign * out.reshape(self.c.shape[2:]) + + def solve(self, y=0.0, discontinuity=True, extrapolate=None): + """Not currently implemented.""" + raise NotImplementedError + + def roots(self, discontinuity=True, extrapolate=None): + """Not currently implemented.""" + raise NotImplementedError + + def extend(self, c, x, right=True): + """Not currently implemented.""" + raise NotImplementedError + + @classmethod + def from_spline(cls, tck, extrapolate=None): + """Not currently implemented.""" + raise NotImplementedError + + @classmethod + def from_bernstein_basis(cls, bp, extrapolate=None): + """Not currently implemented.""" + raise NotImplementedError + + +def prepare_input(x, y, axis, dydx=None, check=True): + """Prepare input for cubic spline interpolators. + + All data are converted to numpy arrays and checked for correctness. + Axes equal to `axis` of arrays `y` and `dydx` are moved to be the 0th + axis. The value of `axis` is converted to lie in + [0, number of dimensions of `y`). + """ + x, y = map(asarray_inexact, (x, y)) + dx = jnp.diff(x) + axis = axis % y.ndim + errorif( + jnp.issubdtype(x.dtype, jnp.complexfloating), + ValueError, + "`x` must contain real values.", + ) + x = x.astype(float) + + if dydx is not None: + dydx = asarray_inexact(dydx) + errorif( + y.shape != dydx.shape, + ValueError, + "The shapes of `y` and `dydx` must be identical.", + ) + dtype = jnp.promote_types(y.dtype, dydx.dtype) + dydx = dydx.astype(dtype) + y = y.astype(dtype) + if check: + errorif(x.ndim != 1, ValueError, "`x` must be 1-dimensional.") + errorif(x.shape[0] < 2, ValueError, "`x` must contain at least 2 elements.") + errorif( + x.shape[0] != y.shape[axis], + ValueError, + f"The length of `y` along `axis`={axis} doesn't match the length of `x`", + ) + errorif( + not jnp.all(jnp.isfinite(x)), + ValueError, + "`x` must contain only finite values.", + ) + errorif( + not jnp.all(jnp.isfinite(y)), + ValueError, + "`y` must contain only finite values.", + ) + errorif( + (dydx is not None) and (not jnp.all(jnp.isfinite(dydx))), + ValueError, + "`dydx` must contain only finite values.", + ) + errorif( + jnp.any(dx <= 0), ValueError, "`x` must be strictly increasing sequence." + ) + + return x, dx, y, axis, dydx + + +class CubicHermiteSpline(PPoly): + """Piecewise-cubic interpolator matching values and first derivatives. + + The result is represented as a `PPoly` instance. + + Parameters + ---------- + x : array_like, shape (n,) + 1-D array containing values of the independent variable. + Values must be real, finite and in strictly increasing order. + y : array_like + Array containing values of the dependent variable. It can have + arbitrary number of dimensions, but the length along ``axis`` + (see below) must match the length of ``x``. Values must be finite. + dydx : array_like + Array containing derivatives of the dependent variable. It can have + arbitrary number of dimensions, but the length along ``axis`` + (see below) must match the length of ``x``. Values must be finite. + axis : int, optional + Axis along which `y` is assumed to be varying. Meaning that for + ``x[i]`` the corresponding values are ``np.take(y, i, axis=axis)``. + Default is 0. + extrapolate : {bool, 'periodic', None}, optional + If bool, determines whether to extrapolate to out-of-bounds points + based on first and last intervals, or to return NaNs. If 'periodic', + periodic extrapolation is used. If None (default), it is set to True. + check : bool + Whether to perform checks on the input. Should be False if used under JIT. + + See Also + -------- + Akima1DInterpolator : Akima 1D interpolator. + PchipInterpolator : PCHIP 1-D monotonic cubic interpolator. + CubicSpline : Cubic spline data interpolator. + PPoly : Piecewise polynomial in terms of coefficients and breakpoints + + """ + + def __init__( + self, + x: jax.Array, + y: jax.Array, + dydx: jax.Array, + axis: int = 0, + extrapolate: Union[bool, str] = None, + check: bool = True, + ): + if extrapolate is None: + extrapolate = True + + x, dx, y, axis, dydx = prepare_input(x, y, axis, dydx, check) + + y = jnp.moveaxis(y, axis, 0) + dydx = jnp.moveaxis(dydx, axis, 0) + dxr = dx.reshape([dx.shape[0]] + [1] * (y.ndim - 1)) + F = jnp.stack([y[:-1], y[1:], dydx[:-1] * dxr, dydx[1:] * dxr], axis=0).T + c = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_CUBIC, F)[..., ::-1] + # handle non-uniform spacing + c = c / (dx[:, None] ** jnp.arange(4)[::-1]) + # c has shape (..., m, k) for m knots and order k + c = c.T # (k, m, ...) + # c.shape = (k, m, ...), but we want it to be (..., k, m, ...) + # ^ + # axis + # So we roll two of them. + c = jnp.moveaxis(c, 0, axis + 1) # (m, ..., k) + c = jnp.moveaxis(c, 0, axis + 1) # (..., k, m, ...) + super().__init__(c, x, extrapolate=extrapolate, axis=axis) + + +class PchipInterpolator(CubicHermiteSpline): + r"""PCHIP 1-D monotonic cubic interpolation. + + ``x`` and ``y`` are arrays of values used to approximate some function f, + with ``y = f(x)``. The interpolant uses monotonic cubic splines + to find the value of new points. (PCHIP stands for Piecewise Cubic + Hermite Interpolating Polynomial). + + Parameters + ---------- + x : ndarray, shape (npoints, ) + A 1-D array of monotonically increasing real values. ``x`` cannot + include duplicate values (otherwise f is overspecified) + y : ndarray, shape (..., npoints, ...) + A N-D array of real values. ``y``'s length along the interpolation + axis must be equal to the length of ``x``. Use the ``axis`` + parameter to select the interpolation axis. + axis : int, optional + Axis in the ``y`` array corresponding to the x-coordinate values. Defaults + to ``axis=0``. + extrapolate : bool, optional + Whether to extrapolate to out-of-bounds points based on first + and last intervals, or to return NaNs. + check : bool + Whether to perform checks on the input. Should be False if used under JIT. + + See Also + -------- + CubicHermiteSpline : Piecewise-cubic interpolator. + Akima1DInterpolator : Akima 1D interpolator. + CubicSpline : Cubic spline data interpolator. + PPoly : Piecewise polynomial in terms of coefficients and breakpoints. + + Notes + ----- + The interpolator preserves monotonicity in the interpolation data and does + not overshoot if the data is not smooth. + + The first derivatives are guaranteed to be continuous, but the second + derivatives may jump at :math:`x_k`. + + Determines the derivatives at the points :math:`x_k`, :math:`f'_k`, + by using PCHIP algorithm [1]_. + + Let :math:`h_k = x_{k+1} - x_k`, and :math:`d_k = (y_{k+1} - y_k) / h_k` + are the slopes at internal points :math:`x_k`. + If the signs of :math:`d_k` and :math:`d_{k-1}` are different or either of + them equals zero, then :math:`f'_k = 0`. Otherwise, it is given by the + weighted harmonic mean + + .. math:: + + \frac{w_1 + w_2}{f'_k} = \frac{w_1}{d_{k-1}} + \frac{w_2}{d_k} + + where :math:`w_1 = 2 h_k + h_{k-1}` and :math:`w_2 = h_k + 2 h_{k-1}`. + + The end slopes are set using a one-sided scheme [2]_. + + References + ---------- + .. [1] F. N. Fritsch and J. Butland, + A method for constructing local + monotone piecewise cubic interpolants, + SIAM J. Sci. Comput., 5(2), 300-304 (1984). + doi:`10.1137/0905021`. + .. [2] see, e.g., C. Moler, Numerical Computing with Matlab, 2004. + doi:`10.1137/1.9780898717952` + + """ + + def __init__( + self, + x: jax.Array, + y: jax.Array, + axis: int = 0, + extrapolate: Union[bool, str] = None, + check: bool = True, + ): + x, _, y, axis, _ = prepare_input(x, y, axis, check=check) + dydx = approx_df(x, y, "monotonic", axis=axis) + super().__init__(x, y, dydx, axis=axis, extrapolate=extrapolate) + + +class Akima1DInterpolator(CubicHermiteSpline): + """Akima interpolator. + + Fit piecewise cubic polynomials, given vectors x and y. The interpolation + method by Akima uses a continuously differentiable sub-spline built from + piecewise cubic polynomials. The resultant curve passes through the given + data points and will appear smooth and natural. + + Parameters + ---------- + x : ndarray, shape (npoints, ) + 1-D array of monotonically increasing real values. + y : ndarray, shape (..., npoints, ...) + N-D array of real values. The length of ``y`` along the interpolation axis + must be equal to the length of ``x``. Use the ``axis`` parameter to + select the interpolation axis. + axis : int, optional + Axis in the ``y`` array corresponding to the x-coordinate values. Defaults + to ``axis=0``. + extrapolate : bool, optional + Whether to extrapolate to out-of-bounds points based on first + and last intervals, or to return NaNs. + check : bool + Whether to perform checks on the input. Should be False if used under JIT. + + See Also + -------- + PchipInterpolator : PCHIP 1-D monotonic cubic interpolator. + CubicSpline : Cubic spline data interpolator. + PPoly : Piecewise polynomial in terms of coefficients and breakpoints + + Notes + ----- + Use only for precise data, as the fitted curve passes through the given + points exactly. This routine is useful for plotting a pleasingly smooth + curve through a few given points for purposes of plotting. + + References + ---------- + [1] A new method of interpolation and smooth curve fitting based + on local procedures. Hiroshi Akima, J. ACM, October 1970, 17(4), + 589-602. + + """ + + def __init__( + self, + x: jax.Array, + y: jax.Array, + axis: int = 0, + extrapolate: Union[bool, str] = None, + check: bool = True, + ): + x, _, y, axis, _ = prepare_input(x, y, axis, check=check) + t = approx_df(x, y, method="akima", axis=axis) + super().__init__(x, y, t, axis=axis, extrapolate=extrapolate) + + +class CubicSpline(CubicHermiteSpline): + """Cubic spline data interpolator. + + Interpolate data with a piecewise cubic polynomial which is twice + continuously differentiable [1]_. The result is represented as a `PPoly` + instance with breakpoints matching the given data. + + Parameters + ---------- + x : array_like, shape (n,) + 1-D array containing values of the independent variable. + Values must be real, finite and in strictly increasing order. + y : array_like + Array containing values of the dependent variable. It can have + arbitrary number of dimensions, but the length along ``axis`` + (see below) must match the length of ``x``. Values must be finite. + axis : int, optional + Axis along which `y` is assumed to be varying. Meaning that for + ``x[i]`` the corresponding values are ``np.take(y, i, axis=axis)``. + Default is 0. + bc_type : string or 2-tuple, optional + Boundary condition type. Two additional equations, given by the + boundary conditions, are required to determine all coefficients of + polynomials on each segment [2]_. + + If `bc_type` is a string, then the specified condition will be applied + at both ends of a spline. Available conditions are: + + * 'not-a-knot' (default): The first and second segment at a curve end + are the same polynomial. It is a good default when there is no + information on boundary conditions. + * 'periodic': The interpolated functions is assumed to be periodic + of period ``x[-1] - x[0]``. The first and last value of `y` must be + identical: ``y[0] == y[-1]``. This boundary condition will result in + ``y'[0] == y'[-1]`` and ``y''[0] == y''[-1]``. + * 'clamped': The first derivative at curves ends are zero. Assuming + a 1D `y`, ``bc_type=((1, 0.0), (1, 0.0))`` is the same condition. + * 'natural': The second derivative at curve ends are zero. Assuming + a 1D `y`, ``bc_type=((2, 0.0), (2, 0.0))`` is the same condition. + + If `bc_type` is a 2-tuple, the first and the second value will be + applied at the curve start and end respectively. The tuple values can + be one of the previously mentioned strings (except 'periodic') or a + tuple `(order, deriv_values)` allowing to specify arbitrary + derivatives at curve ends: + + * `order`: the derivative order, 1 or 2. + * `deriv_value`: array_like containing derivative values, shape must + be the same as `y`, excluding ``axis`` dimension. For example, if + `y` is 1-D, then `deriv_value` must be a scalar. If `y` is 3-D with + the shape (n0, n1, n2) and axis=2, then `deriv_value` must be 2-D + and have the shape (n0, n1). + extrapolate : {bool, 'periodic', None}, optional + If bool, determines whether to extrapolate to out-of-bounds points + based on first and last intervals, or to return NaNs. If 'periodic', + periodic extrapolation is used. If None (default), ``extrapolate`` is + set to 'periodic' for ``bc_type='periodic'`` and to True otherwise. + check : bool + Whether to perform checks on the input. Should be False if used under JIT. + + See Also + -------- + Akima1DInterpolator : Akima 1D interpolator. + PchipInterpolator : PCHIP 1-D monotonic cubic interpolator. + PPoly : Piecewise polynomial in terms of coefficients and breakpoints. + + Notes + ----- + Parameters `bc_type` and ``extrapolate`` work independently, i.e. the + former controls only construction of a spline, and the latter only + evaluation. + + When a boundary condition is 'not-a-knot' and n = 2, it is replaced by + a condition that the first derivative is equal to the linear interpolant + slope. When both boundary conditions are 'not-a-knot' and n = 3, the + solution is sought as a parabola passing through given points. + + References + ---------- + .. [1] `Cubic Spline Interpolation + `_ + on Wikiversity. + .. [2] Carl de Boor, "A Practical Guide to Splines", Springer-Verlag, 1978. + """ + + def __init__( + self, + x: jax.Array, + y: jax.Array, + axis: int = 0, + bc_type: Union[str, tuple] = "not-a-knot", + extrapolate: Union[bool, str] = None, + check: bool = True, + ): + x, _, y, axis, _ = prepare_input(x, y, axis, check=check) + df = approx_df(x, y, "cubic2", axis, bc_type=bc_type) + super().__init__(x, y, df, axis=axis, extrapolate=extrapolate) diff --git a/interpax/_spline.py b/interpax/_spline.py index 4fbb6a3..c44230d 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -12,7 +12,7 @@ from ._coefs import A_BICUBIC, A_CUBIC, A_TRICUBIC from ._fd_derivs import approx_df -from .utils import errorif, isbool +from .utils import asarray_inexact, errorif, isbool CUBIC_METHODS = ( "cubic", @@ -63,11 +63,6 @@ class Interpolator1D(eqx.Module): periodicity of the function. If given, function is assumed to be periodic on the interval [0,period]. None denotes no periodicity - Notes - ----- - This class is registered as a PyTree in JAX (it is actually an equinox.Module) - so should be compatible with standard JAX transformations (jit, grad, vmap, etc.) - """ x: jax.Array @@ -87,7 +82,7 @@ def __init__( period: Union[None, float] = None, **kwargs, ): - x, f = map(jnp.asarray, (x, f)) + x, f = map(asarray_inexact, (x, f)) axis = kwargs.get("axis", 0) fx = kwargs.pop("fx", None) @@ -174,11 +169,6 @@ class Interpolator2D(eqx.Module): otherwise function is assumed to be periodic on the interval [0,period]. Use a single value for the same in both directions. - Notes - ----- - This class is registered as a PyTree in JAX (it is actually an equinox.Module) - so should be compatible with standard JAX transformations (jit, grad, vmap, etc.) - """ x: jax.Array @@ -200,7 +190,7 @@ def __init__( period: Union[None, float, tuple] = None, **kwargs, ): - x, y, f = map(jnp.asarray, (x, y, f)) + x, y, f = map(asarray_inexact, (x, y, f)) axis = kwargs.get("axis", 0) fx = kwargs.pop("fx", None) fy = kwargs.pop("fy", None) @@ -303,11 +293,6 @@ class Interpolator3D(eqx.Module): otherwise function is assumed to be periodic on the interval [0,period]. Use a single value for the same in both directions. - Notes - ----- - This class is registered as a PyTree in JAX (it is actually an equinox.Module) - so should be compatible with standard JAX transformations (jit, grad, vmap, etc.) - """ x: jax.Array @@ -331,7 +316,7 @@ def __init__( period: Union[None, float, tuple] = None, **kwargs, ): - x, y, z, f = map(jnp.asarray, (x, y, z, f)) + x, y, z, f = map(asarray_inexact, (x, y, z, f)) axis = kwargs.get("axis", 0) errorif( @@ -491,7 +476,7 @@ def interp1d( which caches the calculation of the derivatives and spline coefficients. """ - xq, x, f = map(jnp.asarray, (xq, x, f)) + xq, x, f = map(asarray_inexact, (xq, x, f)) axis = kwargs.get("axis", 0) fx = kwargs.pop("fx", None) outshape = xq.shape + f.shape[1:] @@ -646,7 +631,7 @@ def interp2d( # noqa: C901 - FIXME: break this up into simpler pieces coefficients. """ - xq, yq, x, y, f = map(jnp.asarray, (xq, yq, x, y, f)) + xq, yq, x, y, f = map(asarray_inexact, (xq, yq, x, y, f)) fx = kwargs.pop("fx", None) fy = kwargs.pop("fy", None) fxy = kwargs.pop("fxy", None) @@ -861,7 +846,7 @@ def interp3d( # noqa: C901 - FIXME: break this up into simpler pieces coefficients. """ - xq, yq, zq, x, y, z, f = map(jnp.asarray, (xq, yq, zq, x, y, z, f)) + xq, yq, zq, x, y, z, f = map(asarray_inexact, (xq, yq, zq, x, y, z, f)) errorif( (len(x) != f.shape[0]) or (x.ndim != 1), ValueError, diff --git a/interpax/utils.py b/interpax/utils.py index 3536735..707ba8f 100644 --- a/interpax/utils.py +++ b/interpax/utils.py @@ -2,6 +2,8 @@ import warnings +import jax.numpy as jnp + def isbool(x): """Check if something is boolean or ndarray of bool type.""" @@ -22,3 +24,12 @@ def warnif(cond, err=UserWarning, msg=""): """Throw a warning if condition is met.""" if cond: warnings.warn(msg, err) + + +def asarray_inexact(x): + """Convert to jax array with floating point dtype.""" + x = jnp.asarray(x) + dtype = x.dtype + if not jnp.issubdtype(dtype, jnp.inexact): + dtype = jnp.result_type(x, jnp.array(1.0)) + return x.astype(dtype) diff --git a/setup.cfg b/setup.cfg index de4042e..5630678 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,9 +65,11 @@ ignore = D401, # don't care about docstrings in __dunder__ methods D105, + per-file-ignores = # need to import things to top level even if they aren't used there interpax/__init__.py: F401 + tests/*: D100, D101, D102, D103, D106 max-line-length = 88 exclude = docs/*, build/*, local/*, .git/*, versioneer.py, interpax/_version.py diff --git a/tests/test_scipy.py b/tests/test_scipy.py new file mode 100644 index 0000000..c1304f1 --- /dev/null +++ b/tests/test_scipy.py @@ -0,0 +1,893 @@ +"""Tests for scipy API. + +Tests mostly copied from scipy with minor rewrites. + +Copyright (c) 2001-2002 Enthought, Inc. 2003-2024, SciPy Developers. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" + +import io +import warnings + +import numpy as np +import pytest +import scipy.interpolate +from jax import config as jax_config +from numpy.testing import ( + assert_, + assert_allclose, + assert_array_almost_equal, + assert_array_equal, + assert_equal, +) +from pytest import raises as assert_raises +from scipy.interpolate import splev, splint, splrep + +from interpax import ( + Akima1DInterpolator, + CubicHermiteSpline, + CubicSpline, + PchipInterpolator, + PPoly, +) + +jax_config.update("jax_enable_x64", True) + + +class TestAkima1DInterpolator: + def test_eval(self): + x = np.arange(0.0, 11.0) + y = np.array([0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0]) + ak = Akima1DInterpolator(x, y) + xi = np.array( + [0.0, 0.5, 1.0, 1.5, 2.5, 3.5, 4.5, 5.1, 6.5, 7.2, 8.6, 9.9, 10.0] + ) + yi = np.array( + [ + 0.0, + 1.375, + 2.0, + 1.5, + 1.953125, + 2.484375, + 4.1363636363636366866103344, + 5.9803623910336236590978842, + 5.5067291516462386624652936, + 5.2031367459745245795943447, + 4.1796554159017080820603951, + 3.4110386597938129327189927, + 3.0, + ] + ) + assert_allclose(ak(xi), yi) + + def test_eval_2d(self): + x = np.arange(0.0, 11.0) + y = np.array([0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0]) + y = np.column_stack((y, 2.0 * y)) + ak = Akima1DInterpolator(x, y) + xi = np.array( + [0.0, 0.5, 1.0, 1.5, 2.5, 3.5, 4.5, 5.1, 6.5, 7.2, 8.6, 9.9, 10.0] + ) + yi = np.array( + [ + 0.0, + 1.375, + 2.0, + 1.5, + 1.953125, + 2.484375, + 4.1363636363636366866103344, + 5.9803623910336236590978842, + 5.5067291516462386624652936, + 5.2031367459745245795943447, + 4.1796554159017080820603951, + 3.4110386597938129327189927, + 3.0, + ] + ) + yi = np.column_stack((yi, 2.0 * yi)) + assert_allclose(ak(xi), yi) + + def test_eval_3d(self): + x = np.arange(0.0, 11.0) + y_ = np.array([0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0]) + y = np.empty((11, 2, 2)) + y[:, 0, 0] = y_ + y[:, 1, 0] = 2.0 * y_ + y[:, 0, 1] = 3.0 * y_ + y[:, 1, 1] = 4.0 * y_ + ak = Akima1DInterpolator(x, y) + xi = np.array( + [0.0, 0.5, 1.0, 1.5, 2.5, 3.5, 4.5, 5.1, 6.5, 7.2, 8.6, 9.9, 10.0] + ) + yi = np.empty((13, 2, 2)) + yi_ = np.array( + [ + 0.0, + 1.375, + 2.0, + 1.5, + 1.953125, + 2.484375, + 4.1363636363636366866103344, + 5.9803623910336236590978842, + 5.5067291516462386624652936, + 5.2031367459745245795943447, + 4.1796554159017080820603951, + 3.4110386597938129327189927, + 3.0, + ] + ) + yi[:, 0, 0] = yi_ + yi[:, 1, 0] = 2.0 * yi_ + yi[:, 0, 1] = 3.0 * yi_ + yi[:, 1, 1] = 4.0 * yi_ + assert_allclose(ak(xi), yi) + + def test_degenerate_case_multidimensional(self): + # This test is for issue #5683. + x = np.array([0, 1, 2]) + y = np.vstack((x, x**2)).T + ak = Akima1DInterpolator(x, y) + x_eval = np.array([0.5, 1.5]) + y_eval = ak(x_eval) + assert_allclose(y_eval, np.vstack((x_eval, x_eval**2)).T) + + def test_extend(self): + x = np.arange(0.0, 11.0) + y = np.array([0.0, 2.0, 1.0, 3.0, 2.0, 6.0, 5.5, 5.5, 2.7, 5.1, 3.0]) + ak = Akima1DInterpolator(x, y) + with pytest.raises(NotImplementedError): + ak.extend(None, None) + + +class TestPPolyCommon: + # test basic functionality for PPoly and BPoly + def test_sort_check(self): + c = np.array([[1, 4], [2, 5], [3, 6]]) + x = np.array([0, 1, 0.5]) + assert_raises(ValueError, PPoly, c, x) + + def test_ctor_c(self): + # wrong shape: `c` must be at least 2D + with assert_raises(ValueError): + PPoly([1, 2], [0, 1]) + + def test_extend(self): + # Test adding new points to the piecewise polynomial + np.random.seed(1234) + + order = 3 + x = np.unique(np.r_[0, 10 * np.random.rand(30), 10]) + c = 2 * np.random.rand(order + 1, len(x) - 1, 2, 3) - 1 + + for cls in (PPoly,): + pp = cls(c[:, :9], x[:10]) + with pytest.raises(NotImplementedError): + pp.extend(None, None) + + def test_shape(self): + np.random.seed(1234) + c = np.random.rand(8, 12, 5, 6, 7) + x = np.sort(np.random.rand(13)) + xp = np.random.rand(3, 4) + for cls in (PPoly,): + p = cls(c, x) + assert_equal(p(xp).shape, (3, 4, 5, 6, 7)) + + # 'scalars' + for cls in (PPoly,): + p = cls(c[..., 0, 0, 0], x) + + assert_equal(np.shape(p(0.5)), ()) + assert_equal(np.shape(p(np.array(0.5))), ()) + + assert_raises(TypeError, p, np.array([[0.1, 0.2], [0.4]], dtype=object)) + + def test_complex_coef(self): + np.random.seed(12345) + x = np.sort(np.random.random(13)) + c = np.random.random((8, 12)) * (1.0 + 0.3j) + c_re, c_im = c.real, c.imag + xp = np.random.random(5) + for cls in (PPoly,): + p, p_re, p_im = cls(c, x), cls(c_re, x), cls(c_im, x) + for nu in [0, 1, 2]: + assert_allclose(p(xp, nu).real, p_re(xp, nu)) + assert_allclose(p(xp, nu).imag, p_im(xp, nu)) + + def test_axis(self): + np.random.seed(12345) + c = np.random.rand(3, 4, 5, 6, 7, 8) + c_s = c.shape + xp = np.random.random((1, 2)) + for axis in (0, 1, 2, 3): + m = c.shape[axis + 1] + x = np.sort(np.random.rand(m + 1)) + for cls in (PPoly,): + p = cls(c, x, axis=axis) + assert_equal( + p.c.shape, c_s[axis : axis + 2] + c_s[:axis] + c_s[axis + 2 :] + ) + res = p(xp) + targ_shape = c_s[:axis] + xp.shape + c_s[2 + axis :] + assert_equal(res.shape, targ_shape) + + # deriv/antideriv does not drop the axis + for p1 in [ + cls(c, x, axis=axis).derivative(), + cls(c, x, axis=axis).derivative(2), + cls(c, x, axis=axis).antiderivative(), + cls(c, x, axis=axis).antiderivative(2), + ]: + assert_equal(p1.axis, p.axis) + + # c array needs two axes for the coefficients and intervals, so + # we expect 0 <= axis < c.ndim-1; raise otherwise + for axis in (-1, 4, 5, 6): + for cls in (PPoly,): + assert_raises(ValueError, cls, **dict(c=c, x=x, axis=axis)) + + +class TestPolySubclassing: + class P(PPoly): + pass + + class B(P): + pass + + def _make_polynomials(self): + np.random.seed(1234) + x = np.sort(np.random.random(3)) + c = np.random.random((4, 2)) + return self.P(c, x), self.B(c, x) + + def test_derivative(self): + pp, bp = self._make_polynomials() + for p in (pp, bp): + pd = p.derivative() + assert_equal(p.__class__, pd.__class__) + + ppa = pp.antiderivative() + assert_equal(pp.__class__, ppa.__class__) + + +class TestPPoly: + def test_simple(self): + c = np.array([[1, 4], [2, 5], [3, 6]]) + x = np.array([0, 0.5, 1]) + p = PPoly(c, x) + + assert_allclose(p(0.3), 1 * 0.3**2 + 2 * 0.3 + 3) + assert_allclose(p(0.7), 4 * (0.7 - 0.5) ** 2 + 5 * (0.7 - 0.5) + 6) + + def test_periodic(self): + c = np.array([[1, 4], [2, 5], [3, 6]]) + x = np.array([0, 0.5, 1]) + p = PPoly(c, x, extrapolate="periodic") + + assert_allclose(p(1.3), 1 * 0.3**2 + 2 * 0.3 + 3) + assert_allclose(p(-0.3), 4 * (0.7 - 0.5) ** 2 + 5 * (0.7 - 0.5) + 6) + + assert_allclose(p(1.3, 1), 2 * 0.3 + 2) + assert_allclose(p(-0.3, 1), 8 * (0.7 - 0.5) + 5) + + def test_read_only(self): + c = np.array([[1, 4], [2, 5], [3, 6]]) + x = np.array([0, 0.5, 1]) + xnew = np.array([0, 0.1, 0.2]) + PPoly(c, x, extrapolate="periodic") + + for writeable in (True, False): + x.flags.writeable = writeable + c.flags.writeable = writeable + f = PPoly(c, x) + vals = f(xnew) + assert_(np.isfinite(vals).all()) + + def test_multi_shape(self): + c = np.random.rand(6, 2, 1, 2, 3) + x = np.array([0, 0.5, 1]) + p = PPoly(c, x) + assert_equal(p.x.shape, x.shape) + assert_equal(p.c.shape, c.shape) + assert_equal(p(0.3).shape, c.shape[2:]) + + assert_equal(p(np.random.rand(5, 6)).shape, (5, 6) + c.shape[2:]) + + dp = p.derivative() + assert_equal(dp.c.shape, (5, 2, 1, 2, 3)) + ip = p.antiderivative() + assert_equal(ip.c.shape, (7, 2, 1, 2, 3)) + + def test_construct_fast(self): + np.random.seed(1234) + c = np.array([[1, 4], [2, 5], [3, 6]], dtype=float) + x = np.array([0, 0.5, 1]) + p = PPoly.construct_fast(c, x) + assert_allclose(p(0.3), 1 * 0.3**2 + 2 * 0.3 + 3) + assert_allclose(p(0.7), 4 * (0.7 - 0.5) ** 2 + 5 * (0.7 - 0.5) + 6) + + def test_vs_alternative_implementations(self): + np.random.seed(1234) + c = np.random.rand(3, 12, 22) + x = np.sort(np.r_[0, np.random.rand(11), 1]) + + p = PPoly(c, x) + + xp = np.r_[0.3, 0.5, 0.33, 0.6] + expected = _ppoly_eval_1(c, x, xp) + assert_allclose(p(xp), expected) + + expected = _ppoly_eval_2(c[:, :, 0], x, xp) + assert_allclose(p(xp)[:, 0], expected) + + def test_derivative_simple(self): + np.random.seed(1234) + c = np.array([[4, 3, 2, 1]]).T + dc = np.array([[3 * 4, 2 * 3, 2]]).T + ddc = np.array([[2 * 3 * 4, 1 * 2 * 3]]).T + x = np.array([0, 1]) + + pp = PPoly(c, x) + dpp = PPoly(dc, x) + ddpp = PPoly(ddc, x) + + assert_allclose(pp.derivative().c, dpp.c) + assert_allclose(pp.derivative(2).c, ddpp.c) + + def test_derivative_eval(self): + np.random.seed(1234) + x = np.sort(np.r_[0, np.random.rand(11), 1]) + y = np.random.rand(len(x)) + + spl = splrep(x, y, s=0) + spp = scipy.interpolate.PPoly.from_spline(spl) + pp = PPoly(spp.c, spp.x) + + xi = np.linspace(0, 1, 200) + for dx in range(0, 3): + assert_allclose(pp(xi, dx), splev(xi, spl, dx)) + + def test_derivative(self): + np.random.seed(1234) + x = np.sort(np.r_[0, np.random.rand(11), 1]) + y = np.random.rand(len(x)) + + spl = splrep(x, y, s=0, k=5) + spp = scipy.interpolate.PPoly.from_spline(spl) + pp = PPoly(spp.c, spp.x) + + xi = np.linspace(0, 1, 200) + for dx in range(0, 10): + assert_allclose(pp(xi, dx), pp.derivative(dx)(xi), err_msg="dx=%d" % (dx,)) + + def test_antiderivative_of_constant(self): + # https://github.com/scipy/scipy/issues/4216 + p = PPoly([[1.0]], [0, 1]) + assert_array_equal(p.antiderivative().c, PPoly([[1], [0]], [0, 1]).c) + assert_array_equal(p.antiderivative().x, PPoly([[1], [0]], [0, 1]).x) + + def test_antiderivative_regression_4355(self): + # https://github.com/scipy/scipy/issues/4355 + p = PPoly([[1.0, 0.5]], [0, 1, 2]) + q = p.antiderivative() + assert_array_equal(q.c, [[1, 0.5], [0, 1]]) + assert_array_equal(q.x, [0, 1, 2]) + assert_allclose(p.integrate(0, 2), 1.5) + assert_allclose(q(2) - q(0), 1.5) + + def test_antiderivative_simple(self): + np.random.seed(1234) + # [ p1(x) = 3*x**2 + 2*x + 1, + # p2(x) = 1.6875] + c = np.array([[3, 2, 1], [0, 0, 1.6875]]).T + # [ pp1(x) = x**3 + x**2 + x, + # pp2(x) = 1.6875*(x - 0.25) + pp1(0.25)] + ic = np.array([[1, 1, 1, 0], [0, 0, 1.6875, 0.328125]]).T + # [ ppp1(x) = (1/4)*x**4 + (1/3)*x**3 + (1/2)*x**2, + # ppp2(x) = (1.6875/2)*(x - 0.25)**2 + pp1(0.25)*x + ppp1(0.25)] + iic = np.array( + [ + [1 / 4, 1 / 3, 1 / 2, 0, 0], + [0, 0, 1.6875 / 2, 0.328125, 0.037434895833333336], + ] + ).T + x = np.array([0, 0.25, 1]) + + pp = PPoly(c, x) + ipp = pp.antiderivative() + iipp = pp.antiderivative(2) + iipp2 = ipp.antiderivative() + + assert_allclose(ipp.x, x) + assert_allclose(ipp.c.T, ic.T) + assert_allclose(iipp.c.T, iic.T) + assert_allclose(iipp2.c.T, iic.T) + + def test_antiderivative_vs_derivative(self): + np.random.seed(1234) + x = np.linspace(0, 1, 30) ** 2 + y = np.random.rand(len(x)) + spl = splrep(x, y, s=0, k=5) + spp = scipy.interpolate.PPoly.from_spline(spl) + pp = PPoly(spp.c, spp.x) + + for dx in range(0, 10): + ipp = pp.antiderivative(dx) + + # check that derivative is inverse op + pp2 = ipp.derivative(dx) + assert_allclose(pp.c, pp2.c) + + # check continuity + for k in range(dx): + pp2 = ipp.derivative(k) + + r = 1e-13 + endpoint = r * pp2.x[:-1] + (1 - r) * pp2.x[1:] + + assert_allclose( + pp2(pp2.x[1:]), + pp2(endpoint), + rtol=1e-7, + err_msg="dx=%d k=%d" % (dx, k), + ) + + def test_antiderivative_continuity(self): + c = np.array([[2, 1, 2, 2], [2, 1, 3, 3]]).T + x = np.array([0, 0.5, 1]) + + p = PPoly(c, x) + ip = p.antiderivative() + + # check continuity + assert_allclose(ip(0.5 - 1e-9), ip(0.5 + 1e-9), rtol=1e-8) + + # check that only lowest order coefficients were changed + p2 = ip.derivative() + assert_allclose(p2.c, p.c) + + def test_integrate(self): + np.random.seed(1234) + x = np.sort(np.r_[0, np.random.rand(11), 1]) + y = np.random.rand(len(x)) + + spl = splrep(x, y, s=0, k=5) + spp = scipy.interpolate.PPoly.from_spline(spl) + pp = PPoly(spp.c, spp.x) + + a, b = 0.3, 0.9 + ig = pp.integrate(a, b) + + ipp = pp.antiderivative() + assert_allclose(ig, ipp(b) - ipp(a)) + assert_allclose(ig, splint(a, b, spl)) + + a, b = -0.3, 0.9 + ig = pp.integrate(a, b, extrapolate=True) + assert_allclose(ig, ipp(b) - ipp(a)) + + assert_(np.isnan(pp.integrate(a, b, extrapolate=False)).all()) + + def test_integrate_readonly(self): + x = np.array([1, 2, 4]) + c = np.array([[0.0, 0.0], [-1.0, -1.0], [2.0, -0.0], [1.0, 2.0]]) + + for writeable in (True, False): + x.flags.writeable = writeable + + P = PPoly(c, x) + vals = P.integrate(1, 4) + + assert_(np.isfinite(vals).all()) + + def test_integrate_periodic(self): + x = np.array([1, 2, 4]) + c = np.array([[0.0, 0.0], [-1.0, -1.0], [2.0, -0.0], [1.0, 2.0]]) + + P = PPoly(c, x, extrapolate="periodic") + I = P.antiderivative() + + period_int = I(4) - I(1) + + assert_allclose(P.integrate(1, 4), period_int) + assert_allclose(P.integrate(-10, -7), period_int) + assert_allclose(P.integrate(-10, -4), 2 * period_int) + + assert_allclose(P.integrate(1.5, 2.5), I(2.5) - I(1.5)) + assert_allclose(P.integrate(3.5, 5), I(2) - I(1) + I(4) - I(3.5)) + assert_allclose(P.integrate(3.5 + 12, 5 + 12), I(2) - I(1) + I(4) - I(3.5)) + assert_allclose( + P.integrate(3.5, 5 + 12), I(2) - I(1) + I(4) - I(3.5) + 4 * period_int + ) + + assert_allclose(P.integrate(0, -1), I(2) - I(3)) + assert_allclose(P.integrate(-9, -10), I(2) - I(3)) + assert_allclose(P.integrate(0, -10), I(2) - I(3) - 3 * period_int) + + def test_extrapolate_attr(self): + # 1 - x**2 + c = np.array([[-1, 0, 1]]).T + x = np.array([0, 1]) + + for extrapolate in [True, False, None]: + pp = PPoly(c, x, extrapolate=extrapolate) + pp_d = pp.derivative() + pp_i = pp.antiderivative() + + if extrapolate is False: + assert_(np.isnan(pp([-0.1, 1.1])).all()) + assert_(np.isnan(pp_i([-0.1, 1.1])).all()) + assert_(np.isnan(pp_d([-0.1, 1.1])).all()) + else: + assert_allclose(pp([-0.1, 1.1]), [1 - 0.1**2, 1 - 1.1**2]) + assert_(not np.isnan(pp_i([-0.1, 1.1])).any()) + assert_(not np.isnan(pp_d([-0.1, 1.1])).any()) + + +def _ppoly_eval_1(c, x, xps): + """Evaluate piecewise polynomial manually.""" + out = np.zeros((len(xps), c.shape[2])) + for i, xp in enumerate(xps): + if xp < 0 or xp > 1: + out[i, :] = np.nan + continue + j = np.searchsorted(x, xp) - 1 + d = xp - x[j] + assert_(x[j] <= xp < x[j + 1]) + r = sum(c[k, j] * d ** (c.shape[0] - k - 1) for k in range(c.shape[0])) + out[i, :] = r + return out + + +def _ppoly_eval_2(coeffs, breaks, xnew, fill=np.nan): + """Evaluate piecewise polynomial manually (another way).""" + a = breaks[0] + b = breaks[-1] + K = coeffs.shape[0] + + saveshape = np.shape(xnew) + xnew = np.ravel(xnew) + res = np.empty_like(xnew) + mask = (xnew >= a) & (xnew <= b) + res[~mask] = fill + xx = xnew.compress(mask) + indxs = np.searchsorted(breaks, xx) - 1 + indxs = indxs.clip(0, len(breaks)) + pp = coeffs + diff = xx - breaks.take(indxs) + V = np.vander(diff, N=K) + values = np.array([np.dot(V[k, :], pp[:, indxs[k]]) for k in range(len(xx))]) + res[mask] = values + res.shape = saveshape + return res + + +class TestPCHIP: + def _make_random(self, npts=20): + np.random.seed(1234) + xi = np.sort(np.random.random(npts)) + yi = np.random.random(npts) + return PchipInterpolator(xi, yi), xi, yi + + def test_overshoot(self): + # PCHIP should not overshoot + p, xi, yi = self._make_random() + for i in range(len(xi) - 1): + x1, x2 = xi[i], xi[i + 1] + y1, y2 = yi[i], yi[i + 1] + if y1 > y2: + y1, y2 = y2, y1 + xp = np.linspace(x1, x2, 10) + yp = p(xp) + assert_(((y1 <= yp + 1e-15) & (yp <= y2 + 1e-15)).all()) + + def test_monotone(self): + # PCHIP should preserve monotonicty + p, xi, yi = self._make_random() + for i in range(len(xi) - 1): + x1, x2 = xi[i], xi[i + 1] + y1, y2 = yi[i], yi[i + 1] + xp = np.linspace(x1, x2, 10) + yp = p(xp) + assert_(((y2 - y1) * (yp[1:] - yp[:1]) > 0).all()) + + def test_cast(self): + # regression test for integer input data, see gh-3453 + data = np.array( + [ + [0, 4, 12, 27, 47, 60, 79, 87, 99, 100], + [-33, -33, -19, -2, 12, 26, 38, 45, 53, 55], + ] + ) + xx = np.arange(100) + curve = PchipInterpolator(data[0], data[1])(xx) + + data1 = data * 1.0 + curve1 = PchipInterpolator(data1[0], data1[1])(xx) + + assert_allclose(curve, curve1, atol=1e-14, rtol=1e-14) + + def test_nag(self): + # Example from NAG C implementation, + # http://nag.com/numeric/cl/nagdoc_cl25/html/e01/e01bec.html + # suggested in gh-5326 as a smoke test for the way the derivatives + # are computed (see also gh-3453) + dataStr = """ + 7.99 0.00000E+0 + 8.09 0.27643E-4 + 8.19 0.43750E-1 + 8.70 0.16918E+0 + 9.20 0.46943E+0 + 10.00 0.94374E+0 + 12.00 0.99864E+0 + 15.00 0.99992E+0 + 20.00 0.99999E+0 + """ + data = np.loadtxt(io.StringIO(dataStr)) + pch = PchipInterpolator(data[:, 0], data[:, 1]) + + resultStr = """ + 7.9900 0.0000 + 9.1910 0.4640 + 10.3920 0.9645 + 11.5930 0.9965 + 12.7940 0.9992 + 13.9950 0.9998 + 15.1960 0.9999 + 16.3970 1.0000 + 17.5980 1.0000 + 18.7990 1.0000 + 20.0000 1.0000 + """ + result = np.loadtxt(io.StringIO(resultStr)) + assert_allclose(result[:, 1], pch(result[:, 0]), rtol=0.0, atol=5e-5) + + def test_endslopes(self): + # this is a smoke test for gh-3453: PCHIP interpolator should not + # set edge slopes to zero if the data do not suggest zero edge derivatives + x = np.array([0.0, 0.1, 0.25, 0.35]) + y1 = np.array([279.35, 0.5e3, 1.0e3, 2.5e3]) + y2 = np.array([279.35, 2.5e3, 1.50e3, 1.0e3]) + for pp in (PchipInterpolator(x, y1), PchipInterpolator(x, y2)): + for t in (x[0], x[-1]): + assert_(pp(t, 1) != 0) + + def test_all_zeros(self): + x = np.arange(10) + y = np.zeros_like(x) + + # this should work and not generate any warnings + with warnings.catch_warnings(): + warnings.filterwarnings("error") + pch = PchipInterpolator(x, y) + + xx = np.linspace(0, 9, 101) + assert_array_equal(pch(xx), 0.0) + + def test_two_points(self): + # regression test for gh-6222: PchipInterpolator([0, 1], [0, 1]) fails because + # it tries to use a three-point scheme to estimate edge derivatives, + # while there are only two points available. + # Instead, it should construct a linear interpolator. + x = np.linspace(0, 1, 11) + p = PchipInterpolator([0, 1], [0, 2]) + assert_allclose(p(x), 2 * x, atol=1e-15) + + def test_PchipInterpolator(self): + assert_array_almost_equal( + PchipInterpolator([1, 2, 3], [4, 5, 6])([0.5], nu=1), [1.0] + ) + + assert_array_almost_equal( + PchipInterpolator([1, 2, 3], [4, 5, 6])([0.5], nu=0), [3.5] + ) + + +class TestCubicSpline: + @staticmethod + def check_correctness(S, bc_start="not-a-knot", bc_end="not-a-knot", tol=1e-14): + """Check that spline coefficients satisfy the continuity and bc.""" + x = S.x + c = S.c + dx = np.diff(x) + dx = dx.reshape([dx.shape[0]] + [1] * (c.ndim - 2)) + dxi = dx[:-1] + + # Check C2 continuity. + assert_allclose( + c[3, 1:], + c[0, :-1] * dxi**3 + c[1, :-1] * dxi**2 + c[2, :-1] * dxi + c[3, :-1], + rtol=tol, + atol=tol, + ) + assert_allclose( + c[2, 1:], + 3 * c[0, :-1] * dxi**2 + 2 * c[1, :-1] * dxi + c[2, :-1], + rtol=tol, + atol=tol, + ) + assert_allclose(c[1, 1:], 3 * c[0, :-1] * dxi + c[1, :-1], rtol=tol, atol=tol) + + # Check that we found a parabola, the third derivative is 0. + if x.size == 3 and bc_start == "not-a-knot" and bc_end == "not-a-knot": + assert_allclose(c[0], 0, rtol=tol, atol=tol) + return + + # Check periodic boundary conditions. + if bc_start == "periodic": + assert_allclose(S(x[0], 0), S(x[-1], 0), rtol=tol, atol=tol) + assert_allclose(S(x[0], 1), S(x[-1], 1), rtol=tol, atol=tol) + assert_allclose(S(x[0], 2), S(x[-1], 2), rtol=tol, atol=tol) + return + + # Check other boundary conditions. + if bc_start == "not-a-knot": + if x.size == 2: + slope = (S(x[1]) - S(x[0])) / dx[0] + assert_allclose(S(x[0], 1), slope, rtol=tol, atol=tol) + else: + assert_allclose(c[0, 0], c[0, 1], rtol=tol, atol=tol) + elif bc_start == "clamped": + assert_allclose(S(x[0], 1), 0, rtol=tol, atol=tol) + elif bc_start == "natural": + assert_allclose(S(x[0], 2), 0, rtol=tol, atol=tol) + else: + order, value = bc_start + assert_allclose(S(x[0], order), value, rtol=tol, atol=tol) + + if bc_end == "not-a-knot": + if x.size == 2: + slope = (S(x[1]) - S(x[0])) / dx[0] + assert_allclose(S(x[1], 1), slope, rtol=tol, atol=tol) + else: + assert_allclose(c[0, -1], c[0, -2], rtol=tol, atol=tol) + elif bc_end == "clamped": + assert_allclose(S(x[-1], 1), 0, rtol=tol, atol=tol) + elif bc_end == "natural": + assert_allclose(S(x[-1], 2), 0, rtol=2 * tol, atol=2 * tol) + else: + order, value = bc_end + assert_allclose(S(x[-1], order), value, rtol=tol, atol=tol) + + def check_all_bc(self, x, y, axis): + deriv_shape = list(y.shape) + del deriv_shape[axis] + first_deriv = np.empty(deriv_shape) + first_deriv.fill(2) + second_deriv = np.empty(deriv_shape) + second_deriv.fill(-1) + bc_all = [ + "not-a-knot", + "natural", + "clamped", + (1, first_deriv), + (2, second_deriv), + ] + for bc in bc_all[:3]: + S = CubicSpline(x, y, axis=axis, bc_type=bc) + self.check_correctness(S, bc, bc) + + for bc_start in bc_all: + for bc_end in bc_all: + S = CubicSpline(x, y, axis=axis, bc_type=(bc_start, bc_end)) + self.check_correctness(S, bc_start, bc_end, tol=2e-14) + + def test_general(self): + x = np.array([-1, 0, 0.5, 2, 4, 4.5, 5.5, 9]) + y = np.array([0, -0.5, 2, 3, 2.5, 1, 1, 0.5]) + for n in [2, 3, x.size]: + self.check_all_bc(x[:n], y[:n], 0) + + Y = np.empty((2, n, 2)) + Y[0, :, 0] = y[:n] + Y[0, :, 1] = y[:n] - 1 + Y[1, :, 0] = y[:n] + 2 + Y[1, :, 1] = y[:n] + 3 + self.check_all_bc(x[:n], Y, 1) + + def test_dtypes(self): + x = np.array([0, 1, 2, 3], dtype=int) + y = np.array([-5, 2, 3, 1], dtype=int) + S = CubicSpline(x, y) + self.check_correctness(S) + + y = np.array([-1 + 1j, 0.0, 1 - 1j, 0.5 - 1.5j]) + S = CubicSpline(x, y) + self.check_correctness(S) + + S = CubicSpline(x, x**3, bc_type=("natural", (1, 2j))) + self.check_correctness(S, "natural", (1, 2j)) + + y = np.array([-5, 2, 3, 1]) + S = CubicSpline(x, y, bc_type=[(1, 2 + 0.5j), (2, 0.5 - 1j)]) + self.check_correctness(S, (1, 2 + 0.5j), (2, 0.5 - 1j)) + + def test_incorrect_inputs(self): + x = np.array([1, 2, 3, 4]) + y = np.array([1, 2, 3, 4]) + xc = np.array([1 + 1j, 2, 3, 4]) + xn = np.array([np.nan, 2, 3, 4]) + xo = np.array([2, 1, 3, 4]) + yn = np.array([np.nan, 2, 3, 4]) + y3 = [1, 2, 3] + x1 = [1] + y1 = [1] + + assert_raises(ValueError, CubicSpline, xc, y) + assert_raises(ValueError, CubicSpline, xn, y) + assert_raises(ValueError, CubicSpline, x, yn) + assert_raises(ValueError, CubicSpline, xo, y) + assert_raises(ValueError, CubicSpline, x, y3) + assert_raises(ValueError, CubicSpline, x[:, np.newaxis], y) + assert_raises(ValueError, CubicSpline, x1, y1) + + wrong_bc = [ + ("periodic", "clamped"), + ((2, 0), (3, 10)), + ((1, 0),), + (0.0, 0.0), + "not-a-typo", + ] + + for bc_type in wrong_bc: + assert_raises(ValueError, CubicSpline, x, y, 0, bc_type, True) + + # Shapes mismatch when giving arbitrary derivative values: + Y = np.c_[y, y] + bc1 = ("clamped", (1, 0)) + bc2 = ("clamped", (1, [0, 0, 0])) + bc3 = ("clamped", (1, [[0, 0]])) + assert_raises(ValueError, CubicSpline, x, Y, 0, bc1, True) + assert_raises(ValueError, CubicSpline, x, Y, 0, bc2, True) + assert_raises(ValueError, CubicSpline, x, Y, 0, bc3, True) + + +def test_CubicHermiteSpline_correctness(): + x = [0, 2, 7] + y = [-1, 2, 3] + dydx = [0, 3, 7] + s = CubicHermiteSpline(x, y, dydx) + assert_allclose(s(x), y, rtol=1e-15) + assert_allclose(s(x, 1), dydx, rtol=1e-15) + + +def test_CubicHermiteSpline_error_handling(): + x = [1, 2, 3] + y = [0, 3, 5] + dydx = [1, -1, 2, 3] + assert_raises(ValueError, CubicHermiteSpline, x, y, dydx) + + dydx_with_nan = [1, 0, np.nan] + assert_raises(ValueError, CubicHermiteSpline, x, y, dydx_with_nan)