Skip to content

Commit

Permalink
Merge pull request #62 from mehdiataei/major-refactoring
Browse files Browse the repository at this point in the history
Removed the need to have separate JAX/Warp constants
  • Loading branch information
hsalehipour authored Sep 13, 2024
2 parents fe8c945 + e2028cb commit a396328
Show file tree
Hide file tree
Showing 39 changed files with 213 additions and 130 deletions.
13 changes: 11 additions & 2 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import jax.numpy as jnp
import time
import jax


class FlowOverSphere:
Expand Down Expand Up @@ -118,7 +119,11 @@ def post_process(self, i):
else:
f_0 = self.f_0

macro = Macroscopic(compute_backend=ComputeBackend.JAX)
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D3Q19(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
)
rho, u = macro(f_0)

# remove boundary cells
Expand All @@ -135,9 +140,13 @@ def post_process(self, i):
if __name__ == "__main__":
# Running the simulation
grid_shape = (512 // 2, 128 // 2, 128 // 2)
velocity_set = xlb.velocity_set.D3Q19()
backend = ComputeBackend.WARP
precision_policy = PrecisionPolicy.FP32FP32

if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32:
jax.config.update("jax_enable_x64", True)

velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy)
Expand Down
14 changes: 12 additions & 2 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from xlb.operator.macroscopic import Macroscopic
from xlb.utils import save_fields_vtk, save_image
import warp as wp
import jax
import jax.numpy as jnp
import xlb.velocity_set


class LidDrivenCavity2D:
Expand Down Expand Up @@ -80,7 +82,11 @@ def post_process(self, i):
else:
f_0 = self.f_0

macro = Macroscopic(compute_backend=ComputeBackend.JAX)
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
)

rho, u = macro(f_0)

Expand All @@ -100,8 +106,12 @@ def post_process(self, i):
grid_size = 500
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.WARP
velocity_set = xlb.velocity_set.D2Q9()
precision_policy = PrecisionPolicy.FP32FP32

if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32:
jax.config.update("jax_enable_x64", True)

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy)
Expand Down
7 changes: 6 additions & 1 deletion examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import xlb
import jax
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.operator.stepper import IncompressibleNavierStokesStepper
Expand Down Expand Up @@ -26,8 +27,12 @@ def setup_stepper(self, omega):
grid_size = 512
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet!
velocity_set = xlb.velocity_set.D2Q9()
precision_policy = PrecisionPolicy.FP32FP32

if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32:
jax.config.update("jax_enable_x64", True)

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy)
Expand Down
8 changes: 6 additions & 2 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@ def post_process(self, i):
else:
f_0 = self.f_0

macro = Macroscopic(compute_backend=ComputeBackend.JAX)
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D3Q27(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
)

rho, u = macro(f_0)

Expand Down Expand Up @@ -215,8 +219,8 @@ def plot_drag_coefficient(self):

# Configuration
backend = ComputeBackend.WARP
velocity_set = xlb.velocity_set.D3Q27()
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, backend=backend)
wind_speed = 0.02
num_steps = 100000
print_interval = 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/boundary_conditions/mask/test_bc_indices_masker_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape",
[
Expand Down
5 changes: 3 additions & 2 deletions tests/boundary_conditions/mask/test_bc_indices_masker_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
13 changes: 7 additions & 6 deletions tests/grids/test_grid_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import jax.numpy as jnp


def init_xlb_env():
def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=xlb.velocity_set.D2Q9, # does not affect the test
velocity_set=vel_set,
)


@pytest.mark.parametrize("grid_size", [50, 100, 150])
def test_jax_2d_grid_initialization(grid_size):
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)
grid_shape = (grid_size, grid_size)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9)
Expand All @@ -34,7 +35,7 @@ def test_jax_2d_grid_initialization(grid_size):

@pytest.mark.parametrize("grid_size", [50, 100, 150])
def test_jax_3d_grid_initialization(grid_size):
init_xlb_env()
init_xlb_env(xlb.velocity_set.D3Q19)
grid_shape = (grid_size, grid_size, grid_size)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9)
Expand All @@ -54,7 +55,7 @@ def test_jax_3d_grid_initialization(grid_size):


def test_jax_grid_create_field_fill_value():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)
grid_shape = (100, 100)
fill_value = 3.14
my_grid = grid_factory(grid_shape)
Expand All @@ -66,7 +67,7 @@ def test_jax_grid_create_field_fill_value():

@pytest.fixture(autouse=True)
def setup_xlb_env():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)


if __name__ == "__main__":
Expand Down
14 changes: 7 additions & 7 deletions tests/grids/test_grid_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
from xlb.precision_policy import Precision


def init_xlb_env():
def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=xlb.velocity_set.D2Q9,
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


@pytest.mark.parametrize("grid_size", [50, 100, 150])
def test_warp_grid_create_field(grid_size):
for grid_shape in [(grid_size, grid_size), (grid_size, grid_size, grid_size)]:
init_xlb_env()
init_xlb_env(xlb.velocity_set.D3Q19)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9, dtype=Precision.FP32)

Expand All @@ -27,7 +27,7 @@ def test_warp_grid_create_field(grid_size):


def test_warp_grid_create_field_fill_value():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)
grid_shape = (100, 100)
fill_value = 3.14
my_grid = grid_factory(grid_shape)
Expand All @@ -42,7 +42,7 @@ def test_warp_grid_create_field_fill_value():

@pytest.fixture(autouse=True)
def setup_xlb_env():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/collision/test_bgk_collision_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape,omega",
[
Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/collision/test_bgk_collision_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/equilibrium/test_equilibrium_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape",
[
Expand Down
7 changes: 3 additions & 4 deletions tests/kernels/equilibrium/test_equilibrium_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape",
[
Expand Down
3 changes: 2 additions & 1 deletion tests/kernels/macroscopic/test_macroscopic_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/macroscopic/test_macroscopic_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
Loading

0 comments on commit a396328

Please sign in to comment.