Skip to content

Commit

Permalink
used lax.broadcast_in_dim instead of jnp.repeat plus other minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Aug 20, 2024
1 parent 4cecd27 commit 76597d4
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 16 deletions.
4 changes: 3 additions & 1 deletion xlb/operator/boundary_condition/bc_fullway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax.numpy as jnp
from jax import jit
import jax.lax as lax
from functools import partial
import warp as wp
from typing import Any
Expand Down Expand Up @@ -47,7 +48,8 @@ def __init__(
@partial(jit, static_argnums=(0))
def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask):
boundary = boundary_mask == self.id
boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0)
new_shape = (self.velocity_set.q,) + boundary.shape[1:]
boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1)))
return jnp.where(boundary, f_pre[self.velocity_set.opp_indices, ...], f_post)

def _construct_warp(self):
Expand Down
4 changes: 3 additions & 1 deletion xlb/operator/boundary_condition/bc_halfway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax.numpy as jnp
from jax import jit
import jax.lax as lax
from functools import partial
import warp as wp
from typing import Any
Expand Down Expand Up @@ -50,7 +51,8 @@ def __init__(
@partial(jit, static_argnums=(0))
def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask):
boundary = boundary_mask == self.id
boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0)
new_shape = (self.velocity_set.q,) + boundary.shape[1:]
boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1)))
return jnp.where(
jnp.logical_and(missing_mask, boundary),
f_pre[self.velocity_set.opp_indices],
Expand Down
14 changes: 8 additions & 6 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax.numpy as jnp
from jax import jit
import jax.lax as lax
from functools import partial
import warp as wp
from typing import Any
Expand Down Expand Up @@ -139,7 +140,8 @@ def regularize_fpop(self, fpop, feq):
def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask):
# creat a mask to slice boundary cells
boundary = boundary_mask == self.id
boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0)
new_shape = (self.velocity_set.q,) + boundary.shape[1:]
boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1)))

# compute the equilibrium based on prescribed values and the type of BC
feq = self.calculate_equilibrium(f_post, missing_mask)
Expand Down Expand Up @@ -185,7 +187,7 @@ def get_normal_vectors_2d(
return normals

@wp.func
def _helper_function(
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
Expand Down Expand Up @@ -256,7 +258,7 @@ def functional3d_velocity(
normals = get_normal_vectors_3d(missing_mask)

# calculate rho
fsum = _helper_function(_f, missing_mask)
fsum = _get_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
for d in range(_d):
unormal += _u[d] * normals[d]
Expand All @@ -283,7 +285,7 @@ def functional3d_pressure(
normals = get_normal_vectors_3d(missing_mask)

# calculate velocity
fsum = _helper_function(_f, missing_mask)
fsum = _get_fsum(_f, missing_mask)
unormal = -1.0 + fsum / _rho
_u = unormal * normals

Expand All @@ -308,7 +310,7 @@ def functional2d_velocity(
normals = get_normal_vectors_2d(missing_mask)

# calculate rho
fsum = _helper_function(_f, missing_mask)
fsum = _get_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
for d in range(_d):
unormal += _u[d] * normals[d]
Expand All @@ -335,7 +337,7 @@ def functional2d_pressure(
normals = get_normal_vectors_2d(missing_mask)

# calculate velocity
fsum = _helper_function(_f, missing_mask)
fsum = _get_fsum(_f, missing_mask)
unormal = -1.0 + fsum / _rho
_u = unormal * normals

Expand Down
14 changes: 8 additions & 6 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax.numpy as jnp
from jax import jit
import jax.lax as lax
from functools import partial
import warp as wp
from typing import Any
Expand Down Expand Up @@ -156,7 +157,8 @@ def bounceback_nonequilibrium(self, fpop, feq, missing_mask):
def apply_jax(self, f_pre, f_post, boundary_mask, missing_mask):
# creat a mask to slice boundary cells
boundary = boundary_mask == self.id
boundary = jnp.repeat(boundary, self.velocity_set.q, axis=0)
new_shape = (self.velocity_set.q,) + boundary.shape[1:]
boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1)))

# compute the equilibrium based on prescribed values and the type of BC
feq = self.calculate_equilibrium(f_post, missing_mask)
Expand Down Expand Up @@ -193,7 +195,7 @@ def get_normal_vectors_2d(
return normals

@wp.func
def _helper_function(
def _get_fsum(
fpop: Any,
missing_mask: Any,
):
Expand Down Expand Up @@ -238,7 +240,7 @@ def functional3d_velocity(
normals = get_normal_vectors_3d(missing_mask)

# calculate rho
fsum = _helper_function(_f, missing_mask)
fsum = _get_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
for d in range(_d):
unormal += _u[d] * normals[d]
Expand All @@ -262,7 +264,7 @@ def functional3d_pressure(
normals = get_normal_vectors_3d(missing_mask)

# calculate velocity
fsum = _helper_function(_f, missing_mask)
fsum = _get_fsum(_f, missing_mask)
unormal = -1.0 + fsum / _rho
_u = unormal * normals

Expand All @@ -284,7 +286,7 @@ def functional2d_velocity(
normals = get_normal_vectors_2d(missing_mask)

# calculate rho
fsum = _helper_function(_f, missing_mask)
fsum = _get_fsum(_f, missing_mask)
unormal = self.compute_dtype(0.0)
for d in range(_d):
unormal += _u[d] * normals[d]
Expand All @@ -308,7 +310,7 @@ def functional2d_pressure(
normals = get_normal_vectors_2d(missing_mask)

# calculate velocity
fsum = _helper_function(_f, missing_mask)
fsum = _get_fsum(_f, missing_mask)
unormal = -1.0 + fsum / _rho
_u = unormal * normals

Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/macroscopic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from xlb.operator.macroscopic.zero_first_moments import FirstAndZerothMoment as Macroscopic
from xlb.operator.macroscopic.zero_first_moments import ZeroAndFirstMoments as Macroscopic
from xlb.operator.macroscopic.second_moment import SecondMoment as SecondMoment
2 changes: 1 addition & 1 deletion xlb/operator/macroscopic/zero_first_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from xlb.operator.operator import Operator


class FirstAndZerothMoment(Operator):
class ZeroAndFirstMoments(Operator):
"""
A class to compute first and zeroth moments of distribution functions.
Expand Down

0 comments on commit 76597d4

Please sign in to comment.