Skip to content

Commit

Permalink
Warp structs (#56)
Browse files Browse the repository at this point in the history
* somewhat improved bc handling using structs

* minor clean up. Warp MLUPs is not affected.

* added missing ruff formatting
  • Loading branch information
hsalehipour authored Aug 2, 2024
1 parent e07976e commit 329bd4c
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 75 deletions.
74 changes: 51 additions & 23 deletions xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,23 @@ def _construct_warp(self):
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool

# Get the boundary condition ids
_equilibrium_bc = wp.uint8(self.equilibrium_bc.id)
_do_nothing_bc = wp.uint8(self.do_nothing_bc.id)
_halfway_bounce_back_bc = wp.uint8(self.halfway_bounce_back_bc.id)
_fullway_bounce_back_bc = wp.uint8(self.fullway_bounce_back_bc.id)
@wp.struct
class BoundaryConditionIDStruct:
# Note the names are hardcoded here based on various BC operator names with "id_" at the beginning
# One needs to manually add the names of additional BC's as they are added.
# TODO: Anyway to improve this
id_EquilibriumBC: wp.uint8
id_DoNothingBC: wp.uint8
id_HalfwayBounceBackBC: wp.uint8
id_FullwayBounceBackBC: wp.uint8

@wp.kernel
def kernel2d(
f_0: wp.array3d(dtype=Any),
f_1: wp.array3d(dtype=Any),
boundary_mask: wp.array3d(dtype=Any),
missing_mask: wp.array3d(dtype=Any),
bc_struct: BoundaryConditionIDStruct,
timestep: int,
):
# Get the global index
Expand All @@ -117,17 +122,17 @@ def kernel2d(
else:
_missing_mask[l] = wp.uint8(0)

# Apply streaming boundary conditions
if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc:
# Regular streaming
f_post_stream = self.stream.warp_functional(f_0, index)
elif _boundary_id == _equilibrium_bc:
# Apply streaming (pull method)
f_post_stream = self.stream.warp_functional(f_0, index)

# Apply post-streaming type boundary conditions
if _boundary_id == bc_struct.id_EquilibriumBC:
# Equilibrium boundary condition
f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index)
elif _boundary_id == _do_nothing_bc:
elif _boundary_id == bc_struct.id_DoNothingBC:
# Do nothing boundary condition
f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index)
elif _boundary_id == _halfway_bounce_back_bc:
elif _boundary_id == bc_struct.id_HalfwayBounceBackBC:
# Half way boundary condition
f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index)

Expand All @@ -145,8 +150,8 @@ def kernel2d(
u,
)

# Apply collision type boundary conditions
if _boundary_id == _fullway_bounce_back_bc:
# Apply post-collision type boundary conditions
if _boundary_id == bc_struct.id_FullwayBounceBackBC:
# Full way boundary condition
f_post_collision = self.fullway_bounce_back_bc.warp_functional(
f_post_stream,
Expand All @@ -165,6 +170,7 @@ def kernel3d(
f_1: wp.array4d(dtype=Any),
boundary_mask: wp.array4d(dtype=Any),
missing_mask: wp.array4d(dtype=Any),
bc_struct: BoundaryConditionIDStruct,
timestep: int,
):
# Get the global index
Expand All @@ -181,17 +187,17 @@ def kernel3d(
else:
_missing_mask[l] = wp.uint8(0)

# Apply streaming boundary conditions
if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc:
# Regular streaming
f_post_stream = self.stream.warp_functional(f_0, index)
elif _boundary_id == _equilibrium_bc:
# Apply streaming (pull method)
f_post_stream = self.stream.warp_functional(f_0, index)

# Apply post-streaming boundary conditions
if _boundary_id == bc_struct.id_EquilibriumBC:
# Equilibrium boundary condition
f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index)
elif _boundary_id == _do_nothing_bc:
elif _boundary_id == bc_struct.id_DoNothingBC:
# Do nothing boundary condition
f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index)
elif _boundary_id == _halfway_bounce_back_bc:
elif _boundary_id == bc_struct.id_HalfwayBounceBackBC:
# Half way boundary condition
f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index)

Expand All @@ -205,7 +211,7 @@ def kernel3d(
f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u)

# Apply collision type boundary conditions
if _boundary_id == _fullway_bounce_back_bc:
if _boundary_id == bc_struct.id_FullwayBounceBackBC:
# Full way boundary condition
f_post_collision = self.fullway_bounce_back_bc.warp_functional(
f_post_stream,
Expand All @@ -220,10 +226,31 @@ def kernel3d(
# Return the correct kernel
kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return None, kernel
return BoundaryConditionIDStruct, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep):
# Get the boundary condition ids
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry

bc_to_id = boundary_condition_registry.bc_to_id

bc_struct = self.warp_functional()
bc_attribute_list = []
for bc in self.boundary_conditions:
# Setting the Struct attributes based on the BC class names
attribute_str = bc.__class__.__name__
setattr(bc_struct, "id_" + attribute_str, bc_to_id[attribute_str])
bc_attribute_list.append("id_" + attribute_str)

# Unused attributes of the struct are set to inernal (id=0)
ll = vars(bc_struct)
for var in ll:
if var not in bc_attribute_list and not var.startswith("_"):
# set unassigned boundaries to the maximum integer in uint8
attribute_str = bc.__class__.__name__
setattr(bc_struct, var, 255)

# Launch the warp kernel
wp.launch(
self.warp_kernel,
Expand All @@ -232,6 +259,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep):
f_1,
boundary_mask,
missing_mask,
bc_struct,
timestep,
],
dim=f_0.shape[1:],
Expand Down
75 changes: 23 additions & 52 deletions xlb/operator/stepper/stepper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Base class for all stepper operators
from xlb.operator import Operator
from xlb.operator.equilibrium import Equilibrium
from xlb import DefaultConfig


Expand All @@ -26,63 +25,35 @@ def __init__(self, operators, boundary_conditions):
compute_backend = DefaultConfig.default_backend if not compute_backends else compute_backends.pop()

# Add boundary conditions
# Warp cannot handle lists of functions currently
# Because of this we manually unpack the boundary conditions
############################################
# Warp cannot handle lists of functions currently
# TODO: Fix this later
############################################
from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC
from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC
from xlb.operator.boundary_condition.bc_halfway_bounce_back import (
HalfwayBounceBackBC,
)
from xlb.operator.boundary_condition.bc_fullway_bounce_back import (
FullwayBounceBackBC,
)

self.equilibrium_bc = None
self.do_nothing_bc = None
self.halfway_bounce_back_bc = None
self.fullway_bounce_back_bc = None

for bc in boundary_conditions:
if isinstance(bc, EquilibriumBC):
self.equilibrium_bc = bc
elif isinstance(bc, DoNothingBC):
self.do_nothing_bc = bc
elif isinstance(bc, HalfwayBounceBackBC):
self.halfway_bounce_back_bc = bc
elif isinstance(bc, FullwayBounceBackBC):
self.fullway_bounce_back_bc = bc
from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC
from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC

# Define a list of tuples with attribute names and their corresponding classes
conditions = [
("equilibrium_bc", EquilibriumBC),
("do_nothing_bc", DoNothingBC),
("halfway_bounce_back_bc", HalfwayBounceBackBC),
("fullway_bounce_back_bc", FullwayBounceBackBC),
]

# this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example.
bc_fallback = boundary_conditions[0]

# Iterate over each boundary condition
for attr_name, bc_class in conditions:
for bc in boundary_conditions:
if isinstance(bc, bc_class):
setattr(self, attr_name, bc)
break
elif not hasattr(self, attr_name):
setattr(self, attr_name, bc_fallback)

if self.equilibrium_bc is None:
# Select the equilibrium operator based on its type
self.equilibrium_bc = EquilibriumBC(
rho=1.0,
u=(0.0, 0.0, 0.0),
equilibrium_operator=next((op for op in self.operators if isinstance(op, Equilibrium)), None),
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
if self.do_nothing_bc is None:
self.do_nothing_bc = DoNothingBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
if self.halfway_bounce_back_bc is None:
self.halfway_bounce_back_bc = HalfwayBounceBackBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
if self.fullway_bounce_back_bc is None:
self.fullway_bounce_back_bc = FullwayBounceBackBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
############################################

# Initialize operator
Expand Down

0 comments on commit 329bd4c

Please sign in to comment.