Skip to content

Commit

Permalink
Merge pull request #13 from f0uriest/rc/vector_valued
Browse files Browse the repository at this point in the history
Fixes for interpolating vector valued functions
  • Loading branch information
f0uriest authored Nov 28, 2023
2 parents 4160329 + 777f49c commit 9633d30
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 38 deletions.
34 changes: 34 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,44 @@
Changelog
=========

v0.2.4
------
- Fixes for scalar valued query points
- Fixes for interpolating vector valued functions

**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.3...v0.2.4


v0.2.3
------
- Add type annotations

**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.2...v0.2.3


v0.2.2
------
- Add ``approx_df`` to public API

**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.1...v0.2.2


v0.2.1
------
- More efficient nearest neighbor search
- Correct slopes for linear interpolation in 2d, 3d
- Fix for cubic2 splines in 2d and 3d
Forward and reverse mode AD now fully working and tested

**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.2.0...v0.2.1


v0.2.0
-------
- Adds convenience classes for spline interpolation that cache the derivative calculation.

**Full Changelog**: https://github.com/f0uriest/interpax/compare/v0.1.0...v0.2.0


v0.1.0
------
Expand Down
76 changes: 38 additions & 38 deletions interpax/_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,17 +513,17 @@ def derivative0():
delta = xq - x[i - 1]
fq = jnp.where(
(dx == 0),
jnp.take(f, i, axis),
jnp.take(f, i - 1, axis) + delta * dxi * df,
)
jnp.take(f, i, axis).T,
jnp.take(f, i - 1, axis).T + (delta * dxi * df.T),
).T
return fq

def derivative1():
i = jnp.clip(jnp.searchsorted(x, xq, side="right"), 1, len(x) - 1)
df = jnp.take(f, i, axis) - jnp.take(f, i - 1, axis)
dx = x[i] - x[i - 1]
dxi = jnp.where(dx == 0, 0, 1 / dx)
return df * dxi
return (df.T * dxi).T

def derivative2():
return jnp.zeros((xq.size, *f.shape[1:]))
Expand All @@ -544,11 +544,11 @@ def derivative2():

f0 = jnp.take(f, i - 1, axis)
f1 = jnp.take(f, i, axis)
fx0 = jnp.take(fx, i - 1, axis) * dx
fx1 = jnp.take(fx, i, axis) * dx
fx0 = (jnp.take(fx, i - 1, axis).T * dx).T
fx1 = (jnp.take(fx, i, axis).T * dx).T

F = jnp.vstack([f0, f1, fx0, fx1])
coef = jnp.matmul(A_CUBIC, F)
F = jnp.stack([f0, f1, fx0, fx1], axis=0).T
coef = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_CUBIC, F).T
ttx = _get_t_der(t, derivative, dxi)
fq = jnp.einsum("ji...,ij->i...", coef, ttx)

Expand Down Expand Up @@ -666,12 +666,12 @@ def derivative0():
[[x[i], x[i - 1], x[i], x[i - 1]], [y[j], y[j], y[j - 1], y[j - 1]]]
)
neighbors_f = jnp.array(
[f[i, j], f[i - 1, j], f[i, j - 1], f[i - 1, j - 1]]
[f[i, j].T, f[i - 1, j].T, f[i, j - 1].T, f[i - 1, j - 1].T]
)
xyq = jnp.array([xq, yq])
dist = jnp.linalg.norm(neighbors_x - xyq[:, None, :], axis=0)
idx = jnp.argmin(dist, axis=0)
return jax.vmap(jnp.take)(neighbors_f.T, idx)
return jax.vmap(lambda a, b: jnp.take(a, b, axis=-1))(neighbors_f.T, idx)

def derivative1():
return jnp.zeros((xq.size, *f.shape[2:]))
Expand Down Expand Up @@ -708,7 +708,7 @@ def derivative1():
tx = jax.lax.switch(derivative_x, [dx0, dx1, dx2])
ty = jax.lax.switch(derivative_y, [dy0, dy1, dy2])
F = jnp.array([[f00, f01], [f10, f11]])
fq = dxi * dyi * jnp.einsum("ijk...,ik,jk->k...", F, tx, ty)
fq = (dxi * dyi * jnp.einsum("ijk...,ik,jk->k...", F, tx, ty).T).T

elif method in CUBIC_METHODS:
if fx is None:
Expand Down Expand Up @@ -740,15 +740,16 @@ def derivative1():
for ff in fs.keys():
for jj in [0, 1]:
for ii in [0, 1]:
fsq[ff + str(ii) + str(jj)] = fs[ff][i - 1 + ii, j - 1 + jj]
s = ff + str(ii) + str(jj)
fsq[s] = fs[ff][i - 1 + ii, j - 1 + jj]
if "x" in ff:
fsq[ff + str(ii) + str(jj)] *= dx
fsq[s] = (dx * fsq[s].T).T
if "y" in ff:
fsq[ff + str(ii) + str(jj)] *= dy
fsq[s] = (dy * fsq[s].T).T

F = jnp.vstack([foo for foo in fsq.values()])
coef = jnp.matmul(A_BICUBIC, F)
coef = jnp.moveaxis(coef.reshape((4, 4, -1), order="F"), -1, 0)
F = jnp.stack([foo for foo in fsq.values()], axis=0).T
coef = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_BICUBIC, F).T
coef = jnp.moveaxis(coef.reshape((4, 4, *coef.shape[1:]), order="F"), 2, 0)
ttx = _get_t_der(tx, derivative_x, dxi)
tty = _get_t_der(ty, derivative_y, dyi)
fq = jnp.einsum("ijk...,ij,ik->i...", coef, ttx, tty)
Expand Down Expand Up @@ -900,20 +901,20 @@ def derivative0():
)
neighbors_f = jnp.array(
[
f[i, j, k],
f[i - 1, j, k],
f[i, j - 1, k],
f[i - 1, j - 1, k],
f[i, j, k - 1],
f[i - 1, j, k - 1],
f[i, j - 1, k - 1],
f[i - 1, j - 1, k - 1],
f[i, j, k].T,
f[i - 1, j, k].T,
f[i, j - 1, k].T,
f[i - 1, j - 1, k].T,
f[i, j, k - 1].T,
f[i - 1, j, k - 1].T,
f[i, j - 1, k - 1].T,
f[i - 1, j - 1, k - 1].T,
]
)
xyzq = jnp.array([xq, yq, zq])
dist = jnp.linalg.norm(neighbors_x - xyzq[:, None, :], axis=0)
idx = jnp.argmin(dist, axis=0)
return jax.vmap(jnp.take)(neighbors_f.T, idx)
return jax.vmap(lambda a, b: jnp.take(a, b, axis=-1))(neighbors_f.T, idx)

def derivative1():
return jnp.zeros((xq.size, *f.shape[3:]))
Expand Down Expand Up @@ -966,7 +967,7 @@ def derivative1():
tz = jax.lax.switch(derivative_z, [dz0, dz1, dz2])

F = jnp.array([[[f000, f001], [f010, f011]], [[f100, f101], [f110, f111]]])
fq = dxi * dyi * dzi * jnp.einsum("lijk...,lk,ik,jk->k...", F, tx, ty, tz)
fq = (dxi * dyi * dzi * jnp.einsum("lijk...,lk,ik,jk->k...", F, tx, ty, tz).T).T

elif method in CUBIC_METHODS:
if fx is None:
Expand Down Expand Up @@ -1026,19 +1027,18 @@ def derivative1():
for kk in [0, 1]:
for jj in [0, 1]:
for ii in [0, 1]:
fsq[ff + str(ii) + str(jj) + str(kk)] = fs[ff][
i - 1 + ii, j - 1 + jj, k - 1 + kk
]
s = ff + str(ii) + str(jj) + str(kk)
fsq[s] = fs[ff][i - 1 + ii, j - 1 + jj, k - 1 + kk]
if "x" in ff:
fsq[ff + str(ii) + str(jj) + str(kk)] *= dx
fsq[s] = (dx * fsq[s].T).T
if "y" in ff:
fsq[ff + str(ii) + str(jj) + str(kk)] *= dy
fsq[s] = (dy * fsq[s].T).T
if "z" in ff:
fsq[ff + str(ii) + str(jj) + str(kk)] *= dz
fsq[s] = (dz * fsq[s].T).T

F = jnp.vstack([foo for foo in fsq.values()])
coef = jnp.matmul(A_TRICUBIC, F)
coef = jnp.moveaxis(coef.reshape((4, 4, 4, -1), order="F"), -1, 0)
F = jnp.stack([foo for foo in fsq.values()], axis=0).T
coef = jnp.vectorize(jnp.matmul, signature="(n,n),(n)->(n)")(A_TRICUBIC, F).T
coef = jnp.moveaxis(coef.reshape((4, 4, 4, *coef.shape[1:]), order="F"), 3, 0)
ttx = _get_t_der(tx, derivative_x, dxi)
tty = _get_t_der(ty, derivative_y, dyi)
ttz = _get_t_der(tz, derivative_z, dzi)
Expand Down Expand Up @@ -1129,13 +1129,13 @@ def loclip(fq, lo):
# lo is either False (no extrapolation) or a fixed value to fill in
if isbool(lo):
lo = jnp.nan
return jnp.where(xq < x[0], lo, fq)
return jnp.where(xq < x[0], lo, fq.T).T

def hiclip(fq, hi):
# hi is either False (no extrapolation) or a fixed value to fill in
if isbool(hi):
hi = jnp.nan
return jnp.where(xq > x[-1], hi, fq)
return jnp.where(xq > x[-1], hi, fq.T).T

def noclip(fq, *_):
return fq
Expand Down
78 changes: 78 additions & 0 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,38 @@ def test_interp1d(self, x):
fq = interp(x, xp, fp, method="monotonic-0")
np.testing.assert_allclose(fq, f(x), rtol=1e-4, atol=1e-2)

@pytest.mark.unit
def test_interp1d_vector_valued(self):
"""Test for interpolating vector valued function."""
xp = np.linspace(0, 2 * np.pi, 100)
x = np.linspace(0, 2 * np.pi, 300)[10:-10]
f = lambda x: np.array([np.sin(x), np.cos(x)])
fp = f(xp).T

fq = interp1d(x, xp, fp, method="nearest")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-2, atol=1e-1)

fq = interp1d(x, xp, fp, method="linear")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-3)

fq = interp1d(x, xp, fp, method="cubic")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5)

fq = interp1d(x, xp, fp, method="cubic2")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5)

fq = interp1d(x, xp, fp, method="cardinal")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5)

fq = interp1d(x, xp, fp, method="catmull-rom")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-6, atol=1e-5)

fq = interp1d(x, xp, fp, method="monotonic")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-3)

fq = interp1d(x, xp, fp, method="monotonic-0")
np.testing.assert_allclose(fq, f(x).T, rtol=1e-4, atol=1e-2)

@pytest.mark.unit
def test_interp1d_extrap_periodic(self):
"""Test extrapolation and periodic BC of 1d interpolation."""
Expand Down Expand Up @@ -156,6 +188,27 @@ def test_interp2d(self, x, y):
)
np.testing.assert_allclose(fq, f(x, y), rtol=rtol, atol=atol)

@pytest.mark.unit
def test_interp2d_vector_valued(self):
"""Test for interpolating vector valued function."""
xp = np.linspace(0, 3 * np.pi, 99)
yp = np.linspace(0, 2 * np.pi, 40)
x = np.linspace(0, 3 * np.pi, 200)
y = np.linspace(0, 2 * np.pi, 200)
xxp, yyp = np.meshgrid(xp, yp, indexing="ij")

f = lambda x, y: np.array([np.sin(x) * np.cos(y), np.sin(x) + np.cos(y)])
fp = f(xxp.T, yyp.T).T

fq = interp2d(x, y, xp, yp, fp, method="nearest")
np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-2, atol=1.2e-1)

fq = interp2d(x, y, xp, yp, fp, method="linear")
np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-3, atol=1e-2)

fq = interp2d(x, y, xp, yp, fp, method="cubic")
np.testing.assert_allclose(fq, f(x, y).T, rtol=1e-5, atol=2e-3)


class TestInterp3D:
"""Tests for interp3d function."""
Expand Down Expand Up @@ -213,6 +266,31 @@ def test_interp3d(self, x, y, z):
fq = interp(x, y, z, xp, yp, zp, fp, method="cardinal")
np.testing.assert_allclose(fq, f(x, y, z), rtol=rtol, atol=atol)

@pytest.mark.unit
def test_interp3d_vector_valued(self):
"""Test for interpolating vector valued function."""
x = np.linspace(0, np.pi, 1000)
y = np.linspace(0, 2 * np.pi, 1000)
z = np.linspace(0, 3, 1000)
xp = np.linspace(0, np.pi, 20)
yp = np.linspace(0, 2 * np.pi, 30)
zp = np.linspace(0, 3, 25)
xxp, yyp, zzp = np.meshgrid(xp, yp, zp, indexing="ij")

f = lambda x, y, z: np.array(
[np.sin(x) * np.cos(y) * z**2, 0.1 * (x + y - z)]
)
fp = f(xxp.T, yyp.T, zzp.T).T

fq = interp3d(x, y, z, xp, yp, zp, fp, method="nearest")
np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-2, atol=1)

fq = interp3d(x, y, z, xp, yp, zp, fp, method="linear")
np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-3, atol=1e-1)

fq = interp3d(x, y, z, xp, yp, zp, fp, method="cubic")
np.testing.assert_allclose(fq, f(x, y, z).T, rtol=1e-5, atol=5e-3)


@pytest.mark.unit
def test_fft_interp1d():
Expand Down

0 comments on commit 9633d30

Please sign in to comment.