Skip to content

Commit

Permalink
Merge pull request #13 from tumaer/neighbors
Browse files Browse the repository at this point in the history
Neighbors
  • Loading branch information
arturtoshev authored Jun 9, 2024
2 parents cf3eb54 + 3d51272 commit 669d8b3
Show file tree
Hide file tree
Showing 28 changed files with 2,355 additions and 639 deletions.
28 changes: 17 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@

</div>

JAX-SPH [(Toshev et al., 2024)](https://arxiv.org/abs/2403.04750) is a modular JAX-based weakly compressible SPH framework, which implements the following SPH routines:
- Standard SPH [(Adami et al., 2012)](https://www.sciencedirect.com/science/article/pii/S002199911200229X)
- Transport velocity SPH [(Adami et al., 2013)](https://www.sciencedirect.com/science/article/pii/S002199911300096X)
- Riemann SPH [(Zhang et al., 2017)](https://www.sciencedirect.com/science/article/abs/pii/S0021999117300438)

![HT_T.gif](https://s9.gifyu.com/images/SUwUD.gif)

## Table of Contents

1. [**Installation**](#installation)
1. [**Getting Started**](#getting-started)
1. [**Setting up a case**](#setting-up-a-case)
1. [**Contributing**](#contributing)
1. [**Citation**](#citation)
1. [**Acknowledgements**](#acknowledgements)

## Installation

### Standalone library
Expand Down Expand Up @@ -84,16 +88,18 @@ python main.py config=cases/ht.yaml
```

### Notebooks
We provide four notebooks demonstrating how to use JAX-SPH:
We provide various notebooks demonstrating how to use JAX-SPH:
- [`tutorial.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/tutorial.ipynb), with a general overview of JAX-SPH and an example how to run the channel flow with hot bottom wall.
- [`iclr24_grads.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_grads.ipynb), with a validation of the gradients through the solver.
- [`iclr24_inverse.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_inverse.ipynb), solving the inverse problem of finding the initial state of a 100-step-long SPH simulation.
- [`iclr24_sitl.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_sitl.ipynb), including training and testing a Solver-in-the-Loop model using the [LagrangeBench](https://github.com/tumaer/lagrangebench) library.
- [`iclr24_grads.ipynb`](notebooks/iclr24_grads.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_grads.ipynb), with a validation of the gradients through the solver.
- [`iclr24_inverse.ipynb`](notebooks/iclr24_inverse.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_inverse.ipynb), solving the inverse problem of finding the initial state of a 100-step-long SPH simulation.
- [`iclr24_sitl.ipynb`](notebooks/iclr24_sitl.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/iclr24_sitl.ipynb), including training and testing a Solver-in-the-Loop model using the [LagrangeBench](https://github.com/tumaer/lagrangebench) library.
- [`neighbors.ipynb`](notebooks/neighbors.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/neighbors.ipynb), explaining the difference between the three neighbor search implementations and comparing their performance.
- [`kernel_plots.ipynb`](notebooks/kernel_plots.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/jax-sph/blob/main/notebooks/kernel_plots.ipynb), visualizing the SPH kernels.

## Setting up a case
## Setting up a Case
To set up a case, just add a `my_case.py` and a `my_case.yaml` file to the `cases/` directory. Every *.py case should inherit from `SimulationSetup` in `jax_sph/case_setup.py` or another case, and every *.yaml config file should either contain a complete set of parameters (see `jax_sph/defaults.py`) or extend `JAX_SPH_DEFAULTS`. Running a case in relaxation mode `case.mode=rlx` overwrites certain parts of the selected case. Passed CLI arguments overwrite any argument.

## Development and Contribution
## Contributing
If you wish to contribute, please run
```bash
pre-commit install
Expand Down
30 changes: 26 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,36 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to JAX-SPH's documentation!
===================================
JAX-SPH
========

.. image:: https://s9.gifyu.com/images/SUwUD.gif
:alt: GIF


What is ``JAX-SPH``?
--------------------

JAX-SPH `(Toshev et al., 2024) <https://arxiv.org/abs/2403.04750>`_ is a Smoothed Particle Hydrodynamics (SPH) code written in `JAX <https://jax.readthedocs.io/>`_. JAX-SPH is designed to be simple, fast, and compatible with deep learning workflows. We currently support the following SPH routines:

* Standard SPH `(Adami et al., 2012) <https://www.sciencedirect.com/science/article/pii/S002199911200229X>`_
* Transport velocity SPH `(Adami et al., 2013) <https://www.sciencedirect.com/science/article/pii/S002199911300096X>`_
* Riemann SPH `(Zhang et al., 2017) <https://www.sciencedirect.com/science/article/abs/pii/S0021999117300438>`_

Check out our `GitHub repository <https://github.com/tumaer/jax-sph>`_ for more information including installation instructions and tutorial notebooks.

.. toctree::
:maxdepth: 1
:caption: Getting Started

pages/tutorials
pages/defaults

.. toctree::
:maxdepth: 2
:caption: Contents:
:caption: API

pages/case_setup
pages/solver
pages/simulate
pages/utils
pages/utils
48 changes: 48 additions & 0 deletions docs/pages/defaults.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Defaults
===================================

The defaults are defined through a function ``jax_sph.defaults.set_defaults()``, which
takes a potentially empty ``omegaconf.DictConfig`` object and creates or overwrites the
default values. One can also directly call ``from jax_sph.defaults import defaults``,
with ``defaults=set_defaults()``, to get the default DictConfig, which we unpack below.

.. exec_code::
:hide_code:
:linenos_output:
:language_output: python
:caption: JAX-SPH default values


with open("jax_sph/defaults.py", "r") as file:
defaults_full = file.read()

# parse defaults: remove imports, only keep the set_defaults function

defaults_full = defaults_full.split("\n")

# remove imports
defaults_full = [line for line in defaults_full if not line.startswith("import")]
defaults_full = [line for line in defaults_full if len(line.replace(" ", "")) > 0]

# remove other functions
keep = False
defaults = []
for i, line in enumerate(defaults_full):
if line.startswith("def"):
if "set_defaults" in line:
keep = True
else:
keep = False

if keep:
defaults.append(line)

# remove function declaration and return
defaults = defaults[2:-2]

# remove indent
defaults = [line[4:] for line in defaults]


print("\n".join(defaults))

8 changes: 8 additions & 0 deletions docs/pages/tutorials.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Tutorials
=========

Currently, there are two places to look for tutorials:

* The README of our `GitHub repository <https://github.com/tumaer/jax-sph>`_.
* The `notebooks <https://github.com/tumaer/jax-sph/tree/main/notebooks>`_ in the same
repository.
2 changes: 1 addition & 1 deletion jax_sph/case_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import jax.numpy as jnp
import numpy as np
from jax import vmap
from jax_md import space

from jax_sph.eos import RIEMANNEoS, TaitEoS
from jax_sph.io_state import read_h5
from jax_sph.jax_md import space
from jax_sph.utils import (
Tag,
get_noise_masked,
Expand Down
64 changes: 32 additions & 32 deletions jax_sph/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
### global and hardware-related configs

# .yaml case configuration file
cfg.config = None # previously: case
cfg.config = None
# Seed for random number generator
cfg.seed = 123
# Whether to disable jitting compilation
cfg.no_jit = False
# Which GPU to use. -1 for CPU
cfg.gpu = 0
# Data type. One of "float32" or "float64"
cfg.dtype = "float64" # previously: no_f64
cfg.dtype = "float64"
# XLA memory fraction to be preallocated. The JAX default is 0.75.
# Should be specified before importing the library.
cfg.xla_mem_fraction = 0.75
Expand All @@ -30,30 +30,30 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
# Simulation mode. One of "sim" (run simulation) or "rlx" (run relaxation)
cfg.case.mode = "sim"
# Dimension of the simulation. One of 2 or 3
cfg.case.dim = 3 # previously: dim
cfg.case.dim = 3
# Average distance between particles [0.001, 0.1]
cfg.case.dx = 0.05 # previously: dx
cfg.case.dx = 0.05
# Initial state h5 path. Overrides `r0_type`. Can be useful to restart a simulation.
cfg.case.state0_path = None # previously: state0-path
cfg.case.state0_path = None
# Which properties to adopt from state0_path. Include all to restart a simulation.
cfg.case.state0_keys = ["r"]
# Position initialization type. One of "cartesian" or "relaxed". Cartesian can have
# `r0_noise_factor` and relaxed requires a state to be present in `data_relaxed`.
cfg.case.r0_type = "cartesian" # previously: r0-type
cfg.case.r0_type = "cartesian"
# How much Gaussian noise to add to r0. ( _ * dx)
cfg.case.r0_noise_factor = 0.0 # previously: r0-noise-factor
cfg.case.r0_noise_factor = 0.0
# Magnitude of external force field
cfg.case.g_ext_magnitude = 0.0 # previously: g-ext-magnitude
cfg.case.g_ext_magnitude = 0.0
# Reference dynamic viscosity. Inversely proportional to Re.
cfg.case.viscosity = 0.01 # previously: viscosity
cfg.case.viscosity = 0.01
# Estimate max flow velocity to calculate artificial speed of sound.
cfg.case.u_ref = 1.0 # previously: u_ref
cfg.case.u_ref = 1.0
# Reference speed of sound factor w.r.t. u_ref.
cfg.case.c_ref_factor = 10.0 # previously: p-bg-factor
cfg.case.c_ref_factor = 10.0
# Reference density
cfg.case.rho_ref = 1.0
# Reference temperature
cfg.case.T_ref = 1.0 # previously: T-ref
cfg.case.T_ref = 1.0
# Reference thermal conductivity
cfg.case.kappa_ref = 0.0
# Reference heat capacity at constant pressure
Expand All @@ -65,29 +65,29 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
cfg.solver = OmegaConf.create({})

# Solver name. One of "SPH" (standard SPH) or "RIE" (Riemann SPH)
cfg.solver.name = "SPH" # previously: solver
cfg.solver.name = "SPH"
# Transport velocity inclusion factor [0,...,1]
cfg.solver.tvf = 0.0 # previously: tvf
cfg.solver.tvf = 0.0
# CFL condition factor
cfg.solver.cfl = 0.25 # previously: cfl
cfg.solver.cfl = 0.25
# Density evolution vs density summation
cfg.solver.density_evolution = False # previously: density-evolution
cfg.solver.density_evolution = False
# Density renormalization when density evolution
cfg.solver.density_renormalize = False # previously: density-renormalize
cfg.solver.density_renormalize = False
# Integration time step. If None, it is calculated from the CFL condition.
cfg.solver.dt = None # previously: dt
cfg.solver.dt = None
# Physical time length of simulation
cfg.solver.t_end = 0.2 # previously: t-end
cfg.solver.t_end = 0.2
# Parameter alpha of artificial viscosity term
cfg.solver.artificial_alpha = 0.0 # previously: artificial-alpha
cfg.solver.artificial_alpha = 0.0
# Whether to turn on free-slip boundary condition
cfg.solver.free_slip = False # previously: free-slip
cfg.solver.free_slip = False
# Riemann dissipation limiter parameter, -1 = off
cfg.solver.eta_limiter = 3 # previously: eta-limiter
cfg.solver.eta_limiter = 3
# Thermal conductivity (non-dimensional)
cfg.solver.kappa = 0 # previously: kappa
cfg.solver.kappa = 0
# Whether to apply the heat conduction term
cfg.solver.heat_conduction = False # previously: heat-conduction
cfg.solver.heat_conduction = False
# Whether to apply boundaty conditions
cfg.solver.is_bc_trick = False # new

Expand All @@ -102,37 +102,37 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
# "WC6K" (Wendland C4 kernel)
# "GK" (gaussian kernel)
# "SGK" (super gaussian kernel)
cfg.kernel.name = "QSK" # previously: kernel
cfg.kernel.name = "QSK"
# Smoothing length factor
cfg.kernel.h_factor = 1.0 # new. Should default to 1.3 WC2K and 1.0 QSK

### equation of state
cfg.eos = OmegaConf.create({})

# EoS name. One of "Tait" or "RIEMANN"
cfg.eos.name = "Tait" # previously: eos
cfg.eos.name = "Tait"
# power in the Tait equation of state
cfg.eos.gamma = 1.0
# background pressure factor w.r.t. p_ref
cfg.eos.p_bg_factor = 0.0 # previously: p-bg-factor
cfg.eos.p_bg_factor = 0.0

### neighbor list
cfg.nl = OmegaConf.create({})

# Neighbor list backend. One of "jaxmd_vmap", "jaxmd_scan", "matscipy"
cfg.nl.backend = "jaxmd_vmap" # previously: nl-backend
cfg.nl.backend = "jaxmd_vmap"
# Number of partitions for neighbor list. Applies to jaxmd_scan only.
cfg.nl.num_partitions = 1 # previously: num-partitions
cfg.nl.num_partitions = 1

### output writing
cfg.io = OmegaConf.create({})

# In which format to write states. A subset of ["h5", "vtk"]
cfg.io.write_type = [] # previously: write-h5, write-vtk
cfg.io.write_type = []
# Every `write_every` step will be saved
cfg.io.write_every = 1 # previously: write-every
cfg.io.write_every = 1
# Where to write and read data
cfg.io.data_path = "./" # previously: data-path
cfg.io.data_path = "./"
# What to print to stdout. As list of possible properties.
cfg.io.print_props = ["Ekin", "u_max"]

Expand Down
Loading

0 comments on commit 669d8b3

Please sign in to comment.