Skip to content

Commit

Permalink
Merge pull request #287 from j-towns/tensordot-adjoints
Browse files Browse the repository at this point in the history
Tensordot adjoints
  • Loading branch information
mattjj authored Sep 6, 2017
2 parents 4e517a0 + 82068d0 commit 0c72e22
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 68 deletions.
10 changes: 7 additions & 3 deletions autograd/numpy/numpy_jvps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import numpy_wrapper as anp
from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero,
dot_0_adjoint, dot_1_adjoint)
dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0,
tensordot_adjoint_1)
from autograd.core import (defjvp, defjvps, def_linear_wrt_arg, defjvp_argnum,
def_multilinear, vspace)
from ..util import func
Expand Down Expand Up @@ -185,8 +186,11 @@ def fwd_grad_chooser(g, ans, gvs, vs, x, axis=None, keepdims=False):
def_multilinear(anp.tensordot)
def_multilinear(anp.outer)

def_multilinear(dot_0_adjoint)
def_multilinear(dot_1_adjoint)
def_multilinear(dot_adjoint_0)
def_multilinear(dot_adjoint_1)

def_multilinear(tensordot_adjoint_0)
def_multilinear(tensordot_adjoint_1)

def fwd_grad_concatenate_args(argnum, g, ans, gvs, vs, *axis_args, **kwargs):
result = []
Expand Down
141 changes: 86 additions & 55 deletions autograd/numpy/numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,21 +280,27 @@ def grad_inner(argnum, ans, vs, gvs, A, B):
axes = ([], [])
else:
axes = ([A.ndim - 1], [B.ndim - 1])
return grad_tensordot(argnum, ans, vs, gvs, A, B, axes=axes)
if argnum == 0:
return lambda G: tensordot_adjoint_0(B, G, axes, vs)
elif argnum == 1:
return lambda G: tensordot_adjoint_1(A, G, axes, vs)
defvjps(anp.inner, grad_inner, [0, 1])

def grad_matmul(argnum, ans, vs, gvs, A, B):
if anp.ndim(A) == 0 or anp.ndim(B) == 0:
raise ValueError("Scalar operands are not allowed, use '*' instead")
elif anp.ndim(A) == 1 or anp.ndim(B) == 1 or (anp.ndim(A) == 2 and anp.ndim(B) == 2):
axes = ([A.ndim - 1], [max(0, B.ndim - 2)])
return grad_tensordot(argnum, ans, vs, gvs, A, B, axes=axes)
if argnum == 0:
return lambda G: tensordot_adjoint_0(B, G, axes, vs)
elif argnum == 1:
return lambda G: tensordot_adjoint_1(A, G, axes, vs)
else:
return grad_einsum(argnum + 1, ans, vs, gvs, ("...ij,...jk->...ik", A, B), None)
defvjps(anp.matmul, grad_matmul, [0, 1])

@primitive
def dot_0_adjoint(B, G, A_vs):
def dot_adjoint_0(B, G, A_vs):
# The adjoint of the operator
# A |--> np.dot(A, B)
A_ndim, B_ndim = A_vs.ndim, onp.ndim(B)
Expand All @@ -305,7 +311,7 @@ def dot_0_adjoint(B, G, A_vs):
return onp.tensordot(G, onp.swapaxes(B, -1, -2), B_ndim - 1)

@primitive
def dot_1_adjoint(A, G, B_vs):
def dot_adjoint_1(A, G, B_vs):
# The adjoint of the operator
# B |--> np.dot(A, B)
A_ndim, B_ndim = onp.ndim(A), B_vs.ndim
Expand All @@ -318,60 +324,85 @@ def dot_1_adjoint(A, G, B_vs):
return swap(onp.tensordot(
G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1), range(A_ndim - 1)]))

defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_0_adjoint(B, g, vs))
defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_1_adjoint(A, g, vs), 1)
defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_adjoint_0(B, g, vs))
defvjp(anp.dot, lambda ans, vs, gvs, A, B: lambda g: dot_adjoint_1(A, g, vs), 1)

defvjp(dot_0_adjoint, lambda ans, vs, gvs, B, g, A_vs: lambda A: dot_1_adjoint(A, g, vs))
defvjp(dot_0_adjoint, lambda ans, vs, gvs, B, g, *args: lambda A: anp.dot(A, B), 1)
defvjp(dot_adjoint_0, lambda ans, vs, gvs, B, g, A_vs: lambda A: dot_adjoint_1(A, g, vs))
defvjp(dot_adjoint_0, lambda ans, vs, gvs, B, g, *args: lambda A: anp.dot(A, B), 1)

defvjp(dot_1_adjoint, lambda ans, vs, gvs, A, g, B_vs: lambda B: dot_0_adjoint(B, g, vs))
defvjp(dot_1_adjoint, lambda ans, vs, gvs, A, g, *args: lambda B: anp.dot(A, B), 1)
defvjp(dot_adjoint_1, lambda ans, vs, gvs, A, g, B_vs: lambda B: dot_adjoint_0(B, g, vs))
defvjp(dot_adjoint_1, lambda ans, vs, gvs, A, g, *args: lambda B: anp.dot(A, B), 1)

def grad_tensordot(argnum, ans, vs, gvs, A, B, axes=2):
def vjp(g):
axes_ = axes
if anp.size(A) == anp.size(B) == 0:
return g * B if argnum == 0 else g * A

A_ndim = anp.ndim(A)
g_axes = onp.arange(anp.ndim(g))
if type(axes_) is int:
axes_ = max(axes_, 0)
if argnum == 0:
B_axes = onp.arange(anp.ndim(B))
return anp.tensordot(g, B, [g_axes[A_ndim-axes_:], B_axes[axes_:]])
else:
A_axes = onp.arange(A_ndim)
return anp.tensordot(A, g, [A_axes[:A_ndim-axes_], g_axes[:A_ndim-axes_]])
elif type(axes_[0]) is int:
B_ndim = anp.ndim(B)
axes_ = [axes_[0] % A_ndim, axes_[1] % B_ndim]
if argnum == 0:
B_axes = onp.arange(B_ndim)
return anp.tensordot(g, B, [g_axes[A_ndim-1:], onp.delete(B_axes, axes_[1])])
else:
A_axes = onp.arange(A_ndim)
return anp.tensordot(A, g, [onp.delete(A_axes, axes_[0]), g_axes[:A_ndim-1]])
else:
B_ndim = anp.ndim(B)
A_axes = onp.arange(A_ndim)
B_axes = onp.arange(B_ndim)
summed_axes = [onp.asarray(axes_[0]) % A_ndim,
onp.asarray(axes_[1]) % B_ndim]
other_axes = [onp.delete(A_axes, summed_axes[0]),
onp.delete(B_axes, summed_axes[1])]
if argnum == 0:
out = anp.tensordot(g, B, [g_axes[len(other_axes[0]):], other_axes[1]])
perm = onp.argsort(onp.concatenate(
(other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])])))
return anp.transpose(out, perm)
else:
out = anp.tensordot(A, g, [other_axes[0], g_axes[:len(other_axes[0])]])
perm = onp.argsort(onp.concatenate(
(summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1])))
return anp.transpose(out, perm)
return vjp
defvjps(anp.tensordot, grad_tensordot, [0, 1])
@primitive
def tensordot_adjoint_0(B, G, axes, A_vs):
# The adjoint of the operator
# A |--> np.tensordot(A, B, axes)
if onp.ndim(B) == 0:
return G * B

A_ndim = A_vs.ndim
G_axes = onp.arange(onp.ndim(G))
if type(axes) is int:
axes = max(axes, 0)
B_axes = onp.arange(onp.ndim(B))
return onp.tensordot(G, B, [G_axes[A_ndim-axes:], B_axes[axes:]])
elif type(axes[0]) is int:
B_ndim = onp.ndim(B)
axes = [axes[0] % A_ndim, axes[1] % B_ndim]
B_axes = onp.arange(B_ndim)
return onp.tensordot(G, B, [G_axes[A_ndim-1:], onp.delete(B_axes, axes[1])])
else:
B_ndim = onp.ndim(B)
A_axes = onp.arange(A_ndim)
B_axes = onp.arange(B_ndim)
summed_axes = [onp.asarray(axes[0]) % A_ndim,
onp.asarray(axes[1]) % B_ndim]
other_axes = [onp.delete(A_axes, summed_axes[0]),
onp.delete(B_axes, summed_axes[1])]
out = onp.tensordot(G, B, [G_axes[len(other_axes[0]):], other_axes[1]])
perm = onp.argsort(onp.concatenate(
(other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])])))
return onp.transpose(out, perm)

@primitive
def tensordot_adjoint_1(A, G, axes, B_vs):
# The adjoint of the operator
# B |--> np.tensordot(A, B, axes)
if onp.ndim(A) == 0:
return G * A

A_ndim = onp.ndim(A)
G_axes = onp.arange(onp.ndim(G))
if type(axes) is int:
axes = max(axes, 0)
A_axes = onp.arange(A_ndim)
return onp.tensordot(A, G, [A_axes[:A_ndim-axes], G_axes[:A_ndim-axes]])
elif type(axes[0]) is int:
B_ndim = B_vs.ndim
axes = [axes[0] % A_ndim, axes[1] % B_ndim]
A_axes = onp.arange(A_ndim)
return onp.tensordot(A, G, [onp.delete(A_axes, axes[0]), G_axes[:A_ndim-1]])
else:
B_ndim = B_vs.ndim
A_axes = onp.arange(A_ndim)
B_axes = onp.arange(B_ndim)
summed_axes = [onp.asarray(axes[0]) % A_ndim,
onp.asarray(axes[1]) % B_ndim]
other_axes = [onp.delete(A_axes, summed_axes[0]),
onp.delete(B_axes, summed_axes[1])]
out = onp.tensordot(A, G, [other_axes[0], G_axes[:len(other_axes[0])]])
perm = onp.argsort(onp.concatenate(
(summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1])))
return onp.transpose(out, perm)

defvjp(anp.tensordot, lambda ans, vs, gvs, A, B, axes=2: lambda G: tensordot_adjoint_0(B, G, axes, vs))
defvjp(anp.tensordot, lambda ans, vs, gvs, A, B, axes=2: lambda G: tensordot_adjoint_1(A, G, axes, vs), 1)

defvjp(tensordot_adjoint_0, lambda ans, vs, gvs, B, G, axes, A_vs: lambda A: tensordot_adjoint_1(A, G, axes, vs))
defvjp(tensordot_adjoint_0, lambda ans, vs, gvs, B, G, axes, A_vs: lambda A: anp.tensordot(A, B, axes), 1)

defvjp(tensordot_adjoint_1, lambda ans, vs, gvs, A, G, axes, B_vs: lambda B: tensordot_adjoint_0(B, G, axes, vs))
defvjp(tensordot_adjoint_1, lambda ans, vs, gvs, A, G, axes, B_vs: lambda B: anp.tensordot(A, B, axes), 1)

defvjp(anp.outer, lambda ans, vs, gvs, a, b : lambda g: anp.dot(g, b.T))
defvjp(anp.outer, lambda ans, vs, gvs, a, b : lambda g: anp.dot(a.T, g), argnum=1)
Expand Down
43 changes: 40 additions & 3 deletions benchmarks/bench_numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
B = npr.randn(2, 3, 5, 4)
g = npr.randn(2, 3, 4, 2, 3, 4)

def time_dot():
np.dot(A, B)

def time_dot_0():
dot_0(A, B, g)

Expand All @@ -44,3 +41,43 @@ def time_dot_1_1():

def time_dot_1_2():
dot_1_2(A, B, g)

tensordot_0 = lambda A, B, G: make_vjp(np.tensordot)(A, B, 2)[0](G)
tensordot_1 = lambda A, B, G: make_vjp(np.tensordot, argnum=1)(A, B, 2)[0](G)

tensordot_0_0 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A)
tensordot_0_1 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A)
tensordot_0_2 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A)

tensordot_1_0 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B)
tensordot_1_1 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B)
tensordot_1_2 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B)

A = npr.randn(2, 3, 5, 4)
B = npr.randn(5, 4, 2, 3)
G = npr.randn(2, 3, 2, 3)

def time_tensordot_0():
tensordot_0(A, B, G)

def time_tensordot_1():
tensordot_1(A, B, G)

def time_tensordot_0_0():
tensordot_0_0(A, B, G)

def time_tensordot_0_1():
tensordot_0_1(A, B, G)

def time_tensordot_0_2():
tensordot_0_2(A, B, G)

def time_tensordot_1_0():
tensordot_1_0(A, B, G)

def time_tensordot_1_1():
tensordot_1_1(A, B, G)

def time_tensordot_1_2():
tensordot_1_2(A, B, G)

14 changes: 7 additions & 7 deletions tests/test_systematic.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,25 @@ def test_matmul(): combo_check(np.matmul, [0, 1])(
[R(3), R(2, 3), R(2, 2, 3)],
[R(3), R(3, 4), R(2, 3, 4)])
def test_matmul_broadcast(): combo_check(np.matmul, [0, 1])([R(1, 2, 2)], [R(3, 2, 1)])
def test_tensordot_1(): combo_check(np.tensordot, [0, 1])(
def test_tensordot_1(): combo_check(np.tensordot, [0, 1], order=3)(
[R(1, 3), R(2, 3, 2)],
[R(3), R(3, 1), R(3, 4, 2)],
axes=[ [(1,), (0,)] ])
def test_tensordot_2(): combo_check(np.tensordot, [0, 1])(
def test_tensordot_2(): combo_check(np.tensordot, [0, 1], order=3)(
[R(3), R(3, 1), R(3, 4, 2)],
[R(1, 3), R(2, 3, 2)],
axes=[ [(0,), (1,)] ])
def test_tensordot_3(): combo_check(np.tensordot, [0, 1])(
def test_tensordot_3(): combo_check(np.tensordot, [0, 1], order=3)(
[R(2, 3), R(2, 3, 4)],
[R(1, 2, 3), R(2, 2, 3, 4)],
axes=[ [(0, 1), (1, 2)] , [(1, 0), (2, 1)] ])
def test_tensordot_4(): combo_check(np.tensordot, [0, 1])(
def test_tensordot_4(): combo_check(np.tensordot, [0, 1], order=3)(
[R(2, 2), R(4, 2, 2)],
[R(2, 2), R(2, 2, 4)],
axes=[1, 2])
def test_tensordot_5(): combo_check(np.tensordot, [0, 1])([R(4)], [R()], axes=[0])
def test_tensordot_6(): combo_check(np.tensordot, [0, 1])([R(2,6)], [R(6,3)], axes=[[[-1], [0]]])
def test_tensordot_7(): combo_check(np.tensordot, [0, 1])([R(2,6)], [R(6,3)], axes=[[-1, 0]])
def test_tensordot_5(): combo_check(np.tensordot, [0, 1], order=3)([R(4)], [R()], axes=[0])
def test_tensordot_6(): combo_check(np.tensordot, [0, 1], order=3)([R(2,6)], [R(6,3)], axes=[[[-1], [0]]])
def test_tensordot_7(): combo_check(np.tensordot, [0, 1], order=3)([R(2,6)], [R(6,3)], axes=[[-1, 0]])

# Need custom tests because gradient is undefined when arguments are identical.
def test_maximum(): combo_check(np.maximum, [0, 1])(
Expand Down

0 comments on commit 0c72e22

Please sign in to comment.