diff --git a/CHANGELOG.md b/CHANGELOG.md index e9311c4..d6b626b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,44 @@ Changelog ========= +v0.2.4 +------ +- Fixes for scalar valued query points +- Fixes for interpolating vector valued functions + +**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.3...v0.2.4 + + +v0.2.3 +------ +- Add type annotations + +**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.2...v0.2.3 + + +v0.2.2 +------ +- Add ``approx_df`` to public API + +**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.1...v0.2.2 + + +v0.2.1 +------ +- More efficient nearest neighbor search +- Correct slopes for linear interpolation in 2d, 3d +- Fix for cubic2 splines in 2d and 3d +Forward and reverse mode AD now fully working and tested + +**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.0...v0.2.1 + + v0.2.0 ------- - Adds convenience classes for spline interpolation that cache the derivative calculation. +**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.1.0...v0.2.0 + v0.1.0 ------ diff --git a/interpax/_spline.py b/interpax/_spline.py index 76eb5f1..380c307 100644 --- a/interpax/_spline.py +++ b/interpax/_spline.py @@ -513,9 +513,9 @@ def derivative0(): delta = xq - x[i - 1] fq = jnp.where( (dx == 0), - jnp.take(f, i, axis), - jnp.take(f, i - 1, axis) + delta * dxi * df, - ) + jnp.take(f, i, axis).T, + jnp.take(f, i - 1, axis).T + (delta * dxi * df.T), + ).T return fq def derivative1(): @@ -523,7 +523,7 @@ def derivative1(): df = jnp.take(f, i, axis) - jnp.take(f, i - 1, axis) dx = x[i] - x[i - 1] dxi = jnp.where(dx == 0, 0, 1 / dx) - return df * dxi + return (df.T * dxi).T def derivative2(): return jnp.zeros((xq.size, *f.shape[1:])) @@ -544,11 +544,11 @@ def derivative2(): f0 = jnp.take(f, i - 1, axis) f1 = jnp.take(f, i, axis) - fx0 = jnp.take(fx, i - 1, axis) * dx - fx1 = jnp.take(fx, i, axis) * dx + fx0 = (jnp.take(fx, i - 1, axis).T * dx).T + fx1 = (jnp.take(fx, i, axis).T * dx).T - F = jnp.vstack([f0, f1, fx0, fx1]) - coef = jnp.matmul(A_CUBIC, F) + F = jnp.stack([f0, f1, fx0, fx1], axis=0).T + coef = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_CUBIC, F).T ttx = _get_t_der(t, derivative, dxi) fq = jnp.einsum("ji...,ij->i...", coef, ttx) @@ -666,12 +666,12 @@ def derivative0(): [[x[i], x[i - 1], x[i], x[i - 1]], [y[j], y[j], y[j - 1], y[j - 1]]] ) neighbors_f = jnp.array( - [f[i, j], f[i - 1, j], f[i, j - 1], f[i - 1, j - 1]] + [f[i, j].T, f[i - 1, j].T, f[i, j - 1].T, f[i - 1, j - 1].T] ) xyq = jnp.array([xq, yq]) dist = jnp.linalg.norm(neighbors_x - xyq[:, None, :], axis=0) idx = jnp.argmin(dist, axis=0) - return jax.vmap(jnp.take)(neighbors_f.T, idx) + return jax.vmap(lambda a, b: jnp.take(a, b, axis=-1))(neighbors_f.T, idx) def derivative1(): return jnp.zeros((xq.size, *f.shape[2:])) @@ -708,7 +708,7 @@ def derivative1(): tx = jax.lax.switch(derivative_x, [dx0, dx1, dx2]) ty = jax.lax.switch(derivative_y, [dy0, dy1, dy2]) F = jnp.array([[f00, f01], [f10, f11]]) - fq = dxi * dyi * jnp.einsum("ijk...,ik,jk->k...", F, tx, ty) + fq = (dxi * dyi * jnp.einsum("ijk...,ik,jk->k...", F, tx, ty).T).T elif method in CUBIC_METHODS: if fx is None: @@ -740,15 +740,16 @@ def derivative1(): for ff in fs.keys(): for jj in [0, 1]: for ii in [0, 1]: - fsq[ff + str(ii) + str(jj)] = fs[ff][i - 1 + ii, j - 1 + jj] + s = ff + str(ii) + str(jj) + fsq[s] = fs[ff][i - 1 + ii, j - 1 + jj] if "x" in ff: - fsq[ff + str(ii) + str(jj)] *= dx + fsq[s] = (dx * fsq[s].T).T if "y" in ff: - fsq[ff + str(ii) + str(jj)] *= dy + fsq[s] = (dy * fsq[s].T).T - F = jnp.vstack([foo for foo in fsq.values()]) - coef = jnp.matmul(A_BICUBIC, F) - coef = jnp.moveaxis(coef.reshape((4, 4, -1), order="F"), -1, 0) + F = jnp.stack([foo for foo in fsq.values()], axis=0).T + coef = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_BICUBIC, F).T + coef = jnp.moveaxis(coef.reshape((4, 4, *coef.shape[1:]), order="F"), 2, 0) ttx = _get_t_der(tx, derivative_x, dxi) tty = _get_t_der(ty, derivative_y, dyi) fq = jnp.einsum("ijk...,ij,ik->i...", coef, ttx, tty) @@ -900,20 +901,20 @@ def derivative0(): ) neighbors_f = jnp.array( [ - f[i, j, k], - f[i - 1, j, k], - f[i, j - 1, k], - f[i - 1, j - 1, k], - f[i, j, k - 1], - f[i - 1, j, k - 1], - f[i, j - 1, k - 1], - f[i - 1, j - 1, k - 1], + f[i, j, k].T, + f[i - 1, j, k].T, + f[i, j - 1, k].T, + f[i - 1, j - 1, k].T, + f[i, j, k - 1].T, + f[i - 1, j, k - 1].T, + f[i, j - 1, k - 1].T, + f[i - 1, j - 1, k - 1].T, ] ) xyzq = jnp.array([xq, yq, zq]) dist = jnp.linalg.norm(neighbors_x - xyzq[:, None, :], axis=0) idx = jnp.argmin(dist, axis=0) - return jax.vmap(jnp.take)(neighbors_f.T, idx) + return jax.vmap(lambda a, b: jnp.take(a, b, axis=-1))(neighbors_f.T, idx) def derivative1(): return jnp.zeros((xq.size, *f.shape[3:])) @@ -966,7 +967,7 @@ def derivative1(): tz = jax.lax.switch(derivative_z, [dz0, dz1, dz2]) F = jnp.array([[[f000, f001], [f010, f011]], [[f100, f101], [f110, f111]]]) - fq = dxi * dyi * dzi * jnp.einsum("lijk...,lk,ik,jk->k...", F, tx, ty, tz) + fq = (dxi * dyi * dzi * jnp.einsum("lijk...,lk,ik,jk->k...", F, tx, ty, tz).T).T elif method in CUBIC_METHODS: if fx is None: @@ -1026,19 +1027,18 @@ def derivative1(): for kk in [0, 1]: for jj in [0, 1]: for ii in [0, 1]: - fsq[ff + str(ii) + str(jj) + str(kk)] = fs[ff][ - i - 1 + ii, j - 1 + jj, k - 1 + kk - ] + s = ff + str(ii) + str(jj) + str(kk) + fsq[s] = fs[ff][i - 1 + ii, j - 1 + jj, k - 1 + kk] if "x" in ff: - fsq[ff + str(ii) + str(jj) + str(kk)] *= dx + fsq[s] = (dx * fsq[s].T).T if "y" in ff: - fsq[ff + str(ii) + str(jj) + str(kk)] *= dy + fsq[s] = (dy * fsq[s].T).T if "z" in ff: - fsq[ff + str(ii) + str(jj) + str(kk)] *= dz + fsq[s] = (dz * fsq[s].T).T - F = jnp.vstack([foo for foo in fsq.values()]) - coef = jnp.matmul(A_TRICUBIC, F) - coef = jnp.moveaxis(coef.reshape((4, 4, 4, -1), order="F"), -1, 0) + F = jnp.stack([foo for foo in fsq.values()], axis=0).T + coef = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_TRICUBIC, F).T + coef = jnp.moveaxis(coef.reshape((4, 4, 4, *coef.shape[1:]), order="F"), 3, 0) ttx = _get_t_der(tx, derivative_x, dxi) tty = _get_t_der(ty, derivative_y, dyi) ttz = _get_t_der(tz, derivative_z, dzi) @@ -1129,13 +1129,13 @@ def loclip(fq, lo): # lo is either False (no extrapolation) or a fixed value to fill in if isbool(lo): lo = jnp.nan - return jnp.where(xq < x[0], lo, fq) + return jnp.where(xq < x[0], lo, fq.T).T def hiclip(fq, hi): # hi is either False (no extrapolation) or a fixed value to fill in if isbool(hi): hi = jnp.nan - return jnp.where(xq > x[-1], hi, fq) + return jnp.where(xq > x[-1], hi, fq.T).T def noclip(fq, *_): return fq diff --git a/tests/test_interpolate.py b/tests/test_interpolate.py index 20aa9f0..128dac5 100644 --- a/tests/test_interpolate.py +++ b/tests/test_interpolate.py @@ -65,6 +65,38 @@ def test_interp1d(self, x): fq = interp(x, xp, fp, method="monotonic-0") np.testing.assert_allclose(fq, f(x), rtol=1e-4, atol=1e-2) + @pytest.mark.unit + def test_interp1d_vector_valued(self): + """Test for interpolating vector valued function.""" + xp = np.linspace(0, 2 * np.pi, 100) + x = np.linspace(0, 2 * np.pi, 300)[10:-10] + f = lambda x: np.array([np.sin(x), np.cos(x)]) + fp = f(xp).T + + fq = interp1d(x, xp, fp, method="nearest") + np.testing.assert_allclose(fq, f(x).T, rtol=1e-2, atol=1e-1) + + fq = interp1d(x, xp, fp, method="linear") + np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-3) + + fq = interp1d(x, xp, fp, method="cubic") + np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5) + + fq = interp1d(x, xp, fp, method="cubic2") + np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5) + + fq = interp1d(x, xp, fp, method="cardinal") + np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5) + + fq = interp1d(x, xp, fp, method="catmull-rom") + np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5) + + fq = interp1d(x, xp, fp, method="monotonic") + np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-3) + + fq = interp1d(x, xp, fp, method="monotonic-0") + np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-2) + @pytest.mark.unit def test_interp1d_extrap_periodic(self): """Test extrapolation and periodic BC of 1d interpolation.""" @@ -156,6 +188,27 @@ def test_interp2d(self, x, y): ) np.testing.assert_allclose(fq, f(x, y), rtol=rtol, atol=atol) + @pytest.mark.unit + def test_interp2d_vector_valued(self): + """Test for interpolating vector valued function.""" + xp = np.linspace(0, 3 * np.pi, 99) + yp = np.linspace(0, 2 * np.pi, 40) + x = np.linspace(0, 3 * np.pi, 200) + y = np.linspace(0, 2 * np.pi, 200) + xxp, yyp = np.meshgrid(xp, yp, indexing="ij") + + f = lambda x, y: np.array([np.sin(x) * np.cos(y), np.sin(x) + np.cos(y)]) + fp = f(xxp.T, yyp.T).T + + fq = interp2d(x, y, xp, yp, fp, method="nearest") + np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-2, atol=1.2e-1) + + fq = interp2d(x, y, xp, yp, fp, method="linear") + np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-3, atol=1e-2) + + fq = interp2d(x, y, xp, yp, fp, method="cubic") + np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-5, atol=2e-3) + class TestInterp3D: """Tests for interp3d function.""" @@ -213,6 +266,31 @@ def test_interp3d(self, x, y, z): fq = interp(x, y, z, xp, yp, zp, fp, method="cardinal") np.testing.assert_allclose(fq, f(x, y, z), rtol=rtol, atol=atol) + @pytest.mark.unit + def test_interp3d_vector_valued(self): + """Test for interpolating vector valued function.""" + x = np.linspace(0, np.pi, 1000) + y = np.linspace(0, 2 * np.pi, 1000) + z = np.linspace(0, 3, 1000) + xp = np.linspace(0, np.pi, 20) + yp = np.linspace(0, 2 * np.pi, 30) + zp = np.linspace(0, 3, 25) + xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij") + + f = lambda x, y, z: np.array( + [np.sin(x) * np.cos(y) * z**2, 0.1 * (x + y - z)] + ) + fp = f(xxp.T, yyp.T, zzp.T).T + + fq = interp3d(x, y, z, xp, yp, zp, fp, method="nearest") + np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-2, atol=1) + + fq = interp3d(x, y, z, xp, yp, zp, fp, method="linear") + np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-3, atol=1e-1) + + fq = interp3d(x, y, z, xp, yp, zp, fp, method="cubic") + np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-5, atol=5e-3) + @pytest.mark.unit def test_fft_interp1d():