diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 7d07e05..5a70d20 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -19,6 +19,7 @@ import numpy as np import jax.numpy as jnp import time +import jax class FlowOverSphere: @@ -118,7 +119,11 @@ def post_process(self, i): else: f_0 = self.f_0 - macro = Macroscopic(compute_backend=ComputeBackend.JAX) + macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=self.precision_policy, + velocity_set=xlb.velocity_set.D3Q19(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), + ) rho, u = macro(f_0) # remove boundary cells @@ -135,9 +140,13 @@ def post_process(self, i): if __name__ == "__main__": # Running the simulation grid_shape = (512 // 2, 128 // 2, 128 // 2) - velocity_set = xlb.velocity_set.D3Q19() backend = ComputeBackend.WARP precision_policy = PrecisionPolicy.FP32FP32 + + if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: + jax.config.update("jax_enable_x64", True) + + velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend) omega = 1.6 simulation = FlowOverSphere(omega, grid_shape, velocity_set, backend, precision_policy) diff --git a/examples/cfd/lid_driven_cavity_2d.py b/examples/cfd/lid_driven_cavity_2d.py index 16fb4f9..f73fb0e 100644 --- a/examples/cfd/lid_driven_cavity_2d.py +++ b/examples/cfd/lid_driven_cavity_2d.py @@ -8,7 +8,9 @@ from xlb.operator.macroscopic import Macroscopic from xlb.utils import save_fields_vtk, save_image import warp as wp +import jax import jax.numpy as jnp +import xlb.velocity_set class LidDrivenCavity2D: @@ -80,7 +82,11 @@ def post_process(self, i): else: f_0 = self.f_0 - macro = Macroscopic(compute_backend=ComputeBackend.JAX) + macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=self.precision_policy, + velocity_set=xlb.velocity_set.D2Q9(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), + ) rho, u = macro(f_0) @@ -100,8 +106,12 @@ def post_process(self, i): grid_size = 500 grid_shape = (grid_size, grid_size) backend = ComputeBackend.WARP - velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 + + if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: + jax.config.update("jax_enable_x64", True) + + velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) omega = 1.6 simulation = LidDrivenCavity2D(omega, grid_shape, velocity_set, backend, precision_policy) diff --git a/examples/cfd/lid_driven_cavity_2d_distributed.py b/examples/cfd/lid_driven_cavity_2d_distributed.py index 225d6bd..7efe907 100644 --- a/examples/cfd/lid_driven_cavity_2d_distributed.py +++ b/examples/cfd/lid_driven_cavity_2d_distributed.py @@ -1,4 +1,5 @@ import xlb +import jax from xlb.compute_backend import ComputeBackend from xlb.precision_policy import PrecisionPolicy from xlb.operator.stepper import IncompressibleNavierStokesStepper @@ -26,8 +27,12 @@ def setup_stepper(self, omega): grid_size = 512 grid_shape = (grid_size, grid_size) backend = ComputeBackend.JAX # Must be JAX for distributed multi-GPU computations. Distributed computations on WARP are not supported yet! - velocity_set = xlb.velocity_set.D2Q9() precision_policy = PrecisionPolicy.FP32FP32 + + if precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32: + jax.config.update("jax_enable_x64", True) + + velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, backend=backend) omega = 1.6 simulation = LidDrivenCavity2D_distributed(omega, grid_shape, velocity_set, backend, precision_policy) diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 8395579..3522122 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -151,7 +151,11 @@ def post_process(self, i): else: f_0 = self.f_0 - macro = Macroscopic(compute_backend=ComputeBackend.JAX) + macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=self.precision_policy, + velocity_set=xlb.velocity_set.D3Q27(precision_policy=self.precision_policy, backend=ComputeBackend.JAX), + ) rho, u = macro(f_0) @@ -215,8 +219,8 @@ def plot_drag_coefficient(self): # Configuration backend = ComputeBackend.WARP - velocity_set = xlb.velocity_set.D3Q27() precision_policy = PrecisionPolicy.FP32FP32 + velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, backend=backend) wind_speed = 0.02 num_steps = 100000 print_interval = 1000 diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py index 9d2e4ff..bd40dfb 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py @@ -9,10 +9,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 917e7e4..9f0cd68 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -8,10 +8,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py index 2fe0b40..dde05bb 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py @@ -9,10 +9,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index b25d39e..96e2f21 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -9,10 +9,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py index af121d3..9325890 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_jax.py @@ -8,13 +8,13 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) - @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 4d02540..43ad052 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -7,10 +7,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) diff --git a/tests/grids/test_grid_jax.py b/tests/grids/test_grid_jax.py index edd9dd0..dd74da6 100644 --- a/tests/grids/test_grid_jax.py +++ b/tests/grids/test_grid_jax.py @@ -8,17 +8,18 @@ import jax.numpy as jnp -def init_xlb_env(): +def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=xlb.velocity_set.D2Q9, # does not affect the test + velocity_set=vel_set, ) @pytest.mark.parametrize("grid_size", [50, 100, 150]) def test_jax_2d_grid_initialization(grid_size): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D2Q9) grid_shape = (grid_size, grid_size) my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9) @@ -34,7 +35,7 @@ def test_jax_2d_grid_initialization(grid_size): @pytest.mark.parametrize("grid_size", [50, 100, 150]) def test_jax_3d_grid_initialization(grid_size): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D3Q19) grid_shape = (grid_size, grid_size, grid_size) my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9) @@ -54,7 +55,7 @@ def test_jax_3d_grid_initialization(grid_size): def test_jax_grid_create_field_fill_value(): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D2Q9) grid_shape = (100, 100) fill_value = 3.14 my_grid = grid_factory(grid_shape) @@ -66,7 +67,7 @@ def test_jax_grid_create_field_fill_value(): @pytest.fixture(autouse=True) def setup_xlb_env(): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D2Q9) if __name__ == "__main__": diff --git a/tests/grids/test_grid_warp.py b/tests/grids/test_grid_warp.py index 22445cc..11c8b2a 100644 --- a/tests/grids/test_grid_warp.py +++ b/tests/grids/test_grid_warp.py @@ -7,18 +7,18 @@ from xlb.precision_policy import Precision -def init_xlb_env(): +def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=xlb.velocity_set.D2Q9, + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) - @pytest.mark.parametrize("grid_size", [50, 100, 150]) def test_warp_grid_create_field(grid_size): for grid_shape in [(grid_size, grid_size), (grid_size, grid_size, grid_size)]: - init_xlb_env() + init_xlb_env(xlb.velocity_set.D3Q19) my_grid = grid_factory(grid_shape) f = my_grid.create_field(cardinality=9, dtype=Precision.FP32) @@ -27,7 +27,7 @@ def test_warp_grid_create_field(grid_size): def test_warp_grid_create_field_fill_value(): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D2Q9) grid_shape = (100, 100) fill_value = 3.14 my_grid = grid_factory(grid_shape) @@ -42,7 +42,7 @@ def test_warp_grid_create_field_fill_value(): @pytest.fixture(autouse=True) def setup_xlb_env(): - init_xlb_env() + init_xlb_env(xlb.velocity_set.D2Q9) if __name__ == "__main__": diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index 5a400e0..aebc726 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -9,13 +9,13 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) - @pytest.mark.parametrize( "dim,velocity_set,grid_shape,omega", [ diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index 522ea33..2743050 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -9,10 +9,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) diff --git a/tests/kernels/equilibrium/test_equilibrium_jax.py b/tests/kernels/equilibrium/test_equilibrium_jax.py index 07bafe7..50418bc 100644 --- a/tests/kernels/equilibrium/test_equilibrium_jax.py +++ b/tests/kernels/equilibrium/test_equilibrium_jax.py @@ -8,13 +8,13 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) - @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ diff --git a/tests/kernels/equilibrium/test_equilibrium_warp.py b/tests/kernels/equilibrium/test_equilibrium_warp.py index 063a723..9759fb2 100644 --- a/tests/kernels/equilibrium/test_equilibrium_warp.py +++ b/tests/kernels/equilibrium/test_equilibrium_warp.py @@ -8,13 +8,12 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) - - @pytest.mark.parametrize( "dim,velocity_set,grid_shape", [ diff --git a/tests/kernels/macroscopic/test_macroscopic_jax.py b/tests/kernels/macroscopic/test_macroscopic_jax.py index 50d1735..2c2ad55 100644 --- a/tests/kernels/macroscopic/test_macroscopic_jax.py +++ b/tests/kernels/macroscopic/test_macroscopic_jax.py @@ -8,10 +8,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.JAX) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, default_backend=ComputeBackend.JAX, - velocity_set=velocity_set(), + velocity_set=vel_set, ) diff --git a/tests/kernels/macroscopic/test_macroscopic_warp.py b/tests/kernels/macroscopic/test_macroscopic_warp.py index d98a014..6a97927 100644 --- a/tests/kernels/macroscopic/test_macroscopic_warp.py +++ b/tests/kernels/macroscopic/test_macroscopic_warp.py @@ -9,10 +9,11 @@ def init_xlb_env(velocity_set): + vel_set = velocity_set(precision_policy=xlb.PrecisionPolicy.FP32FP32, backend=ComputeBackend.WARP) xlb.init( default_precision_policy=xlb.PrecisionPolicy.FP32FP32, - default_backend=ComputeBackend.WARP, - velocity_set=velocity_set(), + default_backend=ComputeBackend.JAX, + velocity_set=vel_set, ) diff --git a/xlb/default_config.py b/xlb/default_config.py index f1ca25f..20eac44 100644 --- a/xlb/default_config.py +++ b/xlb/default_config.py @@ -1,5 +1,7 @@ +import jax from xlb.compute_backend import ComputeBackend from dataclasses import dataclass +from xlb.precision_policy import PrecisionPolicy @dataclass @@ -17,7 +19,7 @@ def init(velocity_set, default_backend, default_precision_policy): if default_backend == ComputeBackend.WARP: import warp as wp - wp.init() + wp.init() # TODO: Must be removed in the future versions of WARP elif default_backend == ComputeBackend.JAX: check_multi_gpu_support() else: @@ -29,8 +31,6 @@ def default_backend() -> ComputeBackend: def check_multi_gpu_support(): - import jax - gpus = jax.devices("gpu") if len(gpus) > 1: print("Multi-GPU support is available: {} GPUs detected.".format(len(gpus))) diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 6853c0e..c018a60 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -73,7 +73,7 @@ def jax_implementation(self, f_pre, f_post, boundary_map, missing_mask): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _rho = wp.float32(self.rho) + _rho = self.compute_dtype(self.rho) _u = _u_vec(self.u[0], self.u[1], self.u[2]) if self.velocity_set.d == 3 else _u_vec(self.u[0], self.u[1]) # Construct the functional for this BC diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 55f094d..7bd447e 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -137,9 +137,9 @@ def _construct_warp(self): # Set local constants sound_speed = 1.0 / wp.sqrt(3.0) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _q = self.velocity_set.q - _opp_indices = self.velocity_set.wp_opp_indices + _opp_indices = self.velocity_set.opp_indices @wp.func def get_normal_vectors_2d( diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 6af4226..729dbb4 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -56,7 +56,7 @@ def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _opp_indices = self.velocity_set.wp_opp_indices + _opp_indices = self.velocity_set.opp_indices _q = wp.constant(self.velocity_set.q) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 5c001d9..2cb60fa 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -63,7 +63,7 @@ def apply_jax(self, f_pre, f_post, boundary_map, missing_mask): def _construct_warp(self): # Set local constants - _opp_indices = self.velocity_set.wp_opp_indices + _opp_indices = self.velocity_set.opp_indices # Construct the functional for this BC @wp.func diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index b74c0b1..9734368 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -136,14 +136,14 @@ def _construct_warp(self): # compute Qi tensor and store it in self _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _rho = wp.float32(rho) + _rho = self.compute_dtype(rho) _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) - _opp_indices = self.velocity_set.wp_opp_indices - _w = self.velocity_set.wp_w - _c = self.velocity_set.wp_c - _c32 = self.velocity_set.wp_c32 - _qi = self.velocity_set.wp_qi - # TODO: related to _c32: this is way less than ideal. we should not be making new types + _opp_indices = self.velocity_set.opp_indices + _w = self.velocity_set.w + _c = self.velocity_set.c + _c_float = self.velocity_set.c_float + _qi = self.velocity_set.qi + # TODO: related to _c_float: this is way less than ideal. we should not be making new types @wp.func def _get_fsum( @@ -165,7 +165,7 @@ def get_normal_vectors_2d( ): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - return -_u_vec(_c32[0, l], _c32[1, l]) + return -_u_vec(_c_float[0, l], _c_float[1, l]) @wp.func def get_normal_vectors_3d( @@ -173,7 +173,7 @@ def get_normal_vectors_3d( ): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) + return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) @wp.func def bounceback_nonequilibrium( diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 56c6868..782eb4c 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -180,11 +180,11 @@ def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update # _u_vec = wp.vec(_d, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _rho = wp.float32(rho) + _rho = self.compute_dtype(rho) _u = _u_vec(u[0], u[1], u[2]) if _d == 3 else _u_vec(u[0], u[1]) - _opp_indices = self.velocity_set.wp_opp_indices - _c = self.velocity_set.wp_c - _c32 = self.velocity_set.wp_c32 + _opp_indices = self.velocity_set.opp_indices + _c = self.velocity_set.c + _c_float = self.velocity_set.c_float # TODO: this is way less than ideal. we should not be making new types @wp.func @@ -193,7 +193,7 @@ def get_normal_vectors_2d( ): l = lattice_direction if wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: - normals = -_u_vec(_c32[0, l], _c32[1, l]) + normals = -_u_vec(_c_float[0, l], _c_float[1, l]) return normals @wp.func @@ -216,7 +216,7 @@ def get_normal_vectors_3d( ): for l in range(_q): if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: - return -_u_vec(_c32[0, l], _c32[1, l], _c32[2, l]) + return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) @wp.func def bounceback_nonequilibrium( diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index 208f50f..54cb7aa 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -97,7 +97,7 @@ def jax_implementation(self, bclist, boundary_map, missing_mask, start_index=Non def _construct_warp(self): # Make constants for warp - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _q = wp.constant(self.velocity_set.q) # Construct the warp 2D kernel diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index 366c9d6..c43ea02 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -47,7 +47,7 @@ def jax_implementation( def _construct_warp(self): # Make constants for warp - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _q = wp.constant(self.velocity_set.q) # Construct the warp kernel diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index 9dbfabd..196e3ba 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -23,7 +23,7 @@ def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _w = self.velocity_set.wp_w + _w = self.velocity_set.w _omega = wp.constant(self.compute_dtype(self.omega)) _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index ddd7ecc..748ebea 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -251,7 +251,7 @@ def entropic_scalar_product( feq: Any, ): e = wp.cw_div(wp.cw_mul(x, y), feq) - e_sum = wp.float32(0.0) + e_sum = self.compute_dtype(0.0) for i in range(self.velocity_set.q): e_sum += e[i] return e_sum diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index 3af6b4a..0cce91b 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -26,8 +26,8 @@ def jax_implementation(self, rho, u): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _c = self.velocity_set.wp_c - _w = self.velocity_set.wp_w + _c = self.velocity_set.c + _w = self.velocity_set.w _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) @@ -52,10 +52,10 @@ def functional( cu *= self.compute_dtype(3.0) # Compute usqr - usqr = 1.5 * wp.dot(u, u) + usqr = self.compute_dtype(1.5) * wp.dot(u, u) # Compute feq - feq[l] = rho * _w[l] * (1.0 + cu * (1.0 + 0.5 * cu) - usqr) + feq[l] = rho * _w[l] * (self.compute_dtype(1.0) + cu * (self.compute_dtype(1.0) + self.compute_dtype(0.5) * cu) - usqr) return feq diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 66dba13..d25baf7 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -87,8 +87,8 @@ def jax_implementation(self, f, boundary_map, missing_mask): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _c = self.velocity_set.wp_c - _opp_indices = self.velocity_set.wp_opp_indices + _c = self.velocity_set.c + _opp_indices = self.velocity_set.opp_indices _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool _no_slip_id = self.no_slip_bc_instance.id @@ -144,7 +144,7 @@ def kernel2d( if _missing_mask[l] == wp.uint8(1): phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] for d in range(self.velocity_set.d): - m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) + m[d] += phi * self.compute_dtype(_c[d, _opp_indices[l]]) wp.atomic_add(force, 0, m) @@ -193,7 +193,7 @@ def kernel3d( if _missing_mask[l] == wp.uint8(1): phi = f_post_collision[_opp_indices[l]] + f_post_stream[l] for d in range(self.velocity_set.d): - m[d] += phi * wp.float32(_c[d, _opp_indices[l]]) + m[d] += phi * self.compute_dtype(_c[d, _opp_indices[l]]) wp.atomic_add(force, 0, m) diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py index db8fce6..5209d69 100644 --- a/xlb/operator/macroscopic/second_moment.py +++ b/xlb/operator/macroscopic/second_moment.py @@ -56,7 +56,7 @@ def jax_implementation( def _construct_warp(self): # Make constants for warp - _cc = self.velocity_set.wp_cc + _cc = self.velocity_set.cc _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _pi_dim = self.velocity_set.d * (self.velocity_set.d + 1) // 2 _pi_vec = wp.vec( diff --git a/xlb/operator/macroscopic/zero_first_moments.py b/xlb/operator/macroscopic/zero_first_moments.py index fbf7c93..48cf108 100644 --- a/xlb/operator/macroscopic/zero_first_moments.py +++ b/xlb/operator/macroscopic/zero_first_moments.py @@ -46,7 +46,7 @@ def jax_implementation(self, f): def _construct_warp(self): # Make constants for warp - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index ba3e294..60f610a 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -90,9 +90,9 @@ def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _q = self.velocity_set.q - _opp_indices = self.velocity_set.wp_opp_indices + _opp_indices = self.velocity_set.opp_indices sound_speed = 1.0 / wp.sqrt(3.0) @wp.struct diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index f91b567..d96c307 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -52,7 +52,7 @@ def _streaming_jax_i(f, c): def _construct_warp(self): # Set local constants TODO: This is a hack and should be fixed with warp update - _c = self.velocity_set.wp_c + _c = self.velocity_set.c _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) # Construct the warp functional diff --git a/xlb/velocity_set/d2q9.py b/xlb/velocity_set/d2q9.py index 700806c..5324618 100644 --- a/xlb/velocity_set/d2q9.py +++ b/xlb/velocity_set/d2q9.py @@ -13,7 +13,7 @@ class D2Q9(VelocitySet): Lattice Boltzmann Method for simulating fluid flows in two dimensions. """ - def __init__(self): + def __init__(self, precision_policy, backend): # Construct the velocity vectors and weights cx = [0, 0, 0, 1, -1, 1, -1, 1, -1] cy = [0, 1, -1, 0, 1, -1, 0, 1, -1] @@ -21,4 +21,4 @@ def __init__(self): w = np.array([4 / 9, 1 / 9, 1 / 9, 1 / 9, 1 / 36, 1 / 36, 1 / 9, 1 / 36, 1 / 36]) # Call the parent constructor - super().__init__(2, 9, c, w) + super().__init__(2, 9, c, w, precision_policy=precision_policy, backend=backend) diff --git a/xlb/velocity_set/d3q19.py b/xlb/velocity_set/d3q19.py index 97db1d9..4a48c2f 100644 --- a/xlb/velocity_set/d3q19.py +++ b/xlb/velocity_set/d3q19.py @@ -14,7 +14,7 @@ class D3Q19(VelocitySet): Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ - def __init__(self): + def __init__(self, precision_policy, backend): # Construct the velocity vectors and weights c = np.array([ci for ci in itertools.product([-1, 0, 1], repeat=3) if np.sum(np.abs(ci)) <= 2]).T w = np.zeros(19) @@ -27,4 +27,4 @@ def __init__(self): w[i] = 1.0 / 36.0 # Initialize the lattice - super().__init__(3, 19, c, w) + super().__init__(3, 19, c, w, precision_policy=precision_policy, backend=backend) diff --git a/xlb/velocity_set/d3q27.py b/xlb/velocity_set/d3q27.py index 702acf4..b056d53 100644 --- a/xlb/velocity_set/d3q27.py +++ b/xlb/velocity_set/d3q27.py @@ -14,7 +14,7 @@ class D3Q27(VelocitySet): Lattice Boltzmann Method for simulating fluid flows in three dimensions. """ - def __init__(self): + def __init__(self, precision_policy, backend): # Construct the velocity vectors and weights c = np.array(list(itertools.product([0, -1, 1], repeat=3))).T w = np.zeros(27) @@ -29,4 +29,4 @@ def __init__(self): w[i] = 1.0 / 216.0 # Initialize the Lattice - super().__init__(3, 27, c, w) + super().__init__(3, 27, c, w, precision_policy=precision_policy, backend=backend) diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index a93d039..2405b36 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -2,9 +2,11 @@ import math import numpy as np - import warp as wp +import jax.numpy as jnp +from xlb import DefaultConfig +from xlb.compute_backend import ComputeBackend class VelocitySet(object): """ @@ -22,35 +24,86 @@ class VelocitySet(object): The weights of the lattice. Shape: (q,) """ - def __init__(self, d, q, c, w): + def __init__(self, d, q, c, w, precision_policy, backend): # Store the dimension and the number of velocities self.d = d self.q = q + self.precision_policy = precision_policy + self.backend = backend + + # Create all properties in NumPy first + self._init_numpy_properties(c, w) + + # Convert properties to backend-specific format + if self.backend == ComputeBackend.WARP: + self._init_warp_properties() + elif self.backend == ComputeBackend.JAX: + self._init_jax_properties() + else: + raise ValueError(f"Unsupported compute backend: {self.backend}") + + # Set up backend-specific constants + self._init_backend_constants() - # Constants - self.cs = math.sqrt(3) / 3.0 - self.cs2 = 1.0 / 3.0 - self.inv_cs2 = 3.0 - - # Construct the properties of the lattice - self.c = c - self.w = w - self.cc = self._construct_lattice_moment() - self.opp_indices = self._construct_opposite_indices() - self.get_opp_index = lambda i: self.opp_indices[i] + def _init_numpy_properties(self, c, w): + """ + Initialize all properties in NumPy first. + """ + self._c = np.array(c) + self._w = np.array(w) + self._opp_indices = self._construct_opposite_indices() + self._cc = self._construct_lattice_moment() + self._c_float = self._c.astype(np.float64) + self._qi = self._construct_qi() + + # Constants in NumPy + self.cs = np.float64(math.sqrt(3) / 3.0) + self.cs2 = np.float64(1.0 / 3.0) + self.inv_cs2 = np.float64(3.0) + + # Indices self.main_indices = self._construct_main_indices() self.right_indices = self._construct_right_indices() self.left_indices = self._construct_left_indices() - self.qi = self._construct_qi() - # Make warp constants for these vectors - # TODO: Following warp updates these may not be necessary - self.wp_c = wp.constant(wp.mat((self.d, self.q), dtype=wp.int32)(self.c)) - self.wp_w = wp.constant(wp.vec(self.q, dtype=wp.float32)(self.w)) # TODO: Make type optional somehow - self.wp_opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self.opp_indices)) - self.wp_cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.cc)) - self.wp_c32 = wp.constant(wp.mat((self.d, self.q), dtype=wp.float32)(self.c)) - self.wp_qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=wp.float32)(self.qi)) + def _init_warp_properties(self): + """ + Convert NumPy properties to Warp-specific properties. + """ + dtype = self.precision_policy.compute_precision.wp_dtype + self.c = wp.constant(wp.mat((self.d, self.q), dtype=wp.int32)(self._c)) + self.w = wp.constant(wp.vec(self.q, dtype=dtype)(self._w)) + self.opp_indices = wp.constant(wp.vec(self.q, dtype=wp.int32)(self._opp_indices)) + self.cc = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=dtype)(self._cc)) + self.c_float = wp.constant(wp.mat((self.d, self.q), dtype=dtype)(self._c_float)) + self.qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=dtype)(self._qi)) + + def _init_jax_properties(self): + """ + Convert NumPy properties to JAX-specific properties. + """ + dtype = self.precision_policy.compute_precision.jax_dtype + self.c = jnp.array(self._c, dtype=dtype) + self.w = jnp.array(self._w, dtype=dtype) + self.opp_indices = jnp.array(self._opp_indices, dtype=jnp.int32) + self.cc = jnp.array(self._cc, dtype=dtype) + self.c_float = jnp.array(self._c_float, dtype=dtype) + self.qi = jnp.array(self._qi, dtype=dtype) + + def _init_backend_constants(self): + """ + Initialize the constants for the backend. + """ + if self.backend == ComputeBackend.WARP: + dtype = self.precision_policy.compute_precision.wp_dtype + self.cs = wp.constant(dtype(self.cs)) + self.cs2 = wp.constant(dtype(self.cs2)) + self.inv_cs2 = wp.constant(dtype(self.inv_cs2)) + elif self.backend == ComputeBackend.JAX: + dtype = self.precision_policy.compute_precision.jax_dtype + self.cs = jnp.array(self.cs, dtype=dtype) + self.cs2 = jnp.array(self.cs2, dtype=dtype) + self.inv_cs2 = jnp.array(self.inv_cs2, dtype=dtype) def warp_lattice_vec(self, dtype): return wp.vec(len(self.c), dtype=dtype) @@ -64,13 +117,11 @@ def warp_stream_mat(self, dtype): def _construct_qi(self): # Qi = cc - cs^2*I dim = self.d - Qi = self.cc.copy() + Qi = self._cc.copy() if dim == 3: - diagonal = (0, 3, 5) - offdiagonal = (1, 2, 4) + diagonal, offdiagonal = (0, 3, 5), (1, 2, 4) elif dim == 2: - diagonal = (0, 2) - offdiagonal = (1,) + diagonal, offdiagonal = (0, 2), (1,) else: raise ValueError(f"dim = {dim} not supported") @@ -92,19 +143,18 @@ def _construct_lattice_moment(self): cc: numpy.ndarray The moments of the lattice. """ - c = self.c.T + c = self._c.T # Counter for the loop cntr = 0 - + c = self._c.T # nt: number of independent elements of a symmetric tensor nt = self.d * (self.d + 1) // 2 - cc = np.zeros((self.q, nt)) - for a in range(0, self.d): + cntr = 0 + for a in range(self.d): for b in range(a, self.d): cc[:, cntr] = c[:, a] * c[:, b] cntr += 1 - return cc def _construct_opposite_indices(self): @@ -119,9 +169,8 @@ def _construct_opposite_indices(self): opposite: numpy.ndarray The indices of the opposite velocities. """ - c = self.c.T - opposite = np.array([c.tolist().index((-c[i]).tolist()) for i in range(self.q)]) - return opposite + c = self._c.T + return np.array([c.tolist().index((-c[i]).tolist()) for i in range(self.q)]) def _construct_main_indices(self): """ @@ -134,10 +183,9 @@ def _construct_main_indices(self): numpy.ndarray The indices of the main velocities. """ - c = self.c.T + c = self._c.T if self.d == 2: return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) == 1))[0] - elif self.d == 3: return np.nonzero((np.abs(c[:, 0]) + np.abs(c[:, 1]) + np.abs(c[:, 2]) == 1))[0] @@ -151,8 +199,7 @@ def _construct_right_indices(self): numpy.ndarray The indices of the right velocities. """ - c = self.c.T - return np.nonzero(c[:, 0] == 1)[0] + return np.nonzero(self._c.T[:, 0] == 1)[0] def _construct_left_indices(self): """ @@ -164,8 +211,7 @@ def _construct_left_indices(self): numpy.ndarray The indices of the left velocities. """ - c = self.c.T - return np.nonzero(c[:, 0] == -1)[0] + return np.nonzero(self._c.T[:, 0] == -1)[0] def __str__(self): """ @@ -178,3 +224,4 @@ def __repr__(self): This function returns the name of the lattice in the format of DxQy. """ return "D{}Q{}".format(self.d, self.q) +