Skip to content

Equipping your Library with Backprop

Justin Le edited this page May 2, 2018 · 25 revisions

Equipping your library with backprop involves providing "backprop-aware" versions of your library functions. In fact, it is possible to make a library fully by providing only backprop versions of your functions, since you can use a backprop-aware function as a normal function with evalBP. Alternatively, you can re-export all of your functions in a separate model with "backprop-aware" versions.

The Types

If you have a function:

myFunc :: a -> b

Then its lifted version would have type:

myFunc :: Reifies s W => BVar s a -> BVar s b

That is, instead of a function directly taking an a and returning a b, it's a function taking a BVar containing an a, and returning a BVar containing a b.

Functions taking multiple arguments can be translated pretty straightforwardly:

func1   ::                       a ->        b ->        c
func1BP :: Reifies s W => BVar s a -> BVar s b -> BVar s c

And also functions returning multiple arguments:

func2   ::                       a -> (       b,        c)
func2BP :: Reifies s W => BVar s a -> (BVar s b, BVar s c)

Note that almost all operations involving BVar'd items require that the contents have a Backprop instance by default. API's that only use Num instances (or explicitly specified addition functions) are available in Numeric.Backprop.Num and Numeric.Backprop.Explicit.

Lifting operations

A BVar s a -> BVar s b really encodes two things:

  1. A a -> b (the actual function)
  2. A a -> b -> a (the "scaled gradient" function)

The documentation for Numeric.Backprop.Op gives detail about what these entail, with rendered math and examples.

The second function requires some elaboration. Let's say you are writing a lifted version of your function y = f(x) (whose derivative is dy/dx), and that your final result at the end of your computation is z = g(f(x)) (whose derivative is dz/dx). In that case, because of the chain rule, dz/dx = dz/dy * dy/dx.

The scaled gradient is the function where, given dz/dy, returns dz/dx (that is, returns dz/dy * dy/dx).

For example, for the mathematical operation y = f(x) = x^2, then, considering z = g(f(x)), dz/dx = dz/dy * 2x. In fact, for all functions taking and returning scalars (just normal single numbers), dz/dx = dz/dy * f'(x).

With that in mind, let's write our "squared" op:

square :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a
square = liftOp1 . op1 $ \x ->
    ( x^2              , \dzdy -> dzdy * 2 * x)
  --  ^- actual result   ^- scaled gradient function

Keeping along the same pattern, for y = f(x) = sin(x), then, considering z = g(f(x)), dz/dx = dz/dy * cos(x). So, we have:

liftedSin :: (Reifies s W, Backprop a, Floating a) => BVar s a -> BVar s a
liftedSin = liftOp1 . op1 $ \x ->
    ( sin x, \dzdy -> dzdy * cos x )

In general, for functions that take and return scalars:

liftedF :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a
liftedF = liftOp1 . op1 $ \x ->
    ( f x, \dzdy -> dzdy * dfdx x )

Non-trivial example

A simple non-trivial example is sumElements, which we can define to take the hmatrix library's R n type (an n-vector of Double). In this case, we have to think about g(sum(x)). In this case, the types guide our thinking:

sumElements           :: R n -> Double
sumElementsScaledGrad :: R n -> Double -> R n

The simplest way for me to do this personally is to just take it element by element:

  • y = f(<a,b,c>) = a + b + c
  • z = g(y) = g(x + y + z)
  • dz/da = dz/dy * dy/da
  • From here, we can recognize that dy/da = 1, since y = a + b + c.
  • We can then include that dz/da = dz/dy, where dz/dy is the second argument of sumElementsScaledGrad
  • By the same logic, we see that dz/db = dz/dy, as well, and dz/dc = dz/dy, as well.
  • We can generalize this to say that the gradient of the sum of vector components is the same for every component: just dz/dy.

So in the end:

sumElements :: Reifies s W => BVar s (R n) -> BVar s Double
sumElements = liftOp1 . op1 $ \xs ->
    ( H.sumElements xs, \dzdy -> H.konst dzdy )    -- a constant vector

Multiple-argument functions

Lifting multiple-argument functions is the same thing, except using liftOp2 and op2, or liftOpN and opN.

A BVar s a -> BVar s b -> BVar s c encodes two things:

  1. The actual a -> b -> c
  2. The scaled gradient, a -> b -> c -> (a, b).

c is again dzdy, and the final (a,b) is how dzdy affects both of the inputs.

For a simple example, x + y. Working it out:

  1. y = f(x1,x2) = x1 + x2
  2. z = g(f(x1,x2)) = g(x1 + x2)
  3. dz/dx1 = dz/dy * dy/dx1
  4. We recognize here that dy/dx1 is 1, because y = x1 + x2.
  5. So we can say dz/dx1 = dz/dy
  6. The same logic applies for dz/dxy
  7. So we give dz/dx1 = dz/dy, and dz/dx2 = dz/dy.
add :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a -> BVar s a
add = liftOp2 . op2 $ \x1 x2 ->
    ( x1 + x2, \dzdy -> (dzdy, dzdy) )

And for multiplication, we can work it out:

  1. y = f(x1,x2) = x1 * x2
  2. z = g(f(x1,x2)) = g(x1 * x2)
  3. dz/dx1 = dz/dy * dy/dx1
  4. We recognize here that dy/dx1 is x2, because y = x1 * x2.
  5. So we can say dz/dx1 = dz/dy * x2
  6. dz/dx2 = dz/dy * dy/dx2
  7. We recognize that dy/dx2 = x1, because y = x1 * x2.
  8. So we can say dz/dx2 = dz/dy * x1
mul :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a -> BVar s a
mul = liftOp2 . op2 $ \x1 x2 ->
    ( x1 * x2, \dzdy -> (dzdy * x2, x1 * dzdy) )

For non-trivial examples involving linear algebra, see the source for the hmatrix-backprop library.

-- | Dot product
dot :: Reifies s W => BVar s (R n) -> BVar s (R n) -> BVar s Double
dot = liftOp2 . op2 $ \x y ->
    ( x `H.dot` y
    , \dzdy -> (H.konst dzdy * y, x * H.konst dzdy)
    )


-- | matrix-vector multiplication
(#>) :: Reifies s W => L m n -> R n -> R m
(#>) = liftOp2 . op2 $ \mat vec ->
    ( mat H.#> vec
    , \dzdy -> (dzdy `H.outer` vec, H.tr mat H.#> dzdy)
    )

Returning multiple items

You can return tuples inside BVars:

splitAt :: (Reifies s W, Backprop a) => Int -> BVar s [a] -> BVar s ([a], [a])
splitAt n = liftOp1 . op1 $ \xs ->
    let (ys, zs) = Data.List.splitAt n xs
    in  ((ys, zs), \(dys,dzs) -> dys ++ dzs)

This works as expected. However, it is strongly recommended, for the benefit of your users, that you return a tuple of BVars instead of a BVar or tuples:

splitAt :: (Reifies s W, Backprop a) => Int -> BVar s [a] -> (BVar s [a], BVar s [a])
splitAt n xs = (yszs ^^. _1, yszs ^^. _2)
  where
    yszs = liftOp1 (op1 $ \xs' ->
      let (ys, zs) = Data.List.splitAt n xs'
      in  ((ys, zs), \(dys,dzs) -> dys ++ dzs)
         ) xs

Using _1 and _2 from the microlens or lens packages.

Special Constructors (Iso)

If your function witnesses an isomorphism, there are handy combinators for making this easy to write. This is especially useful in the case of data constructors:

data Foo = MkFoo { getFoo :: Double }

mkFoo :: Reifies s W => BVar s Double -> BVar s Foo
mkFoo = isoVar MkFoo getFoo

data Bar = MkBar { bar1 :: Double, bar2 :: Float }

mkBar :: Reifies s W => BVar s Double -> BVar s Float -> BVar s Bar
mkBar = isoVar2 MkBar (\b -> (bar1 b, bar2 b))

NoGrad

Clone this wiki locally