Skip to content

gerkone/segnn-jax

Repository files navigation

Steerable E(3) GNN in jax

Reimplementation of SEGNN in jax. Original work by Johannes Brandstetter, Rob Hesselink, Elise van der Pol, Erik Bekkers and Max Welling.

Why jax?

40-50% faster inference and training compared to the original torch implementation. Also JAX-MD.

Installation

python -m pip install segnn-jax

Or clone this repository and build locally

python -m pip install -e .

GPU support

Upgrade jax to the gpu version

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Validation

N-body (charged and gravity) and QM9 datasets are included for completeness from the original paper.

Results

Charged is on 5 bodies, gravity on 100 bodies. QM9 has graphs of variable sizes, so in jax samples are padded to the maximum size. Loss is MSE for Charged and Gravity and MAE for QM9.

Times are remeasured on Quadro RTX 4000, model only on batches of 100 graphs, in (global) single precision.

torch (original) jax (ours)
Loss Inference [ms] Loss Inference [ms]
charged (position) .0043 21.22 .0045 3.77
gravity (position) .265 60.55 .264 41.72
QM9 (alpha) .066* 82.53 .082 105.98**
* rerun on same conditions

** padded (naive)

Validation install

The experiments are only included in the github repo, so it needs to be cloned first.

git clone https://github.com/gerkone/segnn-jax

They are adapted from the original implementation, so additionally torch and torch_geometric are needed (cpu versions are enough).

python -m pip install -r experiments/requirements.txt

Datasets

QM9 is automatically downloaded and processed when running the respective experiment.

The N-body datasets have to be generated locally from the directory experiments/nbody/data (it will take some time, especially n-body gravity)

Charged dataset (5 bodies, 10000 training samples)

python3 -u generate_dataset.py --simulation=charged --seed=43

Gravity dataset (100 bodies, 10000 training samples)

python3 -u generate_dataset.py --simulation=gravity --n-balls=100 --seed=43

Notes

On jax<=0.4.6, the jit-pjit merge can be deactivated making traning faster (on nbody). This looks like an issue with dataloading and the validation training loop implementation and it does not affect SEGNN.

export JAX_JIT_PJIT_API_MERGE=0

Usage

N-body (charged)

python validate.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12

N-body (gravity)

python validate.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100

QM9

python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling

(configurations used in validation)

Acknowledgments