Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sketching out what TJPs might look like #280

Open
wants to merge 8 commits into
base: dougal-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autograd/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import numpy_vspaces
from . import numpy_vjps
from . import numpy_jvps
from . import numpy_tjps
from . import linalg
from . import fft
from . import random
133 changes: 133 additions & 0 deletions autograd/numpy/numpy_tjps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from __future__ import absolute_import
import numpy as onp
from functools import partial
from ..util import func # TODO(mattjj): should this import use autograd.util, not ..util?
from autograd.tracer import primitive, getval
from autograd.vspace import vspace
from autograd.core import SparseObject
from autograd.tjp import deftjp, vjps_are_tjps
from . import numpy_wrapper as anp
from .numpy_boxes import ArrayBox

# ----- Binary ufuncs -----

# The only difference here is we have to use a modified unbroadcast function,
# which handles leading dimensions (if they exist). Otherwise, the expressions
# used in the VJPs already broadcast along leading dimensions of g.

deftjp(anp.add, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g))
deftjp(anp.add, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g), argnum=1)
deftjp(anp.multiply, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, y * g))
deftjp(anp.multiply, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, x * g), argnum=1)
deftjp(anp.subtract, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g))
deftjp(anp.subtract, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, -g), argnum=1)
deftjp(anp.divide, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g / y))
deftjp(anp.divide, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, - g * x / y**2), argnum=1)
deftjp(anp.maximum, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(x, ans, y)))
deftjp(anp.maximum, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(y, ans, x)), argnum=1)
deftjp(anp.minimum, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(x, ans, y)))
deftjp(anp.minimum, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(y, ans, x)), argnum=1)
deftjp(anp.fmax, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(x, ans, y)))
deftjp(anp.fmax, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(y, ans, x)), argnum=1)
deftjp(anp.fmin, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(x, ans, y)))
deftjp(anp.fmin, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * balanced_eq(y, ans, x)), argnum=1)
deftjp(anp.logaddexp, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * anp.exp(x-ans)))
deftjp(anp.logaddexp, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * anp.exp(y-ans)), argnum=1)
deftjp(anp.logaddexp2, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * 2**(x-ans)))
deftjp(anp.logaddexp2, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g * 2**(y-ans)), argnum=1)
deftjp(anp.true_divide, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g / y))
deftjp(anp.true_divide, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, - g * x / y**2), argnum=1)
deftjp(anp.mod, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g))
deftjp(anp.remainder, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, g))
deftjp(anp.mod, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, -g * anp.floor(x/y)), argnum=1)
deftjp(anp.remainder, lambda ans, vs, out_vs, x, y : lambda g: unbroadcast(vs, out_vs, -g * anp.floor(x/y)), argnum=1)
deftjp(anp.power,
lambda ans, vs, out_vs, x, y : lambda g:
unbroadcast(vs, out_vs, g * y * x ** anp.where(y, y - 1, 1.)))
deftjp(anp.power,
lambda ans, vs, out_vs, x, y : lambda g:
unbroadcast(vs, out_vs, g * anp.log(replace_zero(x, 1.)) * x ** y), argnum=1)

# ----- Simple grads -----

# Some VJP implementations already broadcast along leading dimensions of g, so
# they work as TJP definitions too. We use the vjps_are_tjps function for that.

vjps_are_tjps(anp.absolute)
vjps_are_tjps(anp.reciprocal)
vjps_are_tjps(anp.exp)
vjps_are_tjps(anp.exp2)
vjps_are_tjps(anp.expm1)
vjps_are_tjps(anp.log)
vjps_are_tjps(anp.log2)
vjps_are_tjps(anp.log10)
vjps_are_tjps(anp.log1p)
vjps_are_tjps(anp.sin)
vjps_are_tjps(anp.cos)
vjps_are_tjps(anp.tan)
vjps_are_tjps(anp.arcsin)
vjps_are_tjps(anp.arccos)
vjps_are_tjps(anp.arctan)
vjps_are_tjps(anp.sinh)
vjps_are_tjps(anp.cosh)
vjps_are_tjps(anp.tanh)
vjps_are_tjps(anp.arcsinh)
vjps_are_tjps(anp.arccosh)
vjps_are_tjps(anp.arctanh)
vjps_are_tjps(anp.rad2deg)
vjps_are_tjps(anp.degrees)
vjps_are_tjps(anp.deg2rad)
vjps_are_tjps(anp.radians)
vjps_are_tjps(anp.square)
vjps_are_tjps(anp.sqrt)
vjps_are_tjps(anp.sinc)

vjps_are_tjps(anp.conj)
vjps_are_tjps(anp.conjugate)

# ----- Trickier grads -----

def tjp_dot_arg0(ans, vs, out_vs, A, B):
A_ndim, B_ndim = vs.ndim, anp.ndim(B)
if B_ndim == 0 or B_ndim == 1 or A_ndim == 0:
contract_num = max(0, B_ndim - (A_ndim != 0))
return lambda G: anp.tensordot(G, B, contract_num)
else:
return lambda G: anp.tensordot(G, anp.swapaxes(B, -1, -2), B_ndim - 1)
deftjp(anp.dot, tjp_dot_arg0)

def tjp_dot_arg1(ans, vs, out_vs, A, B):
A_ndim, B_ndim = anp.ndim(A), vs.ndim
needs_transpose = B_ndim > 1 and A_ndim != 0
swap = (lambda x: anp.swapaxes(x, -1, -2)) if needs_transpose else (lambda x: x)
if A_ndim == 0 or A_ndim == 1 or B_ndim == 0:
contract_num = max(0, A_ndim - (B_ndim != 0))
return lambda G: swap(anp.tensordot(G, A, contract_num))
else:
return lambda G: swap(anp.tensordot(
G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1), range(A_ndim - 1)]))
deftjp(anp.dot, tjp_dot_arg1, argnum=1)

def tjp_transpose(ans, in_vs, out_vs, x, axes=None):
axes = tuple(reversed(range(in_vs.ndim))) if axes is None else anp.argsort(axes)
return lambda g: anp.transpose(g, tuple(range(anp.ndim(g) - len(axes))) + axes)
deftjp(anp.transpose, tjp_transpose)

# ----- Utility functions -----

def unbroadcast(vs, out_vs, result):
result_vs = vspace(result)
leading_dims = result_vs.ndim - out_vs.ndim
broadcast_idx = leading_dims
while anp.ndim(result) > leading_dims + vs.ndim:
result = anp.sum(result, axis=broadcast_idx)
for axis, size in enumerate(vs.shape):
if size == 1:
result = anp.sum(result, axis=leading_dims + axis, keepdims=True)
if result_vs.iscomplex and not vs.iscomplex:
result = anp.real(result)
return result

# ----- Extra functions used internally -----

# TODO untake
18 changes: 17 additions & 1 deletion autograd/numpy/numpy_vspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, value):
self.dtype = value.dtype

@property
def size(self): return np.prod(self.shape)
def size(self): return int(np.prod(self.shape))
@property
def ndim(self): return len(self.shape)
def zeros(self): return np.zeros(self.shape, dtype=self.dtype)
Expand All @@ -26,6 +26,22 @@ def randn(self):
def _inner_prod(self, x, y):
return np.dot(np.ravel(x), np.ravel(y))

def _product(self, other_vspace):
return self._contract(other_vspace, ndim=0)

def _contract(self, other_vspace, ndim=None):
ndim = other_vspace.ndim if ndim is None else ndim
if not self.shape[-ndim % self.ndim:] == other_vspace.shape[:ndim]:
raise ValueError

result = self.__new__(self.__class__)
result.shape = self.shape[:-ndim % self.ndim] + other_vspace.shape[ndim:]
result.dtype = np.promote_types(self.dtype, other_vspace.dtype)
return result

def _kronecker_tensor(self):
return np.reshape(np.eye(self.size), self.shape + self.shape)

class ComplexArrayVSpace(ArrayVSpace):
iscomplex = True

Expand Down
69 changes: 69 additions & 0 deletions autograd/tjp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from collections import defaultdict
from .tracer import trace, primitive, Node, toposort
from .vspace import vspace
from .core import add_outgrads, primitive_vjps

def make_tjp(fun, x):
start_node = TJPNode.new_root(x)
end_value, end_node = trace(start_node, fun, x)
if end_node is None:
in_vs, out_vs = start_node.vspace, vspace(end_value)
def tjp(G): return vspace(G)._contract(end_vs)._product(in_vs).zeros()
else:
def tjp(G): return tjp_backward_pass(G, end_node)
return tjp, end_value

def tjp_backward_pass(G, end_node):
assert_vspace_compatible(G, end_node.vspace)
outgrads = {end_node : (G, False)}
for node in toposort(end_node):
cur_outgrad = outgrads.pop(node)
for parent, tjp in node.parents_and_tjps:
outgrad = tjp(cur_outgrad[0])
assert_vspace_compatible(outgrad, parent.vspace)
outgrads[parent] = add_outgrads(vspace(outgrad), outgrads.get(parent), outgrad)
return cur_outgrad[0]

class TJPNode(Node):
__slots__ = ['vspace', 'parents', 'parents_and_tjps']
def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
self.vspace = vspace(value)
self.parents = parents
self.parents_and_tjps = [
(parent, primitive_tjp(fun, argnum, value, parent.vspace,
self.vspace, args, kwargs))
for argnum, parent in zip(parent_argnums, parents)]

def initialize_root(self, value):
self.vspace = vspace(value)
self.parents = []
self.parents_and_tjps = []

primitive_tjps = defaultdict(dict)

def primitive_tjp(fun, argnum, ans, in_vs, out_vs, args, kwargs):
return primitive_tjps[fun][argnum](ans, in_vs, out_vs, args, kwargs)

def deftjp(fun, tjpmaker, argnum=0):
def tjp_fixed_args(ans, vs, gvs, args, kwargs):
return tjpmaker(ans, vs, gvs, *args, **kwargs)
primitive_tjps[fun][argnum] = tjp_fixed_args

def deftjps(fun, tjpmaker, argnums):
for argnum in argnums:
deftjp(fun, partial(tjpmaker, argnum), argnum)

def vjps_are_tjps(fun):
primitive_tjps[fun] = primitive_vjps[fun]

def assert_vspace_compatible(x, vs):
assert vs.ndim == 0 or vspace(x).shape[-vs.ndim:] == vs.shape

# convenience-wrapper stuff

from .util import unary_to_nary

@unary_to_nary
def jacobian(fun, x):
tjp, ans = make_tjp(fun, x)
return tjp(vspace(ans)._kronecker_tensor())
25 changes: 25 additions & 0 deletions tests/test_tjps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd.tjp import jacobian
from autograd import jacobian as _jacobian

from itertools import product


def allclose(x, y): return x.shape == y.shape and np.allclose(x, y)

def test_dot():
npr.seed(0)
shapes = [(), (2,), (2, 2), (2, 2, 2)]
array_pairs = [(npr.normal(size=s1), npr.normal(size=s2))
for s1, s2 in product(shapes, shapes)]
argnums = [0, 1]

def check(A, B, argnum):
res1 = jacobian(np.dot, argnum)(A, B)
res2 = _jacobian(np.dot, argnum)(A, B)
assert allclose(res1, res2)

for A, B in array_pairs:
for argnum in argnums:
yield check, A, B, argnum