Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed the need to have separate JAX/Warp constants #62

Merged
merged 6 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import jax.numpy as jnp
import time
import jax


class FlowOverSphere:
Expand Down Expand Up @@ -118,7 +119,11 @@ def post_process(self, i):
else:
f_0 = self.f_0

macro = Macroscopic(compute_backend=ComputeBackend.JAX)
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D3Q19(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
)
rho, u = macro(f_0)

# remove boundary cells
Expand All @@ -135,9 +140,13 @@ def post_process(self, i):
if __name__ == "__main__":
# Running the simulation
grid_shape = (512 // 2, 128 // 2, 128 // 2)
velocity_set = xlb.velocity_set.D3Q19()
backend = ComputeBackend.WARP
precision_policy = PrecisionPolicy.FP32FP32

if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32:
jax.config.update("jax_enable_x64", True)

velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy)
Expand Down
14 changes: 12 additions & 2 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from xlb.operator.macroscopic import Macroscopic
from xlb.utils import save_fields_vtk, save_image
import warp as wp
import jax
import jax.numpy as jnp
import xlb.velocity_set


class LidDrivenCavity2D:
Expand Down Expand Up @@ -80,7 +82,11 @@ def post_process(self, i):
else:
f_0 = self.f_0

macro = Macroscopic(compute_backend=ComputeBackend.JAX)
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
)

rho, u = macro(f_0)

Expand All @@ -100,8 +106,12 @@ def post_process(self, i):
grid_size = 500
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.WARP
velocity_set = xlb.velocity_set.D2Q9()
precision_policy = PrecisionPolicy.FP32FP32

if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32:
jax.config.update("jax_enable_x64", True)

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy)
Expand Down
7 changes: 6 additions & 1 deletion examples/cfd/lid_driven_cavity_2d_distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import xlb
import jax
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.operator.stepper import IncompressibleNavierStokesStepper
Expand Down Expand Up @@ -26,8 +27,12 @@ def setup_stepper(self, omega):
grid_size = 512
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet!
velocity_set = xlb.velocity_set.D2Q9()
precision_policy = PrecisionPolicy.FP32FP32

if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32:
jax.config.update("jax_enable_x64", True)

velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend)
omega = 1.6

simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy)
Expand Down
8 changes: 6 additions & 2 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,11 @@ def post_process(self, i):
else:
f_0 = self.f_0

macro = Macroscopic(compute_backend=ComputeBackend.JAX)
macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=self.precision_policy,
velocity_set=xlb.velocity_set.D3Q27(precision_policy=self.precision_policy, backend=ComputeBackend.JAX),
)

rho, u = macro(f_0)

Expand Down Expand Up @@ -215,8 +219,8 @@ def plot_drag_coefficient(self):

# Configuration
backend = ComputeBackend.WARP
velocity_set = xlb.velocity_set.D3Q27()
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, backend=backend)
wind_speed = 0.02
num_steps = 100000
print_interval = 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/boundary_conditions/mask/test_bc_indices_masker_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape",
[
Expand Down
5 changes: 3 additions & 2 deletions tests/boundary_conditions/mask/test_bc_indices_masker_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
13 changes: 7 additions & 6 deletions tests/grids/test_grid_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import jax.numpy as jnp


def init_xlb_env():
def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=xlb.velocity_set.D2Q9, # does not affect the test
velocity_set=vel_set,
)


@pytest.mark.parametrize("grid_size", [50, 100, 150])
def test_jax_2d_grid_initialization(grid_size):
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)
grid_shape = (grid_size, grid_size)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9)
Expand All @@ -34,7 +35,7 @@ def test_jax_2d_grid_initialization(grid_size):

@pytest.mark.parametrize("grid_size", [50, 100, 150])
def test_jax_3d_grid_initialization(grid_size):
init_xlb_env()
init_xlb_env(xlb.velocity_set.D3Q19)
grid_shape = (grid_size, grid_size, grid_size)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9)
Expand All @@ -54,7 +55,7 @@ def test_jax_3d_grid_initialization(grid_size):


def test_jax_grid_create_field_fill_value():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)
grid_shape = (100, 100)
fill_value = 3.14
my_grid = grid_factory(grid_shape)
Expand All @@ -66,7 +67,7 @@ def test_jax_grid_create_field_fill_value():

@pytest.fixture(autouse=True)
def setup_xlb_env():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)


if __name__ == "__main__":
Expand Down
14 changes: 7 additions & 7 deletions tests/grids/test_grid_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
from xlb.precision_policy import Precision


def init_xlb_env():
def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=xlb.velocity_set.D2Q9,
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


@pytest.mark.parametrize("grid_size", [50, 100, 150])
def test_warp_grid_create_field(grid_size):
for grid_shape in [(grid_size, grid_size), (grid_size, grid_size, grid_size)]:
init_xlb_env()
init_xlb_env(xlb.velocity_set.D3Q19)
my_grid = grid_factory(grid_shape)
f = my_grid.create_field(cardinality=9, dtype=Precision.FP32)

Expand All @@ -27,7 +27,7 @@ def test_warp_grid_create_field(grid_size):


def test_warp_grid_create_field_fill_value():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)
grid_shape = (100, 100)
fill_value = 3.14
my_grid = grid_factory(grid_shape)
Expand All @@ -42,7 +42,7 @@ def test_warp_grid_create_field_fill_value():

@pytest.fixture(autouse=True)
def setup_xlb_env():
init_xlb_env()
init_xlb_env(xlb.velocity_set.D2Q9)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/collision/test_bgk_collision_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape,omega",
[
Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/collision/test_bgk_collision_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/equilibrium/test_equilibrium_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape",
[
Expand Down
7 changes: 3 additions & 4 deletions tests/kernels/equilibrium/test_equilibrium_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


@pytest.mark.parametrize(
"dim,velocity_set,grid_shape",
[
Expand Down
3 changes: 2 additions & 1 deletion tests/kernels/macroscopic/test_macroscopic_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.JAX,
velocity_set=velocity_set(),
velocity_set=vel_set,
)


Expand Down
5 changes: 3 additions & 2 deletions tests/kernels/macroscopic/test_macroscopic_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


def init_xlb_env(velocity_set):
vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP)
xlb.init(
default_precision_policy=xlb.PrecisionPolicy.FP32FP32,
default_backend=ComputeBackend.WARP,
velocity_set=velocity_set(),
default_backend=ComputeBackend.JAX,
velocity_set=vel_set,
)


Expand Down
Loading
Loading