-
Notifications
You must be signed in to change notification settings - Fork 22
Equipping your Library with Backprop
Equipping your library with backprop involves providing "backprop-aware" versions of your library functions. In fact, it is possible to make a library fully by providing only backprop versions of your functions, since you can use a backprop-aware function as a normal function with evalBP
. Alternatively, you can re-export all of your functions in a separate model with "backprop-aware" versions.
If you have a function:
myFunc :: a -> b
Then its lifted version would have type:
myFunc :: Reifies s W => BVar s a -> BVar s b
That is, instead of a function directly taking an a
and returning a b
, it's a function taking a BVar
containing an a
, and returning a BVar
containing a b
.
Functions taking multiple arguments can be translated pretty straightforwardly:
func1 :: a -> b -> c
func1BP :: Reifies s W => BVar s a -> BVar s b -> BVar s c
And also functions returning multiple arguments:
func2 :: a -> ( b, c)
func2BP :: Reifies s W => BVar s a -> (BVar s b, BVar s c)
Note that almost all operations involving BVar
'd items require that the contents have a Backprop
instance by default. API's that only use Num
instances (or explicitly specified addition functions) are available in Numeric.Backprop.Num and Numeric.Backprop.Explicit.
A BVar s a -> BVar s b
really encodes two things:
- A
a -> b
(the actual function) - A
a -> b -> a
(the "scaled gradient" function)
The documentation for Numeric.Backprop.Op gives detail about what these entail, with rendered math and examples.
The second function requires some elaboration. Let's say you are writing a lifted version of your function y = f(x)
(whose derivative is dy/dx), and that your final result at the end of your computation is z = g(f(x))
(whose derivative is dz/dx
). In that case, because of the chain rule, dz/dx = dz/dy * dy/dx
.
The scaled gradient is the function where, given dz/dy
, returns dz/dx
(that is, returns dz/dy * dy/dx
).
For example, for the mathematical operation y = f(x) = x^2
, then, considering z = g(f(x))
, dz/dx = dz/dy * 2x
. In fact, for all functions taking and returning scalars (just normal single numbers), dz/dx = dz/dy * f'(x)
.
With that in mind, let's write our "squared" op:
square :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a
square = liftOp1 . op1 $ \x ->
( x^2 , \dzdy -> dzdy * 2 * x)
-- ^- actual result ^- scaled gradient function
Keeping along the same pattern, for y = f(x) = sin(x)
, then, considering z = g(f(x))
, dz/dx = dz/dy * cos(x)
. So, we have:
liftedSin :: (Reifies s W, Backprop a, Floating a) => BVar s a -> BVar s a
liftedSin = liftOp1 . op1 $ \x ->
( sin x, \dzdy -> dzdy * cos x )
In general, for functions that take and return scalars:
liftedF :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a
liftedF = liftOp1 . op1 $ \x ->
( f x, \dzdy -> dzdy * dfdx x )
A simple non-trivial example is sumElements
, which we can define to take the hmatrix library's R n
type (an n-vector of Double
). In this case, we have to think about g(sum(x))
. In this case, the types guide our thinking:
sumElements :: R n -> Double
sumElementsScaledGrad :: R n -> Double -> R n
The simplest way for me to do this personally is to just take it element by element:
y = f(<a,b,c>) = a + b + c
z = g(y) = g(x + y + z)
dz/da = dz/dy * dy/da
- From here, we can recognize that
dy/da = 1
, sincey = a + b + c
. - We can then include that
dz/da = dz/dy
, wheredz/dy
is the second argument ofsumElementsScaledGrad
- By the same logic, we see that
dz/db = dz/dy
, as well, anddz/dc = dz/dy
, as well. - We can generalize this to say that the gradient of the sum of vector components is the same for every component: just
dz/dy
.
So in the end:
sumElements :: Reifies s W => BVar s (R n) -> BVar s Double
sumElements = liftOp1 . op1 $ \xs ->
( H.sumElements xs, \dzdy -> H.konst dzdy ) -- a constant vector
Lifting multiple-argument functions is the same thing, except using liftOp2
and op2
, or liftOpN
and opN
.
A BVar s a -> BVar s b -> BVar s c
encodes two things:
- The actual
a -> b -> c
- The scaled gradient,
a -> b -> c -> (a, b)
.
c
is again dzdy
, and the final (a,b)
is how dzdy affects both of the inputs.
For a simple example, x + y. Working it out:
y = f(x1,x2) = x1 + x2
z = g(f(x1,x2)) = g(x1 + x2)
dz/dx1 = dz/dy * dy/dx1
- We recognize here that
dy/dx1
is1
, becausey = x1 + x2
. - So we can say
dz/dx1 = dz/dy
- The same logic applies for
dz/dxy
- So we give
dz/dx1 = dz/dy
, anddz/dx2 = dz/dy
.
add :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a -> BVar s a
add = liftOp2 . op2 $ \x1 x2 ->
( x1 + x2, \dzdy -> (dzdy, dzdy) )
And for multiplication, we can work it out:
y = f(x1,x2) = x1 * x2
z = g(f(x1,x2)) = g(x1 * x2)
dz/dx1 = dz/dy * dy/dx1
- We recognize here that
dy/dx1
isx2
, becausey = x1 * x2
. - So we can say
dz/dx1 = dz/dy * x2
dz/dx2 = dz/dy * dy/dx2
- We recognize that
dy/dx2 = x1
, becausey = x1 * x2
. - So we can say
dz/dx2 = dz/dy * x1
mul :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a -> BVar s a
mul = liftOp2 . op2 $ \x1 x2 ->
( x1 * x2, \dzdy -> (dzdy * x2, x1 * dzdy) )
For non-trivial examples involving linear algebra, see the source for the hmatrix-backprop library.
-- | Dot product
dot :: Reifies s W => BVar s (R n) -> BVar s (R n) -> BVar s Double
dot = liftOp2 . op2 $ \x y ->
( x `H.dot` y
, \dzdy -> (H.konst dzdy * y, x * H.konst dzdy)
)
-- | matrix-vector multiplication
(#>) :: Reifies s W => L m n -> R n -> R m
(#>) = liftOp2 . op2 $ \mat vec ->
( mat H.#> vec
, \dzdy -> (dzdy `H.outer` vec, H.tr mat H.#> dzdy)
)
You can return tuples inside BVar
s:
splitAt :: (Reifies s W, Backprop a) => Int -> BVar s [a] -> BVar s ([a], [a])
splitAt n = liftOp1 . op1 $ \xs ->
let (ys, zs) = Data.List.splitAt n xs
in ((ys, zs), \(dys,dzs) -> dys ++ dzs)
This works as expected. However, it is strongly recommended, for the benefit of your users, that you return a tuple of BVar
s instead of a BVar
of tuples:
splitAt :: (Reifies s W, Backprop a) => Int -> BVar s [a] -> (BVar s [a], BVar s [a])
splitAt n xs = (yszs ^^. _1, yszs ^^. _2)
where
yszs = liftOp1 (op1 $ \xs' ->
let (ys, zs) = Data.List.splitAt n xs'
in ((ys, zs), \(dys,dzs) -> dys ++ dzs)
) xs
Using _1
and _2
from the microlens or lens packages.
If your function witnesses an isomorphism, there are handy combinators for making this easy to write. This is especially useful in the case of data constructors:
newtype Foo = MkFoo { getFoo :: Double }
mkFoo :: Reifies s W => BVar s Double -> BVar s Foo
mkFoo = isoVar MkFoo getFoo
-- also:
mkFoo :: BVar s Double -> BVar s Foo
mkFoo = coerceVar
data Bar = MkBar { bar1 :: Double, bar2 :: Float }
mkBar :: Reifies s W => BVar s Double -> BVar s Float -> BVar s Bar
mkBar = isoVar2 MkBar (\b -> (bar1 b, bar2 b))
tuple :: (Reifies s W, Backprop a) => BVar s a -> BVar s b -> BVar s (a, b)
tuple = isoVar2 (,) id
If you do decide to go to the extreme, and provide only a BVar-based interface to your library (and no non-BVar based one), then you might have a situation where you have a function where you cannot define the gradient -- maybe no gradient exists, or you haven't put in the time to write one. In this case, you can use noGrad
and noGrad1
:
negateNoGrad :: (Num a, Backprop a) => BVar s a -> BVar s a
negateNoGrad = liftOp1 (noGrad negate)
This function can still be used with evalBP
to get the correct answer. It can even be used with gradBP
if the result is never used in the final answer.
However, if it is used in the final answer, then computing the gradient will throw a runtime exception.
Be sure to warn your users! Like any partial function, this is not recommended unless in extreme circumstances.
This should all work if your operations are all "pure". However, what about the cases where your operations have to be performed in some Applicative or Monadic context?
For example, what if add :: X -> X -> IO X
?
One option you can do is to newtype-wrap your operations, and then give those a backprop instance:
newtype IOX = IOX (IO X)
instance Backprop IOX where
zero (IOX x) = IOX (fmap zeroForX x)
-- or
zero (IOX x) = IOX (zeroForX =<< x)
add (IOX x) (IOX y) = IOX $ do
x' <- x
y' <- y
addForX x' y'
And you can define your functions in terms of this:
addX :: BVar s IOX -> BVar s IOX -> BVar s IOX
addX = liftOp2 . op2 $ \(IOX x) (IOX y) ->
( IOX (do x' <- x; y' <- y; addForX x' y')
, \(IOX dxdy) -> IOX (fmap fst dxdy) (fmap snd dxdy)
)
This should work fine as long as you never "branch" on any results of your actions. You must not ever need to peek inside the action in order to decide what operations to do next. In other words, this works if the operations you need to perform are all known and fixed before-hand.
A newtype wrapper is provided to give you this behavior automatically -- it's ApBP
, from Numeric.Backprop and Numeric.Backprop.Class.
type IOX = ApBP IO X
However, this will not work if you need to do things like compare contents, etc. to decide what operations to use.
At the moment, this is not supported. Please open an issue or contact me (@mstksg) if this becomes an issue!