Skip to content

Equipping your Library with Backprop

Justin Le edited this page Apr 12, 2018 · 25 revisions

Equipping your library with backprop involves providing "backprop-aware" versions of your library functions. In fact, it is possible to make a library fully by providing only backprop versions of your functions, since you can use a backprop-aware function as a normal function with evalBP. Alternatively, you can re-export all of your functions in a separate model with "backprop-aware" versions.

The Types

If you have a function:

myFunc :: a -> b

Then its lifted version would have type:

myFunc :: Reifies s W => BVar s a -> BVar s b

That is, instead of a function directly taking an a and returning a b, it's a function taking a BVar containing an a, and returning a BVar containing a b.

Functions taking multiple arguments can be translated pretty straightforwardly:

func1   ::                       a ->        b ->        c
func1BP :: Reifies s W => BVar s a -> BVar s b -> BVar s c

And also functions returning multiple arguments:

func2   ::                       a -> (       b,        c)
func2BP :: Reifies s W => BVar s a -> (BVar s b, BVar s c)

Note that at this moment, almost all operations involving BVar'd items require that the contents have a Num instance. This is not a fundamental limitation of the underlying algorithm -- it was a conscious design decision intended to significantly reduce the complexity of the external API. However, the internal implementation does not require Num, and one day a Num-free version might be exported if there is enough demand. If this is important to you, please contact me (@mstksg) or open an issue!

Lifting operations

A BVar s a -> BVar s b really encodes two things:

  1. A a -> b (the actual function)
  2. A a -> b -> a (the "scaled gradient" function)

The documentation for Numeric.Backprop.Op gives detail about what these entail, with rendered math and examples.

The second function requires some elaboration. Let's say you are writing a lifted version of your function y = f(x) (whose derivative is dy/dx), and that your final result at the end of your computation is z = g(f(x)) (whose derivative is dz/dx). In that case, because of the chain rule, dz/dx = dz/dy * dy/dx.

The scaled gradient is the function where, given dz/dy, returns dz/dx (that is, returns dz/dy * dy/dx).

For example, for the mathematical operation y = f(x) = x^2, then, considering z = g(f(x)), dz/dx = dz/dy * 2x. In fact, for all functions taking and returning scalars (just normal single numbers), dz/dx = dz/dy * f'(x).

With that in mind, let's write our "squared" op:

square :: (Reifies s W, Num a) => BVar s a -> BVar s a
square = liftOp1 . op1 $ \x ->
    ( x^2              , \dzdy -> dzdy * 2 * x)
  --  ^- actual result   ^- scaled gradient function

Keeping along the same pattern, for y = f(x) = sin(x), then, considering z = g(f(x)), dz/dx = dz/dy * cos(x). So, we have:

liftedSin :: (Reifies s W, Floating a) => BVar s a -> BVar s a
liftedSin = liftOp1 . op1 $ \x ->
    ( sin x, \dzdy -> dzdy * cos x )

In general, for functions that take and return scalars:

liftedF :: (Reifies s W, Num a) => BVar s a -> BVar s a
liftedF = liftOp1 . op1 $ \x ->
    ( f x, \dzdy -> dzdy * dfdx x )

A simple non-trivial example is sumElements, which we can define to take the hmatrix library's R n type (an n-vector of Double). In this case, we have to think about g(sum(x)). In this case, the types guide our thinking:

sumElements           :: R n -> Double
sumElementsScaledGrad :: R n -> Double -> R n

The simplest way for me to do this personally is to just take it element by element:

  • y = f(<a,b,c>) = a + b + c
  • z = g(y) = g(x + y + z)
  • dz/da = dz/dy * dy/da
  • From here, we can recognize that dy/da = 1, since y = a + b + c.
  • We can then include that dz/da = dz/dy, where dz/dy is the second argument of sumElementsScaledGrad
  • By the same logic, we see that dz/db = dz/dy, as well, and dz/dc = dz/dy, as well.
  • We can generalize this to say that the gradient of the sum of vector components is the same for every component: just dz/dy.

So in the end:

sumElements :: Reifies s W => BVar s (R n) -> Double
sumElements = liftOp1 . op1 $ \xs ->
    ( H.sumElements xs, \dzdy -> H.konst dzdy )    -- a constant vector
Clone this wiki locally