Reimplementation of SEGNN in jax. Original work by Johannes Brandstetter, Rob Hesselink, Elise van der Pol, Erik Bekkers and Max Welling.
40-50% faster inference and training compared to the original torch implementation. Also JAX-MD.
python -m pip install segnn-jax
Or clone this repository and build locally
python -m pip install -e .
Upgrade jax
to the gpu version
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
N-body (charged and gravity) and QM9 datasets are included for completeness from the original paper.
Charged is on 5 bodies, gravity on 100 bodies. QM9 has graphs of variable sizes, so in jax samples are padded to the maximum size. Loss is MSE for Charged and Gravity and MAE for QM9.
Times are remeasured on Quadro RTX 4000, model only on batches of 100 graphs, in (global) single precision.
torch (original) | jax (ours) | |||
Loss | Inference [ms] | Loss | Inference [ms] | |
charged (position) |
.0043 | 21.22 | .0045 | 3.77 |
gravity (position) |
.265 | 60.55 | .264 | 41.72 |
QM9 (alpha) |
.066* | 82.53 | .082 | 105.98** |
** padded (naive)
The experiments are only included in the github repo, so it needs to be cloned first.
git clone https://github.com/gerkone/segnn-jax
They are adapted from the original implementation, so additionally torch
and torch_geometric
are needed (cpu versions are enough).
python -m pip install -r experiments/requirements.txt
QM9 is automatically downloaded and processed when running the respective experiment.
The N-body datasets have to be generated locally from the directory experiments/nbody/data (it will take some time, especially n-body gravity
)
python3 -u generate_dataset.py --simulation=charged --seed=43
python3 -u generate_dataset.py --simulation=gravity --n-balls=100 --seed=43
On jax<=0.4.6
, the jit
-pjit
merge can be deactivated making traning faster (on nbody). This looks like an issue with dataloading and the validation training loop implementation and it does not affect SEGNN.
export JAX_JIT_PJIT_API_MERGE=0
python validate.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12
python validate.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100
python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling
(configurations used in validation)
- e3nn_jax made this reimplementation possible.
- Artur Toshev and Johannes Brandsetter, for support.