Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warp structs #56

Merged
merged 4 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions xlb/operator/stepper/nse_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,23 @@ def _construct_warp(self):
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool

# Get the boundary condition ids
_equilibrium_bc = wp.uint8(self.equilibrium_bc.id)
_do_nothing_bc = wp.uint8(self.do_nothing_bc.id)
_halfway_bounce_back_bc = wp.uint8(self.halfway_bounce_back_bc.id)
_fullway_bounce_back_bc = wp.uint8(self.fullway_bounce_back_bc.id)
@wp.struct
class BoundaryConditionIDStruct:
# Note the names are hardcoded here based on various BC operator names with "id_" at the beginning
# One needs to manually add the names of additional BC's as they are added.
# TODO: Anyway to improve this
mehdiataei marked this conversation as resolved.
Show resolved Hide resolved
id_EquilibriumBC: wp.uint8
id_DoNothingBC: wp.uint8
id_HalfwayBounceBackBC: wp.uint8
id_FullwayBounceBackBC: wp.uint8

@wp.kernel
def kernel2d(
f_0: wp.array3d(dtype=Any),
f_1: wp.array3d(dtype=Any),
boundary_mask: wp.array3d(dtype=Any),
missing_mask: wp.array3d(dtype=Any),
bc_struct: BoundaryConditionIDStruct,
timestep: int,
):
# Get the global index
Expand All @@ -117,17 +122,17 @@ def kernel2d(
else:
_missing_mask[l] = wp.uint8(0)

# Apply streaming boundary conditions
if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc:
# Regular streaming
f_post_stream = self.stream.warp_functional(f_0, index)
elif _boundary_id == _equilibrium_bc:
# Apply streaming (pull method)
f_post_stream = self.stream.warp_functional(f_0, index)

# Apply post-streaming type boundary conditions
if _boundary_id == bc_struct.id_EquilibriumBC:
# Equilibrium boundary condition
f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index)
elif _boundary_id == _do_nothing_bc:
elif _boundary_id == bc_struct.id_DoNothingBC:
# Do nothing boundary condition
f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index)
elif _boundary_id == _halfway_bounce_back_bc:
elif _boundary_id == bc_struct.id_HalfwayBounceBackBC:
# Half way boundary condition
f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index)

Expand All @@ -145,8 +150,8 @@ def kernel2d(
u,
)

# Apply collision type boundary conditions
if _boundary_id == _fullway_bounce_back_bc:
# Apply post-collision type boundary conditions
if _boundary_id == bc_struct.id_FullwayBounceBackBC:
# Full way boundary condition
f_post_collision = self.fullway_bounce_back_bc.warp_functional(
f_post_stream,
Expand All @@ -165,6 +170,7 @@ def kernel3d(
f_1: wp.array4d(dtype=Any),
boundary_mask: wp.array4d(dtype=Any),
missing_mask: wp.array4d(dtype=Any),
bc_struct: BoundaryConditionIDStruct,
timestep: int,
):
# Get the global index
Expand All @@ -181,17 +187,17 @@ def kernel3d(
else:
_missing_mask[l] = wp.uint8(0)

# Apply streaming boundary conditions
if (_boundary_id == wp.uint8(0)) or _boundary_id == _fullway_bounce_back_bc:
# Regular streaming
f_post_stream = self.stream.warp_functional(f_0, index)
elif _boundary_id == _equilibrium_bc:
# Apply streaming (pull method)
f_post_stream = self.stream.warp_functional(f_0, index)

# Apply post-streaming boundary conditions
if _boundary_id == bc_struct.id_EquilibriumBC:
# Equilibrium boundary condition
f_post_stream = self.equilibrium_bc.warp_functional(f_0, _missing_mask, index)
elif _boundary_id == _do_nothing_bc:
elif _boundary_id == bc_struct.id_DoNothingBC:
# Do nothing boundary condition
f_post_stream = self.do_nothing_bc.warp_functional(f_0, _missing_mask, index)
elif _boundary_id == _halfway_bounce_back_bc:
elif _boundary_id == bc_struct.id_HalfwayBounceBackBC:
# Half way boundary condition
f_post_stream = self.halfway_bounce_back_bc.warp_functional(f_0, _missing_mask, index)

Expand All @@ -205,7 +211,7 @@ def kernel3d(
f_post_collision = self.collision.warp_functional(f_post_stream, feq, rho, u)

# Apply collision type boundary conditions
if _boundary_id == _fullway_bounce_back_bc:
if _boundary_id == bc_struct.id_FullwayBounceBackBC:
# Full way boundary condition
f_post_collision = self.fullway_bounce_back_bc.warp_functional(
f_post_stream,
Expand All @@ -220,10 +226,31 @@ def kernel3d(
# Return the correct kernel
kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return None, kernel
return BoundaryConditionIDStruct, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep):
# Get the boundary condition ids
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry

bc_to_id = boundary_condition_registry.bc_to_id

bc_struct = self.warp_functional()
bc_attribute_list = []
for bc in self.boundary_conditions:
# Setting the Struct attributes based on the BC class names
attribute_str = bc.__class__.__name__
setattr(bc_struct, "id_" + attribute_str, bc_to_id[attribute_str])
bc_attribute_list.append("id_" + attribute_str)

# Unused attributes of the struct are set to inernal (id=0)
ll = vars(bc_struct)
for var in ll:
if var not in bc_attribute_list and not var.startswith("_"):
# set unassigned boundaries to the maximum integer in uint8
attribute_str = bc.__class__.__name__
setattr(bc_struct, var, 255)

# Launch the warp kernel
wp.launch(
self.warp_kernel,
Expand All @@ -232,6 +259,7 @@ def warp_implementation(self, f_0, f_1, boundary_mask, missing_mask, timestep):
f_1,
boundary_mask,
missing_mask,
bc_struct,
timestep,
],
dim=f_0.shape[1:],
Expand Down
75 changes: 23 additions & 52 deletions xlb/operator/stepper/stepper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Base class for all stepper operators
from xlb.operator import Operator
from xlb.operator.equilibrium import Equilibrium
from xlb import DefaultConfig


Expand All @@ -26,63 +25,35 @@ def __init__(self, operators, boundary_conditions):
compute_backend = DefaultConfig.default_backend if not compute_backends else compute_backends.pop()

# Add boundary conditions
# Warp cannot handle lists of functions currently
# Because of this we manually unpack the boundary conditions
############################################
# Warp cannot handle lists of functions currently
# TODO: Fix this later
############################################
from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC
from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC
from xlb.operator.boundary_condition.bc_halfway_bounce_back import (
HalfwayBounceBackBC,
)
from xlb.operator.boundary_condition.bc_fullway_bounce_back import (
FullwayBounceBackBC,
)

self.equilibrium_bc = None
self.do_nothing_bc = None
self.halfway_bounce_back_bc = None
self.fullway_bounce_back_bc = None

for bc in boundary_conditions:
if isinstance(bc, EquilibriumBC):
self.equilibrium_bc = bc
elif isinstance(bc, DoNothingBC):
self.do_nothing_bc = bc
elif isinstance(bc, HalfwayBounceBackBC):
self.halfway_bounce_back_bc = bc
elif isinstance(bc, FullwayBounceBackBC):
self.fullway_bounce_back_bc = bc
from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC
from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC

# Define a list of tuples with attribute names and their corresponding classes
conditions = [
("equilibrium_bc", EquilibriumBC),
("do_nothing_bc", DoNothingBC),
("halfway_bounce_back_bc", HalfwayBounceBackBC),
("fullway_bounce_back_bc", FullwayBounceBackBC),
]

# this fall-back BC is just to ensure Warp codegen does not produce error when a particular BC is not used in an example.
bc_fallback = boundary_conditions[0]

# Iterate over each boundary condition
for attr_name, bc_class in conditions:
for bc in boundary_conditions:
if isinstance(bc, bc_class):
setattr(self, attr_name, bc)
break
elif not hasattr(self, attr_name):
setattr(self, attr_name, bc_fallback)

if self.equilibrium_bc is None:
# Select the equilibrium operator based on its type
self.equilibrium_bc = EquilibriumBC(
rho=1.0,
u=(0.0, 0.0, 0.0),
equilibrium_operator=next((op for op in self.operators if isinstance(op, Equilibrium)), None),
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
if self.do_nothing_bc is None:
self.do_nothing_bc = DoNothingBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
if self.halfway_bounce_back_bc is None:
self.halfway_bounce_back_bc = HalfwayBounceBackBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
if self.fullway_bounce_back_bc is None:
self.fullway_bounce_back_bc = FullwayBounceBackBC(
velocity_set=velocity_set,
precision_policy=precision_policy,
compute_backend=compute_backend,
)
############################################

# Initialize operator
Expand Down
Loading