TLDR: We implemented automatic functional differentiation (as in variational calculus) in JAX, one can do g=jax.grad(F)(f)
to get the derivative of the functional F
at function f
, where g
is itself a callable python function.
Autofd can be installed via pip
pip install autofd
A minimal example on how to use this package.
import jax
import jax.numpy as jnp
from jaxtyping import Float32, Array
from autofd import function
import autofd.operators as o
# define a function
@function
def f(x: Float32[Array, ""]) -> Float32[Array, ""]:
return -x**2
# define a functional
def F(f):
return o.integrate(o.compose(jnp.exp, f))
# take the functional derivative
dFdf = jax.grad(F)(f)
# dFdf is invokable!
dFdf(1.)
In mathematics, we can see functions as generalizations of vectors. In layman's terms, we can see a vector as a list of bins with different height, e.g. v=[0.34, 0.2, 0.1, 0.43, 0.14]
. This list can be indexed using integers, v[2]
is 0.1
. If we decrease the size of each bin while we increase the number of bins to infinity, eventually we obtain an infinite dimensional vector that can be continuously indexed. In this case when we use
As we see functions as infinite dimensional vectors, the manipulations that we apply on vectors can also be generalized. For example,
- Summation becomes integration:
$\sum_i v_i \rightarrow \int v(x) dx$ . - Difference becomes differentiation:
$v[i]-v[i-1] \rightarrow \nabla v(x)$ . - Linear operation:
$u_j=\sum_{i}w_{ji}v_i \rightarrow u(y)=\int w(y,x)v(x)dx$ .
In JAX, we can easily write python functions that process Array
s that represent vectors or tensors. With the above generalizations, we can also write functions that process infinite dimensional arrays (functions), which we call function of functions, or higher-order functions. There are many higher-order functions in JAX, for example, jax.linearize
, jax.grad
, jax.vjp
etc. Even in pure python, higher-order functions are very common, the decorator pattern in python is implemented via higher-order functions.
Functions of functions has many names, generally we call them higher-order functions. Specifically, when the higher-order function maps a function to another function, it is often called an operator, e.g.
Just like we can compute the derivative of a function
Please see the paper http://arxiv.org/abs/2311.18727.
@misc{lin2023automatic,
title={Automatic Functional Differentiation in JAX},
author={Min Lin},
year={2023},
eprint={2311.18727},
archivePrefix={arXiv},
primaryClass={cs.PL}
}