Skip to content

Commit

Permalink
Merge pull request #10 from tumaer/add-Kernels
Browse files Browse the repository at this point in the history
Added  kernels and pytests
  • Loading branch information
arturtoshev authored Jun 7, 2024
2 parents 84a61f5 + 51a8b8e commit c9fe1f1
Show file tree
Hide file tree
Showing 7 changed files with 409 additions and 9 deletions.
9 changes: 8 additions & 1 deletion jax_sph/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,14 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
### kernel
cfg.kernel = OmegaConf.create({})

# Kernel name. One of "QSK" (quintic spline kernel) or "WC2K" (Wendland C2 kernel)
# Kernel name, choose one of:
# "CSK" (cubic spline kernel)
# "QSK" (quintic spline kernel)
# "WC2K" (Wendland C2 kernel)
# "WC4K" (Wendland C4 kernel)
# "WC6K" (Wendland C4 kernel)
# "GK" (gaussian kernel)
# "SGK" (super gaussian kernel)
cfg.kernel.name = "QSK" # previously: kernel
# Smoothing length factor
cfg.kernel.h_factor = 1.0 # new. Should default to 1.3 WC2K and 1.0 QSK
Expand Down
123 changes: 123 additions & 0 deletions jax_sph/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,31 @@ def grad_w(self, r):
return grad(self.w)(r)


class CubicKernel(BaseKernel):
"""The cubic kernel function of Monaghan."""

def __init__(self, h, dim=3):
self._one_over_h = 1.0 / h

self._normalized_cutoff = 2.0
self.cutoff = self._normalized_cutoff * h
if dim == 1:
self._sigma = 2.0 / 3.0 * self._one_over_h
elif dim == 2:
self._sigma = 10.0 / 7.0 / jnp.pi * self._one_over_h**2
elif dim == 3:
self._sigma = 1.0 / jnp.pi * self._one_over_h**3

def w(self, r):
q = r * self._one_over_h
c1 = jnp.where(1 - q >= 0, 1, 0)
c2 = jnp.where(jnp.logical_and(2 - q < 1, 2 - q >= 0), 1, 0)
q1 = 1 - 1.5 * q**2 * (1 - q / 2)
q2 = 0.25 * (2 - q) ** 3

return self._sigma * (q1 * c1 + q2 * c2)


class QuinticKernel(BaseKernel):
"""The quintic kernel function of Morris."""

Expand Down Expand Up @@ -76,3 +101,101 @@ def w(self, r):
q2 = 2.0 * q + 1.0

return self._sigma * (q1**4 * q2)


class WendlandC4Kernel(BaseKernel):
"""The 5th-order C4 kernel function of Wendland."""

def __init__(self, h, dim=3):
self._one_over_h = 1.0 / h
self.dim = dim

self._normalized_cutoff = 2.0
self.cutoff = self._normalized_cutoff * h
if dim == 1:
self._sigma = 3.0 / 4.0 * self._one_over_h
elif dim == 2:
self._sigma = 9.0 / 4.0 / jnp.pi * self._one_over_h**2
elif dim == 3:
self._sigma = 495.0 / 256.0 / jnp.pi * self._one_over_h**3

def w(self, r):
if self.dim == 1:
q = r * self._one_over_h
q1 = jnp.maximum(0.0, 1.0 - 0.5 * q)
q2 = 2.0 * q**2 + 2.5 * q + 1.0

return self._sigma * (q1**5 * q2)
else:
q = r * self._one_over_h
q1 = jnp.maximum(0.0, 1.0 - 0.5 * q)
q2 = 35.0 / 12.0 * q**2 + 3 * q + 1.0

return self._sigma * (q1**6 * q2)


class WendlandC6Kernel(BaseKernel):
"""The 5th-order C6 kernel function of Wendland."""

def __init__(self, h, dim=3):
self._one_over_h = 1.0 / h
self.dim = dim

self._normalized_cutoff = 2.0
self.cutoff = self._normalized_cutoff * h
if dim == 1:
self._sigma = 55.0 / 64.0 * self._one_over_h
elif dim == 2:
self._sigma = 78.0 / 28.0 / jnp.pi * self._one_over_h**2
elif dim == 3:
self._sigma = 1365.0 / 512.0 / jnp.pi * self._one_over_h**3

def w(self, r):
if self.dim == 1:
q = r * self._one_over_h
q1 = jnp.maximum(0.0, 1.0 - 0.5 * q)
q2 = 21.0 / 8.0 * q**3 + 19.0 / 4.0 * q**2 + 3.5 * q + 1.0

return self._sigma * (q1**7 * q2)
else:
q = r * self._one_over_h
q1 = jnp.maximum(0.0, 1.0 - 0.5 * q)
q2 = 4.0 * q**3 + 6.25 * q**2 + 4 * q + 1.0

return self._sigma * (q1**8 * q2)


class GaussianKernel(BaseKernel):
"""The gaussian kernel function of Monaghan."""

def __init__(self, h, dim=3):
self._one_over_h = 1.0 / h

self._normalized_cutoff = 3.0
self.cutoff = self._normalized_cutoff * h
self._sigma = 1.0 / jnp.pi ** (dim / 2) * self._one_over_h ** (dim)

def w(self, r):
q = r * self._one_over_h
q1 = jnp.where(3 - q >= 0, 1, 0)

return self._sigma * q1 * jnp.exp(-(q**2))


class SuperGaussianKernel(BaseKernel):
# TODO: We want this? Intendent but negativ in some regions
"""The supergaussian kernel function of Monaghan."""

def __init__(self, h, dim=3):
self._one_over_h = 1.0 / h
self.dim = dim

self._normalized_cutoff = 3.0
self.cutoff = self._normalized_cutoff * h
self._sigma = 1.0 / jnp.pi ** (dim / 2) * self._one_over_h ** (dim)

def w(self, r):
q = r * self._one_over_h
q1 = jnp.where(3 - q >= 0, 1, 0)

return self._sigma * q1 * jnp.exp(-(q**2)) * (self.dim / 2 + 1 - q**2)
29 changes: 24 additions & 5 deletions jax_sph/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
from jax_md import space

from jax_sph.eos import RIEMANNEoS, TaitEoS
from jax_sph.kernel import QuinticKernel, WendlandC2Kernel
from jax_sph.kernel import (
CubicKernel,
GaussianKernel,
QuinticKernel,
SuperGaussianKernel,
WendlandC2Kernel,
WendlandC4Kernel,
WendlandC6Kernel,
)
from jax_sph.utils import Tag, wall_tags

EPS = jnp.finfo(float).eps
Expand Down Expand Up @@ -478,10 +486,21 @@ def __init__(
self.is_heat_conduction = is_heat_conduction

_beta_fn = limiter_fn_wrapper(eta_limiter, c_ref)
if kernel == "QSK":
self._kernel_fn = QuinticKernel(h=dx, dim=dim)
elif kernel == "WC2K":
self._kernel_fn = WendlandC2Kernel(h=1.3 * dx, dim=dim)
match kernel:
case "CSK":
self._kernel_fn = CubicKernel(h=dx, dim=dim)
case "QSK":
self._kernel_fn = QuinticKernel(h=dx, dim=dim)
case "WC2K":
self._kernel_fn = WendlandC2Kernel(h=1.3 * dx, dim=dim)
case "WC4K":
self._kernel_fn = WendlandC4Kernel(h=1.3 * dx, dim=dim)
case "WC6K":
self._kernel_fn = WendlandC6Kernel(h=1.3 * dx, dim=dim)
case "GK":
self._kernel_fn = GaussianKernel(h=dx, dim=dim)
case "SGK":
self._kernel_fn = SuperGaussianKernel(h=dx, dim=dim)

self._gwbc_fn = gwbc_fn_wrapper(is_free_slip, is_heat_conduction, eos)
self._free_weight, self._heat_bc = gwbc_fn_riemann_wrapper(
Expand Down
121 changes: 121 additions & 0 deletions notebooks/kernel_plots.ipynb

Large diffs are not rendered by default.

19 changes: 17 additions & 2 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,26 @@
import pytest
from jax import vmap

from jax_sph.kernel import QuinticKernel, WendlandC2Kernel
from jax_sph.kernel import (
CubicKernel,
GaussianKernel,
QuinticKernel,
WendlandC2Kernel,
WendlandC4Kernel,
WendlandC6Kernel,
)


@pytest.mark.parametrize(
"Kernel, dx_factor", [(QuinticKernel, 1), (WendlandC2Kernel, 1.3)]
"Kernel, dx_factor",
[
(CubicKernel, 1),
(QuinticKernel, 1),
(WendlandC2Kernel, 1.3),
(WendlandC4Kernel, 1.3),
(WendlandC6Kernel, 1.3),
(GaussianKernel, 1),
],
)
def test_kernel_1d(Kernel, dx_factor):
"""Test the interpolation kernels in 1 dimension."""
Expand Down
115 changes: 115 additions & 0 deletions tests/test_pf2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Test a full run of the solver on the Poiseuille flow case from the validations."""

import os

import jax.numpy as jnp
import numpy as np
import pytest
from jax import config
from omegaconf import OmegaConf

from main import load_embedded_configs


def u_series_exp(y, t, n_max=10):
"""Analytical solution to unsteady Poiseuille flow (low Re)
Based on Series expansion as shown in:
"Modeling Low Reynolds Number Incompressible Flows Using SPH"
ba Morris et al. 1997
"""

eta = 100.0 # dynamic viscosity
rho = 1.0 # denstiy
nu = eta / rho # kinematic viscosity
u_max = 1.25 # max velocity in middle of channel
d = 1.0 # channel width
fx = -8 * nu * u_max / d**2
offset = fx / (2 * nu) * y * (y - d)

def term(n):
base = np.pi * (2 * n + 1) / d
prefactor = 4 * fx / (nu * base**3 * d)
sin_term = np.sin(base * y)
exp_term = np.exp(-(base**2) * nu * t)
return prefactor * sin_term * exp_term

res = offset
for i in range(n_max):
res += term(i)

return res


@pytest.fixture
def setup_simulation():
y_axis = np.linspace(0, 1, 21)
t_dimless = [0.0005, 0.001, 0.005]
# get analytical solution
ref_solutions = []
for t_val in t_dimless:
ref_solutions.append(u_series_exp(y_axis, t_val))
return y_axis, t_dimless, ref_solutions


def run_simulation(tmp_path, tvf, solver):
"""Emulate `main.py`."""
data_path = tmp_path / f"pf_test_{tvf}"

cli_args = OmegaConf.create(
{
"config": "cases/pf.yaml",
"case": {"dx": 0.0333333},
"solver": {"name": solver, "tvf": tvf, "dt": 0.000002, "t_end": 0.005},
"io": {"write_every": 250, "data_path": str(data_path)},
}
)
cfg = load_embedded_configs(cli_args)

# Specify cuda device. These setting must be done before importing jax-md.
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu)
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cfg.xla_mem_fraction)

if cfg.dtype == "float64":
config.update("jax_enable_x64", True)

from jax_sph.simulate import simulate

simulate(cfg)

return data_path


def get_solution(data_path, t_dimless, y_axis):
from jax_sph.utils import sph_interpolator

dir = os.listdir(data_path)[0]
cfg = OmegaConf.load(data_path / dir / "config.yaml")
step_max = np.array(np.rint(cfg.solver.t_end / cfg.solver.dt), dtype=int)
digits = len(str(step_max))

y_axis += 3 * cfg.case.dx
rs = 0.2 * jnp.ones([y_axis.shape[0], 2])
rs = rs.at[:, 1].set(y_axis)
solutions = []
for i in range(len(t_dimless)):
file_name = (
"traj_" + str(int(t_dimless[i] / cfg.solver.dt)).zfill(digits) + ".h5"
)
src_path = data_path / dir / file_name
interp_vel_fn = sph_interpolator(cfg, src_path)
solutions.append(interp_vel_fn(src_path, rs, prop="u", dim_ind=0))
return solutions


@pytest.mark.parametrize("tvf, solver", [(0.0, "SPH"), (1.0, "SPH")]) # (0.0, "RIE")
def test_pf2d(tvf, solver, tmp_path, setup_simulation):
"""Test whether the poiseuille flow simulation matches the analytical solution"""
y_axis, t_dimless, ref_solutions = setup_simulation
data_path = run_simulation(tmp_path, tvf, solver)
# print(f"tmp_path = {tmp_path}, subdirs = {subdirs}")
solutions = get_solution(data_path, t_dimless, y_axis)
# print(f"solution: {solutions[-1]} \nref_solution: {ref_solutions[-1]}")
for sol, ref_sol in zip(solutions, ref_solutions):
assert np.allclose(sol, ref_sol, atol=1e-2), "Velocity profile does not match."
2 changes: 1 addition & 1 deletion validation/tgv2d.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
# Generate validation data and validate 2D TGV
# with number of particles per direction nx = [20, 50, 100]
# with number of particles per direction nx = [50, 100]
# Reference result from:
# "A Transport Velocty [...]", Adami 2012

Expand Down

0 comments on commit c9fe1f1

Please sign in to comment.