Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZouHe bc added in Warp and JAX #58

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xlb/operator/boundary_condition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from xlb.operator.boundary_condition.bc_do_nothing import DoNothingBC as DoNothingBC
from xlb.operator.boundary_condition.bc_halfway_bounce_back import HalfwayBounceBackBC as HalfwayBounceBackBC
from xlb.operator.boundary_condition.bc_fullway_bounce_back import FullwayBounceBackBC as FullwayBounceBackBC
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC as ZouHeBC
41 changes: 5 additions & 36 deletions xlb/operator/boundary_condition/bc_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,7 @@ def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask):
return jnp.where(boundary, f_pre, f_post)

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool

# Construct the funcional to get streamed indices

# Construct the functional for this BC
@wp.func
def functional(
f_pre: Any,
Expand All @@ -76,21 +71,8 @@ def kernel2d(
i, j = wp.tid()
index = wp.vec2i(i, j)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1]]:
_missing_mask[l] = wp.uint8(1)
else:
_missing_mask[l] = wp.uint8(0)
# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
Expand All @@ -114,21 +96,8 @@ def kernel3d(
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1], index[2]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
_missing_mask[l] = wp.uint8(1)
else:
_missing_mask[l] = wp.uint8(0)
# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(DoNothingBC.id):
Expand Down
38 changes: 5 additions & 33 deletions xlb/operator/boundary_condition/bc_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,11 @@ def jax_implementation(self, f_pre, f_post, boundary_mask, missing_mask):

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
_rho = wp.float32(self.rho)
_u = _u_vec(self.u[0], self.u[1], self.u[2]) if self.velocity_set.d == 3 else _u_vec(self.u[0], self.u[1])
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool

# Construct the funcional to get streamed indices
# Construct the functional for this BC
@wp.func
def functional(
f_pre: Any,
Expand All @@ -98,21 +96,8 @@ def kernel2d(
i, j = wp.tid()
index = wp.vec2i(i, j)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1]]:
_missing_mask[l] = wp.uint8(1)
else:
_missing_mask[l] = wp.uint8(0)
# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
Expand All @@ -136,21 +121,8 @@ def kernel3d(
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1], index[2]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
_missing_mask[l] = wp.uint8(1)
else:
_missing_mask[l] = wp.uint8(0)
# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(EquilibriumBC.id):
Expand Down
45 changes: 7 additions & 38 deletions xlb/operator/boundary_condition/bc_fullway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def _construct_warp(self):
_opp_indices = self.velocity_set.wp_opp_indices
_q = wp.constant(self.velocity_set.q)
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool

# Construct the funcional to get streamed indices
# Construct the functional for this BC
@wp.func
def functional(
f_pre: Any,
Expand All @@ -79,27 +78,12 @@ def kernel2d(
i, j = wp.tid()
index = wp.vec2i(i, j)

# Get the boundary id and missing mask
_boundary_id = boundary_mask[0, index[0], index[1]]

# Make vectors for the lattice
_f_pre = _f_vec()
_f_post = _f_vec()
_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1]]:
_mask[l] = wp.uint8(1)
else:
_mask[l] = wp.uint8(0)
# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index)

# Check if the boundary is active
if _boundary_id == wp.uint8(FullwayBounceBackBC.id):
_f = functional(_f_pre, _f_post, _mask)
_f = functional(_f_pre, _f_post, _missing_mask)
else:
_f = _f_post

Expand All @@ -119,27 +103,12 @@ def kernel3d(
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# Get the boundary id and missing mask
_boundary_id = boundary_mask[0, index[0], index[1], index[2]]

# Make vectors for the lattice
_f_pre = _f_vec()
_f_post = _f_vec()
_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
_mask[l] = wp.uint8(1)
else:
_mask[l] = wp.uint8(0)
# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index)

# Check if the boundary is active
if _boundary_id == wp.uint8(FullwayBounceBackBC.id):
_f = functional(_f_pre, _f_post, _mask)
_f = functional(_f_pre, _f_post, _missing_mask)
else:
_f = _f_post

Expand Down
39 changes: 6 additions & 33 deletions xlb/operator/boundary_condition/bc_halfway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@ def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask):
)

def _construct_warp(self):
# Set local constants TODO: This is a hack and should be fixed with warp update
_c = self.velocity_set.wp_c
# Set local constants
_opp_indices = self.velocity_set.wp_opp_indices
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool

# Construct the functional for this BC
@wp.func
def functional(
f_pre: Any,
Expand Down Expand Up @@ -92,20 +90,8 @@ def kernel2d(
i, j = wp.tid()
index = wp.vec3i(i, j)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]
# TODO fix vec bool
if missing_mask[l, index[0], index[1]]:
_missing_mask[l] = wp.uint8(1)
else:
_missing_mask[l] = wp.uint8(0)
# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_2d(f_pre, f_post, boundary_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(HalfwayBounceBackBC.id):
Expand All @@ -129,21 +115,8 @@ def kernel3d(
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)

# Get the boundary id and missing mask
_f_pre = _f_vec()
_f_post = _f_vec()
_boundary_id = boundary_mask[0, index[0], index[1], index[2]]
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
_missing_mask[l] = wp.uint8(1)
else:
_missing_mask[l] = wp.uint8(0)
# read tid data
_f_pre, _f_post, _boundary_id, _missing_mask = self._get_thread_data_3d(f_pre, f_post, boundary_mask, missing_mask, index)

# Apply the boundary condition
if _boundary_id == wp.uint8(HalfwayBounceBackBC.id):
Expand Down
Loading
Loading