Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An interp that supports gradient #194

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autograd/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import
from . import numpy_wrapper
from . import numpy_grads
from . import numpy_grads_interp
from . import numpy_extra
from .numpy_wrapper import *
from . import linalg
Expand Down
68 changes: 68 additions & 0 deletions autograd/numpy/numpy_grads_interp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from . import numpy_wrapper as anp

def _interp_vjp(x, xp, yp, left, right, period, g):
from autograd import vector_jacobian_product
func = vector_jacobian_product(_interp, argnum=2)
return func(x, xp, yp, left, right, period, g)

def _interp(x, xp, yp, left=None, right=None, period=None):
""" A partial rewrite of interp that is differentiable against yp """
if period is not None:
xp = anp.concatenate([[xp[-1] - period], xp, [xp[0] + period]])
yp = anp.concatenate([anp.array([yp[-1]]), yp, anp.array([yp[0]])])
return _interp(x % period, xp, yp, left, right, None)

if left is None: left = yp[0]
if right is None: right = yp[-1]

xp = anp.concatenate([[xp[0]], xp, [xp[-1]]])

yp = anp.concatenate([anp.array([left]), yp, anp.array([right])])
m = make_matrix(x, xp)
y = anp.inner(m, yp)
return y

anp.interp.defvjp(lambda g, ans, vs, gvs, x, xp, yp, left=None, right=None, period=None:
_interp_vjp(x, xp, yp, left, right, period, g), argnum=2)


# The following are internal functions

import numpy as np

def W(r, D):
""" Convolution kernel for linear interpolation.
D is the differences of xp.
"""
mask = D == 0
D[mask] = 1.0
Wleft = 1.0 + r[1:] / D
Wright = 1.0 - r[:-1] / D
# edges
Wleft = np.where(mask, 0, Wleft)
Wright = np.where(mask, 0, Wright)
Wleft = np.concatenate([[0], Wleft])
Wright = np.concatenate([Wright, [0]])
W = np.where(r < 0, Wleft, Wright)
W = np.where(r == 0, 1.0, W)
W = np.where(W < 0, 0, W)
return W

def make_matrix(x, xp):
D = np.diff(xp)
w = []
v0 = np.zeros(len(xp))
v0[0] = 1.0
v1 = np.zeros(len(xp))
v1[-1] = 1.0
for xi in x:
# left, use left
if xi < xp[0]: v = v0
# right , use right
elif xi > xp[-1]: v = v1
else:
v = W(xi - xp, D)
v[0] = 0
v[-1] = 0
w.append(v)
return np.array(w)
41 changes: 41 additions & 0 deletions tests/test_numpy_interp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import absolute_import
import warnings

import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.util import *
from autograd import grad

from numpy.testing import assert_allclose

def test_interp():
x = np.arange(-20, 20, 0.1)
xp = np.arange(10) * 1.0
npr.seed(1)
yp = xp ** 0.5 + npr.normal(size=xp.shape)
def fun(yp): return to_scalar(np.interp(x, xp, yp))
def dfun(yp): return to_scalar(grad(fun)(yp))

check_grads(fun, yp)
check_grads(dfun, yp)

def test_interp_edge():
x = np.arange(-20, 20, 0.1)
xp = np.arange(10) * 1.0
npr.seed(1)
yp = xp ** 0.5 + npr.normal(size=xp.shape)
def fun(yp): return to_scalar(np.interp(x, xp, yp, left=-1, right=-1))
def dfun(yp): return to_scalar(grad(fun)(yp))
check_grads(fun, yp)
check_grads(dfun, yp)

def test_interp_period():
x = np.arange(-20, 20, 0.5)
xp = np.arange(10) * 1.0
npr.seed(1)
yp = xp ** 0.5 + npr.normal(size=xp.shape)
def fun(yp): return to_scalar(np.interp(x, xp, yp, period=10))
def dfun(yp): return to_scalar(grad(fun)(yp))

check_grads(fun, yp)
check_grads(dfun, yp)