Sea surface height is a gateway variable to other important ocean properties, e.g. sea surface temperature, geostrophic currents. There are many massive models that attempt to model this, e.g. NEMO, MOM6, MITGCM. However they are very expensive and quite difficult to run. So there are many small models that are useful approximations, e.g. Quasi-Geostrophic and Shallow Water. This repo attempts to showcase how we can use some modern tools to construct dynamical systems for PDEs.
What makes this different from the tons and tons of different implementations is that we will be using JAX. JAX is basically numpy on steroids because the API is very similar but we also get some of the modern toolsets along with speed. Most importantly, JAX is differentiable. Having a differentiable model is important because it allows us to:
- Learn some of the hyperparameters if necessary
- Embed this in machine learning models where differentiability is needed
Why Not PyTorch?
We could easily just use PyTorch. However, there are some advantanges to JAX over other languages like PyTorch and TensorFlow:
- Familiar Numpy-Like API which is nice for newcomers in the scientific community
- CPU/GPU/TPU capabilities with minimal code changes
- Gradient Operators instead of storing the transformations in the tensors
- Functional-like language which is easier to read for newcomers
- Auto-Vectorization so we can easily parallize the operators for multiple dimensions without code changes (note: TensorFlow has this)
- JIT compilation speeds up the code by a lot (note: both PyTorch and TensorFlow has this)
This library will be relatively general but this will be a development platform for the following applications:
- Generate Simulations
- Surrogate Models
- Data Assimilation
Without making it too complicated, we settled on a few key objects that the package will comprise of.
Domain
This will be the object to define the grids where all of the fields live. It will be easy to access the coordinates, boundaries, grids and cell volumes. We don't need to store the grid all of the time, instead we just generate it as we see fit.
Operators
This will be a suite of functions for different gradient calculations and combined operations for well-known equations. We will primarily focus on finite difference operators with the finiteDiffX
package. At a later date, we can introduce spectral and finite volume methods.
Integrators
We will use the diffrax
package to do the time integration. We'll use the method-of-lines technique to formulate all of our PDEs to calculate the RHS of the equation for the state at
Params, State & Equations of Motion
We will have a general API for how we can keep store parameters, initialize states and pass thew both through the equation of motion. To handle what's differentiable and what is not, we will use the equinox
package.
Configs
We will use the hydra
package to keep track of the configurations and to initialize parameters for experiments.
We can directly install it via pip from the
pip install "git+https://github.com/jejjohnson/jaxsw.git"
We can also clone the git repository
git clone https://github.com/jejjohnson/jaxsw.git
cd jaxsw
The easiest way to get started is to simply use the poetry package which installs all necessary dev packages as well
poetry install
We can also install via pip
as well
pip install .
We also have a conda environment with all of the equivalent dependencies.
conda env create -f environments/jax_linux_cpu.yaml
conda activate jaxsw
qg_utils
- useful functions for dealing with QG equationsjaxdf
- Nice API for defining operators for PDEs.jax-cfd
- Nice API for defining PDEsinvobs-data-assimilation
- Nice API for Dynamical SystemsMASSH
- The differentiable QG and SW models applied to sea surface height interpolation.qgm_pytorch
- Quasi-Geostrophic Model in PyTorchQGNet
- QG implementation in PyTorch with convolutions.