Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added kernels and pytests #10

Merged
merged 9 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion jax_sph/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,14 @@ def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
### kernel
cfg.kernel = OmegaConf.create({})

# Kernel name. One of "QSK" (quintic spline kernel) or "WC2K" (Wendland C2 kernel)
# Kernel name, choose one of:
# "CSK" (cubic spline kernel)
# "QSK" (quintic spline kernel)
# "WC2K" (Wendland C2 kernel)
# "WC4K" (Wendland C4 kernel)
# "WC6K" (Wendland C4 kernel)
# "GK" (gaussian kernel)
# "SGK" (super gaussian kernel)
cfg.kernel.name = "QSK" # previously: kernel
# Smoothing length factor
cfg.kernel.h_factor = 1.0 # new. Should default to 1.3 WC2K and 1.0 QSK
Expand Down
123 changes: 123 additions & 0 deletions jax_sph/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,31 @@ def grad_w(self, r):
return grad(self.w)(r)


class CubicKernel(BaseKernel):
"""The cubic kernel function of Monaghan."""

def __init__(self, h, dim=3):
self._one_over_h = 1.0 / h

self._normalized_cutoff = 2.0
self.cutoff = self._normalized_cutoff * h
if dim == 1:
self._sigma = 2.0 / 3.0 * self._one_over_h
elif dim == 2:
self._sigma = 10.0 / 7.0 / jnp.pi * self._one_over_h**2
elif dim == 3:
self._sigma = 1.0 / jnp.pi * self._one_over_h**3

def w(self, r):
q = r * self._one_over_h
c1 = jnp.where(1 - q >= 0, 1, 0)
c2 = jnp.where(jnp.logical_and(2 - q < 1, 2 - q >= 0), 1, 0)
q1 = 1 - 1.5 * q**2 * (1 - q / 2)
q2 = 0.25 * (2 - q) ** 3

return self._sigma * (q1 * c1 + q2 * c2)


class QuinticKernel(BaseKernel):
"""The quintic kernel function of Morris."""

Expand Down Expand Up @@ -76,3 +101,101 @@ def w(self, r):
q2 = 2.0 * q + 1.0

return self._sigma * (q1**4 * q2)


class WendlandC4Kernel(BaseKernel):
"""The 5th-order C4 kernel function of Wendland."""

def __init__(self, h, dim=3):
self._one_over_h = 1.0 / h
self.dim = dim

self._normalized_cutoff = 2.0
self.cutoff = self._normalized_cutoff * h
if dim == 1:
self._sigma = 3.0 / 4.0 * self._one_over_h
elif dim == 2:
self._sigma = 9.0 / 4.0 / jnp.pi * self._one_over_h**2
elif dim == 3:
self._sigma = 495.0 / 256.0 / jnp.pi * self._one_over_h**3

def w(self, r):
if self.dim == 1:
q = r * self._one_over_h
q1 = jnp.maximum(0.0, 1.0 - 0.5 * q)
q2 = 2.0 * q**2 + 2.5 * q + 1.0

return self._sigma * (q1**5 * q2)
else:
q = r * self._one_over_h
q1 = jnp.maximum(0.0, 1.0 - 0.5 * q)
q2 = 35.0 / 12.0 * q**2 + 3 * q + 1.0

return self._sigma * (q1**6 * q2)


class WendlandC6Kernel(BaseKernel):
"""The 5th-order C6 kernel function of Wendland."""

def __init__(self, h, dim=3):
self._one_over_h = 1.0 / h
self.dim = dim

self._normalized_cutoff = 2.0
self.cutoff = self._normalized_cutoff * h
if dim == 1:
self._sigma = 55.0 / 64.0 * self._one_over_h
elif dim == 2:
self._sigma = 78.0 / 28.0 / jnp.pi * self._one_over_h**2
elif dim == 3:
self._sigma = 1365.0 / 512.0 / jnp.pi * self._one_over_h**3

def w(self, r):
if self.dim == 1:
q = r * self._one_over_h
q1 = jnp.maximum(0.0, 1.0 - 0.5 * q)
q2 = 21.0 / 8.0 * q**3 + 19.0 / 4.0 * q**2 + 3.5 * q + 1.0

return self._sigma * (q1**7 * q2)
else:
q = r * self._one_over_h
q1 = jnp.maximum(0.0, 1.0 - 0.5 * q)
q2 = 4.0 * q**3 + 6.25 * q**2 + 4 * q + 1.0

return self._sigma * (q1**8 * q2)


class GaussianKernel(BaseKernel):
"""The gaussian kernel function of Monaghan."""

def __init__(self, h, dim=3):
self._one_over_h = 1.0 / h

self._normalized_cutoff = 3.0
self.cutoff = self._normalized_cutoff * h
self._sigma = 1.0 / jnp.pi ** (dim / 2) * self._one_over_h ** (dim)

def w(self, r):
q = r * self._one_over_h
q1 = jnp.where(3 - q >= 0, 1, 0)

return self._sigma * q1 * jnp.exp(-(q**2))


class SuperGaussianKernel(BaseKernel):
# TODO: We want this? Intendent but negativ in some regions
"""The supergaussian kernel function of Monaghan."""

def __init__(self, h, dim=3):
self._one_over_h = 1.0 / h
self.dim = dim

self._normalized_cutoff = 3.0
self.cutoff = self._normalized_cutoff * h
self._sigma = 1.0 / jnp.pi ** (dim / 2) * self._one_over_h ** (dim)

def w(self, r):
q = r * self._one_over_h
q1 = jnp.where(3 - q >= 0, 1, 0)

return self._sigma * q1 * jnp.exp(-(q**2)) * (self.dim / 2 + 1 - q**2)
29 changes: 24 additions & 5 deletions jax_sph/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
from jax_md import space

from jax_sph.eos import RIEMANNEoS, TaitEoS
from jax_sph.kernel import QuinticKernel, WendlandC2Kernel
from jax_sph.kernel import (
CubicKernel,
GaussianKernel,
QuinticKernel,
SuperGaussianKernel,
WendlandC2Kernel,
WendlandC4Kernel,
WendlandC6Kernel,
)
from jax_sph.utils import Tag, wall_tags

EPS = jnp.finfo(float).eps
Expand Down Expand Up @@ -478,10 +486,21 @@ def __init__(
self.is_heat_conduction = is_heat_conduction

_beta_fn = limiter_fn_wrapper(eta_limiter, c_ref)
if kernel == "QSK":
self._kernel_fn = QuinticKernel(h=dx, dim=dim)
elif kernel == "WC2K":
self._kernel_fn = WendlandC2Kernel(h=1.3 * dx, dim=dim)
match kernel:
case "CSK":
self._kernel_fn = CubicKernel(h=dx, dim=dim)
case "QSK":
self._kernel_fn = QuinticKernel(h=dx, dim=dim)
case "WC2K":
self._kernel_fn = WendlandC2Kernel(h=1.3 * dx, dim=dim)
case "WC4K":
self._kernel_fn = WendlandC4Kernel(h=1.3 * dx, dim=dim)
case "WC6K":
self._kernel_fn = WendlandC6Kernel(h=1.3 * dx, dim=dim)
case "GK":
self._kernel_fn = GaussianKernel(h=dx, dim=dim)
case "SGK":
self._kernel_fn = SuperGaussianKernel(h=dx, dim=dim)

self._gwbc_fn = gwbc_fn_wrapper(is_free_slip, is_heat_conduction, eos)
self._free_weight, self._heat_bc = gwbc_fn_riemann_wrapper(
Expand Down
121 changes: 121 additions & 0 deletions notebooks/kernel_plots.ipynb

Large diffs are not rendered by default.

19 changes: 17 additions & 2 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,26 @@
import pytest
from jax import vmap

from jax_sph.kernel import QuinticKernel, WendlandC2Kernel
from jax_sph.kernel import (
CubicKernel,
GaussianKernel,
QuinticKernel,
WendlandC2Kernel,
WendlandC4Kernel,
WendlandC6Kernel,
)


@pytest.mark.parametrize(
"Kernel, dx_factor", [(QuinticKernel, 1), (WendlandC2Kernel, 1.3)]
"Kernel, dx_factor",
[
(CubicKernel, 1),
(QuinticKernel, 1),
(WendlandC2Kernel, 1.3),
(WendlandC4Kernel, 1.3),
(WendlandC6Kernel, 1.3),
(GaussianKernel, 1),
],
)
def test_kernel_1d(Kernel, dx_factor):
"""Test the interpolation kernels in 1 dimension."""
Expand Down
115 changes: 115 additions & 0 deletions tests/test_pf2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Test a full run of the solver on the Poiseuille flow case from the validations."""

import os

import jax.numpy as jnp
import numpy as np
import pytest
from jax import config
from omegaconf import OmegaConf

from main import load_embedded_configs


def u_series_exp(y, t, n_max=10):
"""Analytical solution to unsteady Poiseuille flow (low Re)

Based on Series expansion as shown in:
"Modeling Low Reynolds Number Incompressible Flows Using SPH"
ba Morris et al. 1997
"""

eta = 100.0 # dynamic viscosity
rho = 1.0 # denstiy
nu = eta / rho # kinematic viscosity
u_max = 1.25 # max velocity in middle of channel
d = 1.0 # channel width
fx = -8 * nu * u_max / d**2
offset = fx / (2 * nu) * y * (y - d)

def term(n):
base = np.pi * (2 * n + 1) / d
prefactor = 4 * fx / (nu * base**3 * d)
sin_term = np.sin(base * y)
exp_term = np.exp(-(base**2) * nu * t)
return prefactor * sin_term * exp_term

res = offset
for i in range(n_max):
res += term(i)

return res


@pytest.fixture
def setup_simulation():
y_axis = np.linspace(0, 1, 21)
t_dimless = [0.0005, 0.001, 0.005]
# get analytical solution
ref_solutions = []
for t_val in t_dimless:
ref_solutions.append(u_series_exp(y_axis, t_val))
return y_axis, t_dimless, ref_solutions


def run_simulation(tmp_path, tvf, solver):
"""Emulate `main.py`."""
data_path = tmp_path / f"pf_test_{tvf}"

cli_args = OmegaConf.create(
{
"config": "cases/pf.yaml",
"case": {"dx": 0.0333333},
"solver": {"name": solver, "tvf": tvf, "dt": 0.000002, "t_end": 0.005},
"io": {"write_every": 250, "data_path": str(data_path)},
}
)
cfg = load_embedded_configs(cli_args)

# Specify cuda device. These setting must be done before importing jax-md.
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu)
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cfg.xla_mem_fraction)

if cfg.dtype == "float64":
config.update("jax_enable_x64", True)

from jax_sph.simulate import simulate

simulate(cfg)

return data_path


def get_solution(data_path, t_dimless, y_axis):
from jax_sph.utils import sph_interpolator

dir = os.listdir(data_path)[0]
cfg = OmegaConf.load(data_path / dir / "config.yaml")
step_max = np.array(np.rint(cfg.solver.t_end / cfg.solver.dt), dtype=int)
digits = len(str(step_max))

y_axis += 3 * cfg.case.dx
rs = 0.2 * jnp.ones([y_axis.shape[0], 2])
rs = rs.at[:, 1].set(y_axis)
solutions = []
for i in range(len(t_dimless)):
file_name = (
"traj_" + str(int(t_dimless[i] / cfg.solver.dt)).zfill(digits) + ".h5"
)
src_path = data_path / dir / file_name
interp_vel_fn = sph_interpolator(cfg, src_path)
solutions.append(interp_vel_fn(src_path, rs, prop="u", dim_ind=0))
return solutions


@pytest.mark.parametrize("tvf, solver", [(0.0, "SPH"), (1.0, "SPH")]) # (0.0, "RIE")
def test_pf2d(tvf, solver, tmp_path, setup_simulation):
"""Test whether the poiseuille flow simulation matches the analytical solution"""
y_axis, t_dimless, ref_solutions = setup_simulation
data_path = run_simulation(tmp_path, tvf, solver)
# print(f"tmp_path = {tmp_path}, subdirs = {subdirs}")
solutions = get_solution(data_path, t_dimless, y_axis)
# print(f"solution: {solutions[-1]} \nref_solution: {ref_solutions[-1]}")
for sol, ref_sol in zip(solutions, ref_solutions):
assert np.allclose(sol, ref_sol, atol=1e-2), "Velocity profile does not match."
2 changes: 1 addition & 1 deletion validation/tgv2d.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
# Generate validation data and validate 2D TGV
# with number of particles per direction nx = [20, 50, 100]
# with number of particles per direction nx = [50, 100]
# Reference result from:
# "A Transport Velocty [...]", Adami 2012

Expand Down
Loading