-
Notifications
You must be signed in to change notification settings - Fork 22
Equipping your Library with Backprop
Equipping your library with backprop involves providing "backprop-aware" versions of your library functions. In fact, it is possible to make a library fully by providing only backprop versions of your functions, since you can use a backprop-aware function as a normal function with evalBP
. Alternatively, you can re-export all of your functions in a separate model with "backprop-aware" versions.
If you have a function:
myFunc :: a -> b
Then its lifted version would have type:
myFunc :: Reifies s W => BVar s a -> BVar s b
That is, instead of a function directly taking an a
and returning a b
, it's a function taking a BVar
containing an a
, and returning a BVar
containing a b
.
Functions taking multiple arguments can be translated pretty straightforwardly:
func1 :: a -> b -> c
func1BP :: Reifies s W => BVar s a -> BVar s b -> BVar s c
And also functions returning multiple arguments:
func2 :: a -> ( b, c)
func2BP :: Reifies s W => BVar s a -> (BVar s b, BVar s c)
Note that almost all operations involving BVar
'd items require that the contents have a Backprop
instance by default. API's that only use Num
instances (or explicitly specified addition functions) are available in Numeric.Backprop.Num and Numeric.Backprop.Explicit.
A BVar s a -> BVar s b
really encodes two things:
- A
a -> b
(the actual function) - A
a -> b -> a
(the "scaled gradient" function)
The documentation for Numeric.Backprop.Op gives detail about what these entail, with rendered math and examples.
The second function requires some elaboration. Let's say you are writing a lifted version of your function y = f(x)
(whose derivative is dy/dx), and that your final result at the end of your computation is z = g(f(x))
(whose derivative is dz/dx
). In that case, because of the chain rule, dz/dx = dz/dy * dy/dx
.
The scaled gradient is the function where, given dz/dy
, returns dz/dx
(that is, returns dz/dy * dy/dx
).
For example, for the mathematical operation y = f(x) = x^2
, then, considering z = g(f(x))
, dz/dx = dz/dy * 2x
. In fact, for all functions taking and returning scalars (just normal single numbers), dz/dx = dz/dy * f'(x)
.
With that in mind, let's write our "squared" op:
square :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a
square = liftOp1 . op1 $ \x ->
( x^2 , \dzdy -> dzdy * 2 * x)
-- ^- actual result ^- scaled gradient function
Keeping along the same pattern, for y = f(x) = sin(x)
, then, considering z = g(f(x))
, dz/dx = dz/dy * cos(x)
. So, we have:
liftedSin :: (Reifies s W, Backprop a, Floating a) => BVar s a -> BVar s a
liftedSin = liftOp1 . op1 $ \x ->
( sin x, \dzdy -> dzdy * cos x )
In general, for functions that take and return scalars:
liftedF :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a
liftedF = liftOp1 . op1 $ \x ->
( f x, \dzdy -> dzdy * dfdx x )
A simple non-trivial example is sumElements
, which we can define to take the hmatrix library's R n
type (an n-vector of Double
). In this case, we have to think about g(sum(x))
. In this case, the types guide our thinking:
sumElements :: R n -> Double
sumElementsScaledGrad :: R n -> Double -> R n
The simplest way for me to do this personally is to just take it element by element:
y = f(<a,b,c>) = a + b + c
z = g(y) = g(x + y + z)
dz/da = dz/dy * dy/da
- From here, we can recognize that
dy/da = 1
, sincey = a + b + c
. - We can then include that
dz/da = dz/dy
, wheredz/dy
is the second argument ofsumElementsScaledGrad
- By the same logic, we see that
dz/db = dz/dy
, as well, anddz/dc = dz/dy
, as well. - We can generalize this to say that the gradient of the sum of vector components is the same for every component: just
dz/dy
.
So in the end:
sumElements :: Reifies s W => BVar s (R n) -> BVar s Double
sumElements = liftOp1 . op1 $ \xs ->
( H.sumElements xs, \dzdy -> H.konst dzdy ) -- a constant vector
Lifting multiple-argument functions is the same thing, except using liftOp2
and op2
, or liftOpN
and opN
.
A BVar s a -> BVar s b -> BVar s c
encodes two things:
- The actual
a -> b -> c
- The scaled gradient,
a -> b -> c -> (a, b)
.
c
is again dzdy
, and the final (a,b)
is how dzdy affects both of the inputs.
For a simple example, x + y. Working it out:
y = f(x1,x2) = x1 + x2
z = g(f(x1,x2)) = g(x1 + x2)
dz/dx1 = dz/dy * dy/dx1
- We recognize here that
dy/dx1
is1
, becausey = x1 + x2
. - So we can say
dz/dx1 = dz/dy
- The same logic applies for
dz/dxy
- So we give
dz/dx1 = dz/dy
, anddz/dx2 = dz/dy
.
add :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a -> BVar s a
add = liftOp2 . op2 $ \x1 x2 ->
( x1 + x2, \dzdy -> (dzdy, dzdy) )
And for multiplication, we can work it out:
y = f(x1,x2) = x1 * x2
z = g(f(x1,x2)) = g(x1 * x2)
dz/dx1 = dz/dy * dy/dx1
- We recognize here that
dy/dx1
isx2
, becausey = x1 * x2
. - So we can say
dz/dx1 = dz/dy * x2
dz/dx2 = dz/dy * dy/dx2
- We recognize that
dy/dx2 = x1
, becausey = x1 * x2
. - So we can say
dz/dx2 = dz/dy * x1
mul :: (Reifies s W, Backprop a, Num a) => BVar s a -> BVar s a -> BVar s a
mul = liftOp2 . op2 $ \x1 x2 ->
( x1 * x2, \dzdy -> (dzdy * x2, x1 * dzdy) )
For non-trivial examples involving linear algebra, see the source for the hmatrix-backprop library.
-- | Dot product
dot :: Reifies s W => BVar s (R n) -> BVar s (R n) -> BVar s Double
dot = liftOp2 . op2 $ \x y ->
( x `H.dot` y
, \dzdy -> (H.konst dzdy * y, x * H.konst dzdy)
)
-- | matrix-vector multiplication
(#>) :: Reifies s W => L m n -> R n -> R m
(#>) = liftOp2 . op2 $ \mat vec ->
( mat H.#> vec
, \dzdy -> (dzdy `H.outer` vec, H.tr mat H.#> dzdy)
)
You can return tuples inside BVar
s:
splitAt :: (Reifies s W, Backprop a) => Int -> BVar s [a] -> BVar s ([a], [a])
splitAt n = liftOp1 . op1 $ \xs ->
let (ys, zs) = Data.List.splitAt n xs
in ((ys, zs), \(dys,dzs) -> dys ++ dzs)
This works as expected. However, it is strongly recommended, for the benefit of your users, that you return a tuple of BVar
s instead of a BVar
or tuples:
splitAt :: (Reifies s W, Backprop a) => Int -> BVar s [a] -> (BVar s [a], BVar s [a])
splitAt n xs = (yszs ^^. _1, yszs ^^. _2)
where
yszs = liftOp1 (op1 $ \xs' ->
let (ys, zs) = Data.List.splitAt n xs'
in ((ys, zs), \(dys,dzs) -> dys ++ dzs)
) xs
Using _1
and _2
from the microlens or lens packages.
If your function witnesses an isomorphism, there are handy combinators for making this easy to write. This is especially useful in the case of data constructors:
data Foo = MkFoo { getFoo :: Double }
mkFoo :: Reifies s W => BVar s Double -> BVar s Foo
mkFoo = isoVar MkFoo getFoo
data Bar = MkBar { bar1 :: Double, bar2 :: Float }
mkBar :: Reifies s W => BVar s Double -> BVar s Float -> BVar s Bar
mkBar = isoVar2 MkBar (\b -> (bar1 b, bar2 b))