diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index 3fad2b1..b26c558 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -89,11 +89,15 @@ def _construct_warp(self): _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool - # Get the boundary condition ids - _equilibrium_bc = wp.uint8(self.equilibrium_bc.id) - _do_nothing_bc = wp.uint8(self.do_nothing_bc.id) - _halfway_bounce_back_bc = wp.uint8(self.halfway_bounce_back_bc.id) - _fullway_bounce_back_bc = wp.uint8(self.fullway_bounce_back_bc.id) + @wp.struct + class BoundaryConditionIDStruct: + # Note the names are hardcoded here based on various BC operator names with "id_" at the beginning + # One needs to manually add the names of additional BC's as they are added. + # TODO: Anyway to improve this + id_EquilibriumBC: wp.uint8 + id_DoNothingBC: wp.uint8 + id_HalfwayBounceBackBC: wp.uint8 + id_FullwayBounceBackBC: wp.uint8 @wp.kernel def kernel2d( @@ -101,6 +105,7 @@ def kernel2d( f_1: wp.array3d(dtype=Any), boundary_mask: wp.array3d(dtype=Any), missing_mask: wp.array3d(dtype=Any), + bc_struct: BoundaryConditionIDStruct, timestep: int, ): # Get the global index @@ -117,17 +122,17 @@ def kernel2d( else: _missing_mask[l] = wp.uint8(0) - # Apply streaming boundary conditions - if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: - # Regular streaming - f_post_stream = self.stream.warp_functional(f_0, index) - elif _boundary_id == _equilibrium_bc: + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f_0, index) + + # Apply post-streaming type boundary conditions + if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) - elif _boundary_id == _do_nothing_bc: + elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) - elif _boundary_id == _halfway_bounce_back_bc: + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) @@ -145,8 +150,8 @@ def kernel2d( u, ) - # Apply collision type boundary conditions - if _boundary_id == _fullway_bounce_back_bc: + # Apply post-collision type boundary conditions + if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( f_post_stream, @@ -165,6 +170,7 @@ def kernel3d( f_1: wp.array4d(dtype=Any), boundary_mask: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), + bc_struct: BoundaryConditionIDStruct, timestep: int, ): # Get the global index @@ -181,17 +187,17 @@ def kernel3d( else: _missing_mask[l] = wp.uint8(0) - # Apply streaming boundary conditions - if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc: - # Regular streaming - f_post_stream = self.stream.warp_functional(f_0, index) - elif _boundary_id == _equilibrium_bc: + # Apply streaming (pull method) + f_post_stream = self.stream.warp_functional(f_0, index) + + # Apply post-streaming boundary conditions + if _boundary_id == bc_struct.id_EquilibriumBC: # Equilibrium boundary condition f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index) - elif _boundary_id == _do_nothing_bc: + elif _boundary_id == bc_struct.id_DoNothingBC: # Do nothing boundary condition f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index) - elif _boundary_id == _halfway_bounce_back_bc: + elif _boundary_id == bc_struct.id_HalfwayBounceBackBC: # Half way boundary condition f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index) @@ -205,7 +211,7 @@ def kernel3d( f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u) # Apply collision type boundary conditions - if _boundary_id == _fullway_bounce_back_bc: + if _boundary_id == bc_struct.id_FullwayBounceBackBC: # Full way boundary condition f_post_collision = self.fullway_bounce_back_bc.warp_functional( f_post_stream, @@ -220,10 +226,31 @@ def kernel3d( # Return the correct kernel kernel = kernel3d if self.velocity_set.d == 3 else kernel2d - return None, kernel + return BoundaryConditionIDStruct, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): + # Get the boundary condition ids + from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry + + bc_to_id = boundary_condition_registry.bc_to_id + + bc_struct = self.warp_functional() + bc_attribute_list = [] + for bc in self.boundary_conditions: + # Setting the Struct attributes based on the BC class names + attribute_str = bc.__class__.__name__ + setattr(bc_struct, "id_" + attribute_str, bc_to_id[attribute_str]) + bc_attribute_list.append("id_" + attribute_str) + + # Unused attributes of the struct are set to inernal (id=0) + ll = vars(bc_struct) + for var in ll: + if var not in bc_attribute_list and not var.startswith("_"): + # set unassigned boundaries to the maximum integer in uint8 + attribute_str = bc.__class__.__name__ + setattr(bc_struct, var, 255) + # Launch the warp kernel wp.launch( self.warp_kernel, @@ -232,6 +259,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep): f_1, boundary_mask, missing_mask, + bc_struct, timestep, ], dim=f_0.shape[1:], diff --git a/xlb/operator/stepper/stepper.py b/xlb/operator/stepper/stepper.py index adc2564..2127ea6 100644 --- a/xlb/operator/stepper/stepper.py +++ b/xlb/operator/stepper/stepper.py @@ -1,6 +1,5 @@ # Base class for all stepper operators from xlb.operator import Operator -from xlb.operator.equilibrium import Equilibrium from xlb import DefaultConfig @@ -26,63 +25,35 @@ def __init__(self, operators, boundary_conditions): compute_backend = DefaultConfig.default_backend if not compute_backends else compute_backends.pop() # Add boundary conditions - # Warp cannot handle lists of functions currently - # Because of this we manually unpack the boundary conditions ############################################ + # Warp cannot handle lists of functions currently # TODO: Fix this later ############################################ from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC - from xlb.operator.boundary_condition.bc_halfway_bounce_back import ( - HalfwayBounceBackBC, - ) - from xlb.operator.boundary_condition.bc_fullway_bounce_back import ( - FullwayBounceBackBC, - ) - - self.equilibrium_bc = None - self.do_nothing_bc = None - self.halfway_bounce_back_bc = None - self.fullway_bounce_back_bc = None - - for bc in boundary_conditions: - if isinstance(bc, EquilibriumBC): - self.equilibrium_bc = bc - elif isinstance(bc, DoNothingBC): - self.do_nothing_bc = bc - elif isinstance(bc, HalfwayBounceBackBC): - self.halfway_bounce_back_bc = bc - elif isinstance(bc, FullwayBounceBackBC): - self.fullway_bounce_back_bc = bc + from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC + from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC + + # Define a list of tuples with attribute names and their corresponding classes + conditions = [ + ("equilibrium_bc", EquilibriumBC), + ("do_nothing_bc", DoNothingBC), + ("halfway_bounce_back_bc", HalfwayBounceBackBC), + ("fullway_bounce_back_bc", FullwayBounceBackBC), + ] + + # this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example. + bc_fallback = boundary_conditions[0] + + # Iterate over each boundary condition + for attr_name, bc_class in conditions: + for bc in boundary_conditions: + if isinstance(bc, bc_class): + setattr(self, attr_name, bc) + break + elif not hasattr(self, attr_name): + setattr(self, attr_name, bc_fallback) - if self.equilibrium_bc is None: - # Select the equilibrium operator based on its type - self.equilibrium_bc = EquilibriumBC( - rho=1.0, - u=(0.0, 0.0, 0.0), - equilibrium_operator=next((op for op in self.operators if isinstance(op, Equilibrium)), None), - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.do_nothing_bc is None: - self.do_nothing_bc = DoNothingBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.halfway_bounce_back_bc is None: - self.halfway_bounce_back_bc = HalfwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) - if self.fullway_bounce_back_bc is None: - self.fullway_bounce_back_bc = FullwayBounceBackBC( - velocity_set=velocity_set, - precision_policy=precision_policy, - compute_backend=compute_backend, - ) ############################################ # Initialize operator