Skip to content

Equipping your Library with Backprop

Randall Britten edited this page May 3, 2018 · 25 revisions

Equipping your library with backprop involves providing "backprop-aware" versions of your library functions. In fact, it is possible to make a library fully by providing only backprop versions of your functions, since you can use a backprop-aware function as a normal function with evalBP. Alternatively, you can re-export all of your functions in a separate model with "backprop-aware" versions.

The Types

If you have a function:

myFunc :: a -> b

Then its lifted version would have type:

myFunc :: Reifies s W => BVar s a -> BVar s b

That is, instead of a function directly taking an a and returning a b, it's a function taking a BVar containing an a, and returning a BVar containing a b.

Functions taking multiple arguments can be translated pretty straightforwardly:

func1   ::                       a ->        b ->        c
func1BP :: Reifies s W => BVar s a -> BVar s b -> BVar s c

And also functions returning multiple arguments:

func2   ::                       a -> (       b,        c)
func2BP :: Reifies s W => BVar s a -> (BVar s b, BVar s c)

Note that almost all operations involving BVar'd items require that the contents have a Backprop instance by default. Alternative API's to backprop that require Num instances instead (or explicitly specified addition functions) are available in Numeric.Backprop.Num and Numeric.Backprop.Explicit.

Lifting operations

A BVar s a -> BVar s b really encodes two things:

  1. A a -> b (the actual function)
  2. A a -> b -> a (the "scaled gradient" function)

The documentation for Numeric.Backprop.Op gives detail about what these entail, with rendered math and examples.

The second function requires some elaboration. Let's say you are writing a lifted version of your function y = f(x) (whose derivative is dy/dx), and that your final result at the end of your computation is z = g(f(x)) (whose derivative is dz/dx). In that case, because of the chain rule, dz/dx = dz/dy * dy/dx.

The scaled gradient is the function which, given dz/dy, returns dz/dx (that is, returns dz/dy * dy/dx).

For example, for the mathematical operation y = f(x) = x^2, then, considering z = g(f(x)), dz/dx = dz/dy * 2x. In fact, for all functions taking and returning scalars (just normal single numbers), dz/dx = dz/dy * f'(x).

With that in mind, let's write our "squared" op:

square :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a
square = liftOp1 . op1 $ \x ->
    ( x^2              , \dzdy -> dzdy * 2 * x)
  --  ^- actual result   ^- scaled gradient function

Keeping along the same pattern, for y = f(x) = sin(x), then, considering z = g(f(x)), dz/dx = dz/dy * cos(x). So, we have:

liftedSin :: (Reifies s W, Backprop a, Floating a) => BVar s a -> BVar s a
liftedSin = liftOp1 . op1 $ \x ->
    ( sin x, \dzdy -> dzdy * cos x )

In general, for functions that take and return scalars:

liftedF :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a
liftedF = liftOp1 . op1 $ \x ->
    ( f x, \dzdy -> dzdy * dfdx x )

Non-trivial example

A simple non-trivial example is sumElements, which we can define to take the hmatrix library's R n type (an n-vector of Double). In this case, we have to think about g(sum(x)). In this case, the types guide our thinking:

sumElements           :: R n -> Double
sumElementsScaledGrad :: R n -> Double -> R n

The simplest way for me to do this personally is to just take it element by element:

  • y = f(<a,b,c>) = a + b + c
  • z = g(y) = g(x + y + z)
  • dz/da = dz/dy * dy/da
  • From here, we can recognize that dy/da = 1, since y = a + b + c.
  • We can then include that dz/da = dz/dy, where dz/dy is the second argument of sumElementsScaledGrad
  • By the same logic, we see that dz/db = dz/dy, as well, and dz/dc = dz/dy, as well.
  • We can generalize this to say that the gradient of the sum of vector components is the same for every component: just dz/dy.

So in the end:

sumElements :: Reifies s W => BVar s (R n) -> BVar s Double
sumElements = liftOp1 . op1 $ \xs ->
    ( H.sumElements xs, \dzdy -> H.konst dzdy )    -- a constant vector

Multiple-argument functions

Lifting multiple-argument functions is the same thing, except using liftOp2 and op2, or liftOpN and opN.

A BVar s a -> BVar s b -> BVar s c encodes two things:

  1. The actual a -> b -> c
  2. The scaled gradient, a -> b -> c -> (a, b).

c is again dzdy, and the final (a,b) is how dzdy affects both of the inputs.

For a simple example, x + y. Working it out:

  1. y = f(x1,x2) = x1 + x2
  2. z = g(f(x1,x2)) = g(x1 + x2)
  3. dz/dx1 = dz/dy * dy/dx1
  4. We recognize here that dy/dx1 is 1, because y = x1 + x2.
  5. So we can say dz/dx1 = dz/dy
  6. The same logic applies for dz/dxy
  7. So we give dz/dx1 = dz/dy, and dz/dx2 = dz/dy.
add :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a -> BVar s a
add = liftOp2 . op2 $ \x1 x2 ->
    ( x1 + x2, \dzdy -> (dzdy, dzdy) )

And for multiplication, we can work it out:

  1. y = f(x1,x2) = x1 * x2
  2. z = g(f(x1,x2)) = g(x1 * x2)
  3. dz/dx1 = dz/dy * dy/dx1
  4. We recognize here that dy/dx1 is x2, because y = x1 * x2.
  5. So we can say dz/dx1 = dz/dy * x2
  6. dz/dx2 = dz/dy * dy/dx2
  7. We recognize that dy/dx2 = x1, because y = x1 * x2.
  8. So we can say dz/dx2 = dz/dy * x1
mul :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a -> BVar s a
mul = liftOp2 . op2 $ \x1 x2 ->
    ( x1 * x2, \dzdy -> (dzdy * x2, x1 * dzdy) )

For non-trivial examples involving linear algebra, see the source for the hmatrix-backprop library.

-- | Dot product
dot :: Reifies s W => BVar s (R n) -> BVar s (R n) -> BVar s Double
dot = liftOp2 . op2 $ \u v ->
    ( u `H.dot` v
    , \dzdy -> (H.konst dzdy * v, u * H.konst dzdy)
    )


-- | matrix-vector multiplication
(#>) :: Reifies s W => L m n -> R n -> R m
(#>) = liftOp2 . op2 $ \mat vec ->
    ( mat H.#> vec
    , \dzdy -> (dzdy `H.outer` vec, H.tr mat H.#> dzdy)
    )

Returning multiple items

You can return tuples inside BVars:

splitAt :: (Reifies s W, Backprop a) => Int -> BVar s [a] -> BVar s ([a], [a])
splitAt n = liftOp1 . op1 $ \xs ->
    let (ys, zs) = Data.List.splitAt n xs
    in  ((ys, zs), \(dys,dzs) -> dys ++ dzs)

This works as expected. However, it is strongly recommended, for the benefit of your users, that you return a tuple of BVars instead of a BVar of tuples:

splitAt :: (Reifies s W, Backprop a) => Int -> BVar s [a] -> (BVar s [a], BVar s [a])
splitAt n xs = (yszs ^^. _1, yszs ^^. _2)
  where
    yszs = liftOp1 (op1 $ \xs' ->
      let (ys, zs) = Data.List.splitAt n xs'
      in  ((ys, zs), \(dys,dzs) -> dys ++ dzs)
         ) xs

Using _1 and _2 from the microlens or lens packages.

Isomorphisms

If your function witnesses an isomorphism, there are handy combinators for making this easy to write. This is especially useful in the case of data constructors:

newtype Foo = MkFoo { getFoo :: Double }

mkFoo :: Reifies s W => BVar s Double -> BVar s Foo
mkFoo = isoVar MkFoo getFoo

-- also:
mkFoo :: BVar s Double -> BVar s Foo
mkFoo = coerceVar

data Bar = MkBar { bar1 :: Double, bar2 :: Float }

mkBar :: Reifies s W => BVar s Double -> BVar s Float -> BVar s Bar
mkBar = isoVar2 MkBar (\b -> (bar1 b, bar2 b))

tuple :: (Reifies s W, Backprop a) => BVar s a -> BVar s b -> BVar s (a, b)
tuple = isoVar2 (,) id

NoGrad

If you do decide to go to the extreme, and provide only a BVar-based interface to your library (and no non-BVar based one), then you might have a situation where you have a function where you cannot define the gradient -- maybe no gradient exists, or you haven't put in the time to write one. In this case, you can use noGrad and noGrad1:

negateNoGrad :: (Num a, Backprop a) => BVar s a -> BVar s a
negateNoGrad = liftOp1 (noGrad negate)

This function can still be used with evalBP to get the correct answer. It can even be used with gradBP if the result is never used in the final answer.

However, if it is used in the final answer, then computing the gradient will throw a runtime exception.

Be sure to warn your users! Like any partial function, this is not recommended unless in extreme circumstances.

Monadic Operations

This should all work if your operations are all "pure". However, what about the cases where your operations have to be performed in some Applicative or Monadic context?

For example, what if add :: X -> X -> IO X ?

One option you can do is to newtype-wrap your operations, and then give those a backprop instance:

newtype IOX = IOX (IO X)

instance Backprop IOX where
    zero (IOX x) = IOX (fmap zeroForX x)
    -- or
    zero (IOX x) = IOX (zeroForX =<< x)

    add (IOX x) (IOX y) = IOX $ do
      x' <- x
      y' <- y
      addForX x' y'

And you can define your functions in terms of this:

addX :: BVar s IOX -> BVar s IOX -> BVar s IOX
addX = liftOp2 . op2 $ \(IOX x) (IOX y) ->
    ( IOX (do x' <- x; y' <- y; addForX x' y')
    , \(IOX dxdy) -> IOX (fmap fst dxdy) (fmap snd dxdy)
    )

This should work fine as long as you never "branch" on any results of your actions. You must not ever need to peek inside the action in order to decide what operations to do next. In other words, this works if the operations you need to perform are all known and fixed before-hand.

A newtype wrapper is provided to give you this behavior automatically -- it's ApBP, from Numeric.Backprop and Numeric.Backprop.Class.

type IOX = ApBP IO X

However, this will not work if you need to do things like compare contents, etc. to decide what operations to use.

At the moment, this is not supported. Please open an issue or contact me (@mstksg) if this becomes an issue!

Clone this wiki locally