Skip to content

Commit

Permalink
Merge branch 'main' into example_contour
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour authored Nov 27, 2023
2 parents 4d4f553 + f6f1ba5 commit 88185cb
Show file tree
Hide file tree
Showing 18 changed files with 828 additions and 476 deletions.
27 changes: 12 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,22 @@ The following examples showcase the capabilities of XLB:

To use XLB, you must first install JAX and other dependencies using the following commands:

```bash
# Please refer to https://github.com/google/jax for the latest installation documentation

pip install --upgrade pip

# For CPU run
pip install --upgrade "jax[cpu]"
Please refer to https://github.com/google/jax for the latest installation documentation. The following table is taken from [JAX's Github page](https://github.com/google/jax).

# For GPU run
| Hardware | Instructions |
|------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U "jax[cpu]"` |
| NVIDIA GPU on x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |

# CUDA 12 and cuDNN 8.8 or newer.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
**Note:** We encountered challenges when executing XLB on Apple GPUs due to the lack of support for certain operations in the Metal backend. We advise using the CPU backend on Mac OS. We will be testing XLB on Apple's GPUs in the future and will update this section accordingly.

# CUDA 11 and cuDNN 8.6 or newer.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Run dependencies
Install dependencies:
```bash
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp
```

Expand All @@ -118,6 +117,4 @@ export PYTHONPATH=.
python3 examples/cavity2d.py
```
## Citing XLB
Accompanying publication coming soon:

**M. Ataei, H. Salehipour**. XLB: Hardware-Accelerated, Scalable, and Differentiable Lattice Boltzmann Simulation Framework based on JAX. TBA
Accompanying paper will be available soon.
12 changes: 4 additions & 8 deletions examples/CFD/airfoil3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
# from IPython import display
import matplotlib.pylab as plt
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.boundary_conditions import *
from src.lattice import *
import numpy as np
from src.utils import *
from jax.config import config
from jax import config
import os
#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax
Expand Down Expand Up @@ -159,15 +159,13 @@ def output_data(self, **kwargs):
airfoil_thickness = 30
airfoil_angle = 20
airfoil = makeNacaAirfoil(length=airfoil_length, thickness=airfoil_thickness, angle=airfoil_angle).T

precision = 'f32/f32'
lattice = LatticeD3Q27(precision=precision)

lattice = LatticeD3Q27(precision)

nx = airfoil.shape[0]
ny = airfoil.shape[1]

print("airfoil shape: ", airfoil.shape)

ny = 3 * ny
nx = 5 * nx
nz = 101
Expand All @@ -178,7 +176,6 @@ def output_data(self, **kwargs):

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system('rm -rf ./*.vtk && rm -rf ./*.png')

Expand All @@ -195,5 +192,4 @@ def output_data(self, **kwargs):
}

sim = Airfoil(**kwargs)
print('Domain size: ', sim.nx, sim.ny, sim.nz)
sim.run(20000)
17 changes: 8 additions & 9 deletions examples/CFD/cavity2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
4. Visualization: The simulation outputs data in VTK format for visualization. It also provides images of the velocity field and saves the boundary conditions at each time step. The data can be visualized using software like Paraview.
"""
from src.boundary_conditions import *
from jax.config import config
from src.utils import *
from jax import config
import numpy as np
from src.lattice import LatticeD2Q9
from src.models import BGKSim, KBCSim
import jax.numpy as jnp
import os

from src.boundary_conditions import *
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD2Q9
from src.utils import *

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import jax

class Cavity(KBCSim):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -71,11 +71,10 @@ def output_data(self, **kwargs):
clength = nx - 1

checkpoint_rate = 1000
checkpoint_dir = "./checkpoints"
checkpoint_dir = os.path.abspath("./checkpoints")

visc = prescribed_vel * clength / Re
omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

Expand All @@ -90,7 +89,7 @@ def output_data(self, **kwargs):
'print_info_rate': 100,
'checkpoint_rate': checkpoint_rate,
'checkpoint_dir': checkpoint_dir,
'restore_checkpoint': True,
'restore_checkpoint': False,
}

sim = Cavity(**kwargs)
Expand Down
92 changes: 58 additions & 34 deletions examples/CFD/cavity3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
In this example you'll be introduced to the following concepts:
1. Lattice: The simulation employs a D2Q9 lattice. It's a 2D lattice model with nine discrete velocity directions, which is typically used for 2D simulations.
1. Lattice: The simulation employs a D3Q27 lattice. It's a 3D lattice model with 27 discrete velocity directions.
2. Boundary Conditions: The code implements two types of boundary conditions:
Expand All @@ -14,74 +14,98 @@
4. Visualization: The simulation outputs data in VTK format for visualization. The data can be visualized using software like Paraview.
"""

import os

# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q27

import numpy as np
from src.utils import *
from jax.config import config
from jax import config
import json, codecs

from src.models import BGKSim, KBCSim
from src.lattice import LatticeD3Q19, LatticeD3Q27
from src.boundary_conditions import *

precision = 'f32/f32'

config.update('jax_enable_x64', True)

class Cavity(KBCSim):
# Note: We have used BGK with D3Q19 (or D3Q27) for Re=(1000, 3200) and KBC with D3Q27 for Re=10,000
def __init__(self, **kwargs):
super().__init__(**kwargs)

def set_boundary_conditions(self):
# Note:
# We have used halfway BB for Re=(1000, 3200) and regularized BC for Re=10,000

# apply inlet boundary condition to the top wall
moving_wall = self.boundingBoxIndices['top']
vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype)
vel_wall[:, 0] = prescribed_vel
# self.BCs.append(BounceBackHalfway(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, vel_wall))
self.BCs.append(Regularized(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall))

# concatenate the indices of the left, right, and bottom walls
walls = np.concatenate(
(self.boundingBoxIndices['left'], self.boundingBoxIndices['right'],
self.boundingBoxIndices['front'], self.boundingBoxIndices['back'],
self.boundingBoxIndices['bottom']))
# apply bounce back boundary condition to the walls
self.BCs.append(BounceBack(tuple(walls.T), self.gridInfo, self.precisionPolicy))

# apply inlet equilibrium boundary condition to the top wall
moving_wall = self.boundingBoxIndices['top']

rho_wall = np.ones((moving_wall.shape[0], 1), dtype=self.precisionPolicy.compute_dtype)
vel_wall = np.zeros(moving_wall.shape, dtype=self.precisionPolicy.compute_dtype)
vel_wall[:, 0] = prescribed_vel
self.BCs.append(EquilibriumBC(tuple(moving_wall.T), self.gridInfo, self.precisionPolicy, rho_wall, vel_wall))
# self.BCs.append(BounceBackHalfway(tuple(walls.T), self.gridInfo, self.precisionPolicy))
vel_wall = np.zeros(walls.shape, dtype=self.precisionPolicy.compute_dtype)
self.BCs.append(Regularized(tuple(walls.T), self.gridInfo, self.precisionPolicy, 'velocity', vel_wall))
return

def output_data(self, **kwargs):
# 1: -1 to remove boundary voxels (not needed for visualization when using full-way bounce-back)
rho = np.array(kwargs['rho'][1:-1, 1:-1, 1:-1])
u = np.array(kwargs['u'][1:-1, 1:-1, 1:-1, :])
rho = np.array(kwargs['rho'])
u = np.array(kwargs['u'])
timestep = kwargs['timestep']
u_prev = kwargs['u_prev'][1:-1, 1:-1, 1:-1, :]
u_prev = kwargs['u_prev']

u_old = np.linalg.norm(u_prev, axis=2)
u_new = np.linalg.norm(u, axis=2)

err = np.sum(np.abs(u_old - u_new))
print('error= {:07.6f}'.format(err))
fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]}
save_fields_vtk(timestep, fields)
# save_fields_vtk(timestep, fields)

# output profiles of velocity at mid-plane for benchmarking
output_filename = "./profiles_" + f"{timestep:07d}.json"
ux_mid = 0.5*(u[nx//2, ny//2, :, 0] + u[nx//2+1, ny//2+1, :, 0])
uz_mid = 0.5*(u[:, ny//2, nz//2, 2] + u[:, ny//2+1, nz//2+1, 2])
ldc_ref_result = {'ux(x=y=0)': list(ux_mid/prescribed_vel),
'uz(z=y=0)': list(uz_mid/prescribed_vel)}
json.dump(ldc_ref_result, codecs.open(output_filename, 'w', encoding='utf-8'),
separators=(',', ':'),
sort_keys=True,
indent=4)

# Calculate the velocity magnitude
u_mag = np.linalg.norm(u, axis=2)
# u_mag = np.linalg.norm(u, axis=2)
# live_volume_randering(timestep, u_mag)

if __name__ == '__main__':
# Note:
# We have used BGK with D3Q19 (or D3Q27) for Re=(1000, 3200) and KBC with D3Q27 for Re=10,000
precision = 'f64/f64'
lattice = LatticeD3Q27(precision)

nx = 101
ny = 101
nz = 101
nx = 256
ny = 256
nz = 256

Re = 10000.0
prescribed_vel = 0.06
clength = nx - 2

Re = 50000.0
prescribed_vel = 0.1
clength = nx - 1
# characteristic time
tc = prescribed_vel/clength
niter_max = int(500//tc)

visc = prescribed_vel * clength / Re
omega = 1.0 / (3. * visc + 0.5)
print('omega = ', omega)

os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
Expand All @@ -91,9 +115,9 @@ def output_data(self, **kwargs):
'ny': ny,
'nz': nz,
'precision': precision,
'io_rate': 100,
'print_info_rate': 100,
'downsampling_factor': 2
'io_rate': int(10//tc),
'print_info_rate': int(10//tc),
'downsampling_factor': 1
}
sim = Cavity(**kwargs)
sim.run(2000)
sim.run(niter_max)
7 changes: 3 additions & 4 deletions examples/CFD/channel3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

from src.boundary_conditions import *
from jax.config import config
from jax import config
from src.utils import *
import numpy as np
from src.lattice import LatticeD3Q27
Expand Down Expand Up @@ -55,7 +55,7 @@ def get_dns_data():
}
return dns_dic

class turbulentChannel(KBCSim):
class TurbulentChannel(KBCSim):
def __init__(self, **kwargs):
super().__init__(**kwargs)

Expand All @@ -68,7 +68,7 @@ def set_boundary_conditions(self):
def initialize_macroscopic_fields(self):
rho = self.precisionPolicy.cast_to_output(1.0)
u = self.distributed_array_init((self.nx, self.ny, self.nz, self.dim),
self.precisionPolicy.compute_dtype, initVal=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim)))
self.precisionPolicy.compute_dtype, init_val=1e-2 * np.random.random((self.nx, self.ny, self.nz, self.dim)))
u = self.precisionPolicy.cast_to_output(u)
return rho, u

Expand Down Expand Up @@ -141,7 +141,6 @@ def output_data(self, **kwargs):
zz = np.minimum(zz, zz.max() - zz)
yplus = zz * u_tau / visc

print("omega = ", omega)
os.system("rm -rf ./*.vtk && rm -rf ./*.png")

kwargs = {
Expand Down
13 changes: 7 additions & 6 deletions examples/CFD/couette2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
This script performs a 2D simulation of Couette flow using the lattice Boltzmann method (LBM).
"""

from src.models import BGKSim
from src.boundary_conditions import *
from src.lattice import LatticeD2Q9
import os
import jax.numpy as jnp
import numpy as np
from src.utils import *
from jax.config import config
import os
from jax import config


from src.models import BGKSim
from src.boundary_conditions import *
from src.lattice import LatticeD2Q9

# config.update('jax_disable_jit', True)
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
Expand Down Expand Up @@ -60,7 +62,6 @@ def output_data(self, **kwargs):
visc = prescribed_vel * clength / Re

omega = 1.0 / (3.0 * visc + 0.5)
print("omega = ", omega)
assert omega < 1.98, "omega must be less than 2.0"
os.system("rm -rf ./*.vtk && rm -rf ./*.png")

Expand Down
Loading

0 comments on commit 88185cb

Please sign in to comment.