Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing mask in JAX and Warp #57

Merged
merged 12 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from xlb.helper import create_nse_fields, initialize_eq
from xlb.operator.boundary_masker import IndicesBoundaryMasker
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC
from xlb.operator.boundary_condition import HalfwayBounceBackBC, EquilibriumBC
from xlb.operator.macroscopic import Macroscopic
from xlb.utils import save_fields_vtk, save_image
import warp as wp
Expand Down Expand Up @@ -48,7 +48,7 @@ def define_boundary_indices(self):
def setup_boundary_conditions(self):
lid, walls = self.define_boundary_indices()
bc_top = EquilibriumBC(rho=1.0, u=(0.02, 0.0), indices=lid)
bc_walls = FullwayBounceBackBC(indices=walls)
bc_walls = HalfwayBounceBackBC(indices=walls)
self.boundary_conditions = [bc_top, bc_walls]

def setup_boundary_masks(self):
Expand Down Expand Up @@ -99,7 +99,7 @@ def post_process(self, i):
# Running the simulation
grid_size = 500
grid_shape = (grid_size, grid_size)
backend = ComputeBackend.JAX
backend = ComputeBackend.WARP
velocity_set = xlb.velocity_set.D2Q9()
precision_policy = PrecisionPolicy.FP32FP32
omega = 1.6
Expand Down
4 changes: 2 additions & 2 deletions examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def run(self, num_steps, print_interval, post_process_interval=100):
elapsed_time = time.time() - start_time
print(f"Iteration: {i + 1}/{num_steps} | Time elapsed: {elapsed_time:.2f}s")

if i % post_process_interval == 0 or i == num_steps - 1:
self.post_process(i)
if i % post_process_interval == 0 or i == num_steps - 1:
self.post_process(i)

def post_process(self, i):
# Write the results. We'll use JAX backend for the post-processing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape):
cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0
) # Arbitrary value so that we can check if the values are changed outside the boundary

f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask, f)
f = equilibrium_bc(f_pre, f_post, boundary_mask, missing_mask)

f = f.numpy()
f_post = f_post.numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from xlb.compute_backend import ComputeBackend
from xlb.grid import grid_factory
from xlb import DefaultConfig
from xlb.operator.boundary_masker import IndicesBoundaryMasker


def init_xlb_env(velocity_set):
Expand Down Expand Up @@ -35,7 +36,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape):

boundary_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8)

indices_boundary_masker = xlb.operator.boundary_masker.IndicesBoundaryMasker()
indices_boundary_masker = IndicesBoundaryMasker()

# Make indices for boundary conditions (sphere)
sphere_radius = grid_shape[0] // 4
Expand Down Expand Up @@ -64,7 +65,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape):
cardinality=velocity_set.q, dtype=xlb.Precision.FP32, fill_value=2.0
) # Arbitrary value so that we can check if the values are changed outside the boundary

f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask, f_pre)
f_pre = fullway_bc(f_pre, f_post, boundary_mask, missing_mask)

f = f_pre.numpy()
f_post = f_post.numpy()
Expand Down
59 changes: 25 additions & 34 deletions xlb/operator/boundary_condition/bc_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,43 +58,34 @@ def _construct_warp(self):
# Construct the funcional to get streamed indices

@wp.func
def functional2d(
f: wp.array3d(dtype=Any),
def functional(
f_pre: Any,
f_post: Any,
missing_mask: Any,
index: Any,
):
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1]]
return _f

@wp.func
def functional3d(
f: wp.array4d(dtype=Any),
missing_mask: Any,
index: Any,
):
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
return _f
return f_pre

@wp.kernel
def kernel2d(
f_pre: wp.array3d(dtype=Any),
f_post: wp.array3d(dtype=Any),
boundary_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.uint8),
f: wp.array3d(dtype=Any),
):
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1]]:
_missing_mask[l] = wp.uint8(1)
Expand All @@ -103,15 +94,13 @@ def kernel2d(

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
_f = functional3d(f_pre, _missing_mask, index)
_f = functional(_f_pre, _f_post, _missing_mask)
else:
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f_post[l, index[0], index[1]]
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = _f[l]

# Construct the warp kernel
@wp.kernel
Expand All @@ -120,16 +109,21 @@ def kernel3d(
f_post: wp.array4d(dtype=Any),
boundary_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
f: wp.array4d(dtype=Any),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1], index[2]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
_missing_mask[l] = wp.uint8(1)
Expand All @@ -138,27 +132,24 @@ def kernel3d(

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
_f = functional3d(f_pre, _missing_mask, index)
_f = functional(_f_pre, _f_post, _missing_mask)
else:
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f_post[l, index[0], index[1], index[2]]
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = _f[l]

functional = functional3d if self.velocity_set.d == 3 else functional2d
kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f):
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, boundary_mask, missing_mask, f],
inputs=[f_pre, f_post, boundary_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f
return f_post
52 changes: 24 additions & 28 deletions xlb/operator/boundary_condition/bc_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def _construct_warp(self):

# Construct the funcional to get streamed indices
@wp.func
def functional2d(
f: wp.array3d(dtype=Any),
def functional(
f_pre: Any,
f_post: Any,
missing_mask: Any,
index: Any,
):
_f = self.equilibrium_operator.warp_functional(_rho, _u)
return _f
Expand All @@ -93,16 +93,21 @@ def kernel2d(
f_post: wp.array3d(dtype=Any),
boundary_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.bool),
f: wp.array3d(dtype=Any),
):
# Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1]]:
_missing_mask[l] = wp.uint8(1)
Expand All @@ -111,24 +116,13 @@ def kernel2d(

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
_f = functional2d(f_post, _missing_mask, index)
_f = functional(_f_pre, _f_post, _missing_mask)
else:
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f_post[l, index[0], index[1]]
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f[l, index[0], index[1]] = _f[l]

@wp.func
def functional3d(
f: wp.array4d(dtype=Any),
missing_mask: Any,
index: Any,
):
_f = self.equilibrium_operator.warp_functional(_rho, _u)
return _f
f_post[l, index[0], index[1]] = _f[l]

# Construct the warp kernel
@wp.kernel
Expand All @@ -137,16 +131,21 @@ def kernel3d(
f_post: wp.array4d(dtype=Any),
boundary_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
f: wp.array4d(dtype=Any),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1], index[2]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
_missing_mask[l] = wp.uint8(1)
Expand All @@ -155,27 +154,24 @@ def kernel3d(

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
_f = functional3d(f_post, _missing_mask, index)
_f = functional(_f_pre, _f_post, _missing_mask)
else:
_f = _f_vec()
for l in range(self.velocity_set.q):
_f[l] = f_post[l, index[0], index[1], index[2]]
_f = _f_post

# Write the result
for l in range(self.velocity_set.q):
f[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = _f[l]

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
functional = functional3d if self.velocity_set.d == 3 else functional2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f):
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, boundary_mask, missing_mask, f],
inputs=[f_pre, f_post, boundary_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f
return f_post
14 changes: 7 additions & 7 deletions xlb/operator/boundary_condition/bc_fullway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def kernel2d(
f_post: wp.array3d(dtype=Any),
boundary_mask: wp.array3d(dtype=wp.uint8),
missing_mask: wp.array3d(dtype=wp.bool),
f: wp.array3d(dtype=Any),
): # Get the global index
i, j = wp.tid()
index = wp.vec2i(i, j)
Expand All @@ -88,6 +87,7 @@ def kernel2d(
_f_post = _f_vec()
_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]

Expand All @@ -105,7 +105,7 @@ def kernel2d(

# Write the result to the output
for l in range(self.velocity_set.q):
f[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = _f[l]

# Construct the warp kernel
@wp.kernel
Expand All @@ -114,7 +114,6 @@ def kernel3d(
f_post: wp.array4d(dtype=Any),
boundary_mask: wp.array4d(dtype=wp.uint8),
missing_mask: wp.array4d(dtype=wp.bool),
f: wp.array4d(dtype=Any),
):
# Get the global index
i, j, k = wp.tid()
Expand All @@ -128,6 +127,7 @@ def kernel3d(
_f_post = _f_vec()
_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]

Expand All @@ -145,18 +145,18 @@ def kernel3d(

# Write the result to the output
for l in range(self.velocity_set.q):
f[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = _f[l]

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

return functional, kernel

@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask, f):
def warp_implementation(self, f_pre, f_post, boundary_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[f_pre, f_post, boundary_mask, missing_mask, f],
inputs=[f_pre, f_post, boundary_mask, missing_mask],
dim=f_pre.shape[1:],
)
return f
return f_post
Loading
Loading