-
Notifications
You must be signed in to change notification settings - Fork 22
Equipping your Library with Backprop
Equipping your library with backprop involves providing "backprop-aware" versions of your library functions. In fact, it is possible to make a library fully by providing only backprop versions of your functions, since you can use a backprop-aware function as a normal function with evalBP
. Alternatively, you can re-export all of your functions in a separate model with "backprop-aware" versions.
If you have a function:
myFunc :: a -> b
Then its lifted version would have type:
myFunc :: Reifies s W => BVar s a -> BVar s b
That is, instead of a function directly taking an a
and returning a b
, it's a function taking a BVar
containing an a
, and returning a BVar
containing a b
.
Functions taking multiple arguments can be translated pretty straightforwardly:
func1 :: a -> b -> c
func1BP :: Reifies s W => BVar s a -> BVar s b -> BVar s c
And also functions returning multiple arguments:
func2 :: a -> ( b, c)
func2BP :: Reifies s W => BVar s a -> (BVar s b, BVar s c)
Note that at this moment, almost all operations involving BVar
'd items require that the contents have a Num
instance. This is not a fundamental limitation of the underlying algorithm -- it was a conscious design decision intended to significantly reduce the complexity of the external API. However, the internal implementation does not require Num
, and one day a Num
-free version might be exported if there is enough demand. If this is important to you, please contact me (@mstksg) or open an issue!
A BVar s a -> BVar s b
really encodes two things:
- A
a -> b
(the actual function) - A
a -> b -> a
(the "scaled gradient" function)
The documentation for Numeric.Backprop.Op gives detail about what these entail, with rendered math and examples.
The second function requires some elaboration. Let's say you are writing a lifted version of your function y = f(x)
(whose derivative is dy/dx), and that your final result at the end of your computation is z = g(f(x))
(whose derivative is dz/dx
). In that case, because of the chain rule, dz/dx = dz/dy * dy/dx
.
The scaled gradient is the function where, given dz/dy
, returns dz/dx
(that is, returns dz/dy * dy/dx
).
For example, for the mathematical operation y = f(x) = x^2
, then, considering z = g(f(x))
, dz/dx = dz/dy * 2x
. In fact, for all functions taking and returning scalars (just normal single numbers), dz/dx = dz/dy * f'(x)
.
With that in mind, let's write our "squared" op:
square :: (Reifies s W, Num a) => BVar s a -> BVar s a
square = liftOp1 . op1 $ \x ->
( x^2 , \dzdy -> dzdy * 2 * x)
-- ^- actual result ^- scaled gradient function
Keeping along the same pattern, for y = f(x) = sin(x)
, then, considering z = g(f(x))
, dz/dx = dz/dy * cos(x)
. So, we have:
liftedSin :: (Reifies s W, Floating a) => BVar s a -> BVar s a
liftedSin = liftOp1 . op1 $ \x ->
( sin x, \dzdy -> dzdy * cos x )
In general, for functions that take and return scalars:
liftedF :: (Reifies s W, Num a) => BVar s a -> BVar s a
liftedF = liftOp1 . op1 $ \x ->
( f x, \dzdy -> dzdy * dfdx x )
A simple non-trivial example is sumElements
, which we can define to take the hmatrix library's R n
type (an n-vector of Double
). In this case, we have to think about g(sum(x))
. In this case, the types guide our thinking:
sumElements :: R n -> Double
sumElementsScaledGrad :: R n -> Double -> R n
The simplest way for me to do this personally is to just take it element by element:
y = f(<a,b,c>) = a + b + c
z = g(y) = g(x + y + z)
dz/da = dz/dy * dy/da
- From here, we can recognize that
dy/da = 1
, sincey = a + b + c
. - We can then include that
dz/da = dz/dy
, wheredz/dy
is the second argument ofsumElementsScaledGrad
- By the same logic, we see that
dz/db = dz/dy
, as well, anddz/dc = dz/dy
, as well. - We can generalize this to say that the gradient of the sum of vector components is the same for every component: just
dz/dy
.
So in the end:
sumElements :: Reifies s W => BVar s (R n) -> Double
sumElements = liftOp1 . op1 $ \xs ->
( H.sumElements xs, \dzdy -> H.konst dzdy ) -- a constant vector