Skip to content

Commit

Permalink
renamed connectivty bitmask
Browse files Browse the repository at this point in the history
  • Loading branch information
hsalehipour committed Nov 23, 2023
1 parent 5b571cd commit d2b2855
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 82 deletions.
40 changes: 20 additions & 20 deletions src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,25 +351,25 @@ def show_simulation_parameters(self):
def _create_boundary_data(self):
"""
Create boundary data for the Lattice Boltzmann simulation by setting boundary conditions,
creating grid connectivity bitmask, and preparing local bitmasks and normal arrays.
creating grid mask, and preparing local masks and normal arrays.
"""
self.BCs = []
self.set_boundary_conditions()
# Accumulate the indices of all BCs to create the grid connectivity bitmask with FALSE along directions that
# Accumulate the indices of all BCs to create the grid mask with FALSE along directions that
# stream into a boundary voxel.
solid_halo_list = [np.array(bc.indices).T for bc in self.BCs if bc.isSolid]
solid_halo_voxels = np.unique(np.vstack(solid_halo_list), axis=0) if solid_halo_list else None

# Create the grid connectivity bitmask on each process
# Create the grid mask on each process
start = time.time()
connectivity_bitmask = self.create_grid_connectivity_bitmask(solid_halo_voxels)
print("Time to create the grid connectivity bitmask:", time.time() - start)
grid_mask = self.create_grid_mask(solid_halo_voxels)
print("Time to create the grid mask:", time.time() - start)

start = time.time()
for bc in self.BCs:
assert bc.implementationStep in ['PostStreaming', 'PostCollision']
bc.create_local_bitmask_and_normal_arrays(connectivity_bitmask)
print("Time to create the local bitmasks and normal arrays:", time.time() - start)
bc.create_local_mask_and_normal_arrays(grid_mask)
print("Time to create the local masks and normal arrays:", time.time() - start)

# This is another non-JITed way of creating the distributed arrays. It is not used at the moment.
# def distributed_array_init(self, shape, type, init_val=None):
Expand Down Expand Up @@ -411,42 +411,42 @@ def distributed_array_init(self, shape, type, init_val=0, sharding=None):
return jax.lax.with_sharding_constraint(x, sharding)

@partial(jit, static_argnums=(0,))
def create_grid_connectivity_bitmask(self, solid_halo_voxels):
def create_grid_mask(self, solid_halo_voxels):
"""
This function creates a bitmask for the background grid that accounts for the location of the boundaries.
This function creates a mask for the background grid that accounts for the location of the boundaries.
Parameters
----------
solid_halo_voxels: A numpy array representing the voxels in the halo of the solid object.
Returns
-------
A JAX array representing the connectivity bitmask of the grid.
A JAX array representing the grid mask of the grid.
"""
# Halo width (hw_x is different to accommodate the domain sharding per XLA device)
hw_x = self.nDevices
hw_y = hw_z = 1
if self.dim == 2:
connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.lattice.q), jnp.bool_, init_val=True)
connectivity_bitmask = connectivity_bitmask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(None))].set(False)
grid_mask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.lattice.q), jnp.bool_, init_val=True)
grid_mask = grid_mask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(None))].set(False)
if solid_halo_voxels is not None:
solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x)
solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y)
connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True)
grid_mask = grid_mask.at[tuple(solid_halo_voxels.T)].set(True)

connectivity_bitmask = self.streaming(connectivity_bitmask)
return lax.with_sharding_constraint(connectivity_bitmask, self.sharding)
grid_mask = self.streaming(grid_mask)
return lax.with_sharding_constraint(grid_mask, self.sharding)

elif self.dim == 3:
connectivity_bitmask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.nz + 2 * hw_z, self.lattice.q), jnp.bool_, init_val=True)
connectivity_bitmask = connectivity_bitmask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(hw_z, -hw_z), slice(None))].set(False)
grid_mask = self.distributed_array_init((self.nx + 2 * hw_x, self.ny + 2 * hw_y, self.nz + 2 * hw_z, self.lattice.q), jnp.bool_, init_val=True)
grid_mask = grid_mask.at[(slice(hw_x, -hw_x), slice(hw_y, -hw_y), slice(hw_z, -hw_z), slice(None))].set(False)
if solid_halo_voxels is not None:
solid_halo_voxels = solid_halo_voxels.at[:, 0].add(hw_x)
solid_halo_voxels = solid_halo_voxels.at[:, 1].add(hw_y)
solid_halo_voxels = solid_halo_voxels.at[:, 2].add(hw_z)
connectivity_bitmask = connectivity_bitmask.at[tuple(solid_halo_voxels.T)].set(True)
connectivity_bitmask = self.streaming(connectivity_bitmask)
return lax.with_sharding_constraint(connectivity_bitmask, self.sharding)
grid_mask = grid_mask.at[tuple(solid_halo_voxels.T)].set(True)
grid_mask = self.streaming(grid_mask)
return lax.with_sharding_constraint(grid_mask, self.sharding)

def bounding_box_indices(self):
"""
Expand Down
124 changes: 62 additions & 62 deletions src/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,70 +52,70 @@ def __init__(self, indices, gridInfo, precision_policy):
self.needsExtraConfiguration = False
self.implementationStep = "PostStreaming"

def create_local_bitmask_and_normal_arrays(self, connectivity_bitmask):
def create_local_mask_and_normal_arrays(self, grid_mask):

"""
Creates local bitmask and normal arrays for the boundary condition.
Creates local mask and normal arrays for the boundary condition.
Parameters
----------
connectivity_bitmask : array-like
The connectivity bitmask for the lattice.
grid_mask : array-like
The grid mask for the lattice.
Returns
-------
None
Notes
-----
This method creates local bitmask and normal arrays for the boundary condition based on the connectivity bitmask.
This method creates local mask and normal arrays for the boundary condition based on the grid mask.
If the boundary condition requires extra configuration, the `configure` method is called.
"""

if self.needsExtraConfiguration:
boundaryBitmask = self.get_boundary_bitmask(connectivity_bitmask)
self.configure(boundaryBitmask)
boundaryMask = self.get_boundary_mask(grid_mask)
self.configure(boundaryMask)
self.needsExtraConfiguration = False

boundaryBitmask = self.get_boundary_bitmask(connectivity_bitmask)
self.normals = self.get_normals(boundaryBitmask)
self.imissing, self.iknown = self.get_missing_indices(boundaryBitmask)
self.imissingBitmask, self.iknownBitmask, self.imiddleBitmask = self.get_missing_bitmask(boundaryBitmask)
boundaryMask = self.get_boundary_mask(grid_mask)
self.normals = self.get_normals(boundaryMask)
self.imissing, self.iknown = self.get_missing_indices(boundaryMask)
self.imissingMask, self.iknownMask, self.imiddleMask = self.get_missing_mask(boundaryMask)

return

def get_boundary_bitmask(self, connectivity_bitmask):
def get_boundary_mask(self, grid_mask):
"""
Add jax.device_count() to the self.indices in x-direction, and 1 to the self.indices other directions
This is to make sure the boundary condition is applied to the correct nodes as connectivity_bitmask is
This is to make sure the boundary condition is applied to the correct nodes as grid_mask is
expanded by (jax.device_count(), 1, 1)
Parameters
----------
connectivity_bitmask : array-like
The connectivity bitmask for the lattice.
grid_mask : array-like
The grid mask for the lattice.
Returns
-------
boundaryBitmask : array-like
boundaryMask : array-like
"""
shifted_indices = np.array(self.indices)
shifted_indices[0] += device_count()
shifted_indices[1:] += 1
# Convert back to tuple
shifted_indices = tuple(shifted_indices)
boundaryBitmask = np.array(connectivity_bitmask[shifted_indices])
boundaryMask = np.array(grid_mask[shifted_indices])

return boundaryBitmask
return boundaryMask

def configure(self, boundaryBitmask):
def configure(self, boundaryMask):
"""
Configures the boundary condition.
Parameters
----------
boundaryBitmask : array-like
The connectivity bitmask for the boundary voxels.
boundaryMask : array-like
The grid mask for the boundary voxels.
Returns
-------
Expand Down Expand Up @@ -152,14 +152,14 @@ def prepare_populations(self, fout, fin, implementation_step):
"""
return fout

def get_normals(self, boundaryBitmask):
def get_normals(self, boundaryMask):
"""
Calculates the normal vectors at the boundary nodes.
Parameters
----------
boundaryBitmask : array-like
The boundary bitmask for the lattice.
boundaryMask : array-like
The boundary mask for the lattice.
Returns
-------
Expand All @@ -168,22 +168,22 @@ def get_normals(self, boundaryBitmask):
Notes
-----
This method calculates the normal vectors by dotting the boundary bitmask with the main lattice directions.
This method calculates the normal vectors by dotting the boundary mask with the main lattice directions.
"""
main_c = self.lattice.c.T[self.lattice.main_indices]
m = boundaryBitmask[..., self.lattice.main_indices]
m = boundaryMask[..., self.lattice.main_indices]
normals = -np.dot(m, main_c)
return normals

def get_missing_indices(self, boundaryBitmask):
def get_missing_indices(self, boundaryMask):
"""
Returns two int8 arrays the same shape as boundaryBitmask. The non-zero entries of these arrays indicate missing
Returns two int8 arrays the same shape as boundaryMask. The non-zero entries of these arrays indicate missing
directions that require BCs (imissing) as well as their corresponding opposite directions (iknown).
Parameters
----------
boundaryBitmask : array-like
The boundary bitmask for the lattice.
boundaryMask : array-like
The boundary mask for the lattice.
Returns
-------
Expand All @@ -192,45 +192,45 @@ def get_missing_indices(self, boundaryBitmask):
Notes
-----
This method calculates the missing and known indices based on the boundary bitmask. The missing indices are the
non-zero entries of the boundary bitmask, and the known indices are their corresponding opposite directions.
This method calculates the missing and known indices based on the boundary mask. The missing indices are the
non-zero entries of the boundary mask, and the known indices are their corresponding opposite directions.
"""

# Find imissing, iknown 1-to-1 corresponding indices
# Note: the "zero" index is used as default value here and won't affect BC computations
nbd = len(self.indices[0])
imissing = np.vstack([np.arange(self.lattice.q, dtype='uint8')] * nbd)
iknown = np.vstack([self.lattice.opp_indices] * nbd)
imissing[~boundaryBitmask] = 0
iknown[~boundaryBitmask] = 0
imissing[~boundaryMask] = 0
iknown[~boundaryMask] = 0
return imissing, iknown

def get_missing_bitmask(self, boundaryBitmask):
def get_missing_mask(self, boundaryMask):
"""
Returns three boolean arrays the same shape as boundaryBitmask.
Note: these boundary bitmasks are useful for reduction (eg. summation) operators of selected q-directions.
Returns three boolean arrays the same shape as boundaryMask.
Note: these boundary masks are useful for reduction (eg. summation) operators of selected q-directions.
Parameters
----------
boundaryBitmask : array-like
The boundary bitmask for the lattice.
boundaryMask : array-like
The boundary mask for the lattice.
Returns
-------
tuple of array-like
The missing, known, and middle bitmasks for the boundary condition.
The missing, known, and middle masks for the boundary condition.
Notes
-----
This method calculates the missing, known, and middle bitmasks based on the boundary bitmask. The missing bitmask
is the boundary bitmask, the known bitmask is the opposite directions of the missing bitmask, and the middle bitmask
This method calculates the missing, known, and middle masks based on the boundary mask. The missing mask
is the boundary mask, the known mask is the opposite directions of the missing mask, and the middle mask
is the directions that are neither missing nor known.
"""
# Find Bitmasks for imissing, iknown and imiddle
imissingBitmask = boundaryBitmask
iknownBitmask = imissingBitmask[:, self.lattice.opp_indices]
imiddleBitmask = ~(imissingBitmask | iknownBitmask)
return imissingBitmask, iknownBitmask, imiddleBitmask
# Find masks for imissing, iknown and imiddle
imissingMask = boundaryMask
iknownMask = imissingMask[:, self.lattice.opp_indices]
imiddleMask = ~(imissingMask | iknownMask)
return imissingMask, iknownMask, imiddleMask

@partial(jit, static_argnums=(0,))
def apply(self, fout, fin):
Expand Down Expand Up @@ -478,14 +478,14 @@ def __init__(self, indices, gridInfo, precision_policy, vel=None):
self.isSolid = True
self.vel = vel

def configure(self, boundaryBitmask):
def configure(self, boundaryMask):
"""
Configures the boundary condition.
Parameters
----------
boundaryBitmask : array-like
The connectivity bitmask for the boundary voxels.
boundaryMask : array-like
The grid mask for the boundary voxels.
Returns
-------
Expand All @@ -497,7 +497,7 @@ def configure(self, boundaryBitmask):
the boundary nodes to be the indices of fluid nodes adjacent of the solid nodes.
"""
# Perform index shift for halfway BB.
hasFluidNeighbour = ~boundaryBitmask[:, self.lattice.opp_indices]
hasFluidNeighbour = ~boundaryMask[:, self.lattice.opp_indices]
nbd_orig = len(self.indices[0])
idx = np.array(self.indices).T
idx_trg = []
Expand Down Expand Up @@ -685,12 +685,12 @@ def __init__(self, indices, gridInfo, precision_policy, type, prescribed):
self.prescribed = prescribed
self.needsExtraConfiguration = True

def configure(self, boundaryBitmask):
def configure(self, boundaryMask):
"""
Correct boundary indices to ensure that only voxelized surfaces with normal vectors along main cartesian axes
are assigned this type of BC.
"""
nv = np.dot(self.lattice.c, ~boundaryBitmask.T)
nv = np.dot(self.lattice.c, ~boundaryMask.T)
corner_voxels = np.count_nonzero(nv, axis=0) > 1
# removed_voxels = np.array(self.indices)[:, corner_voxels]
self.indices = tuple(np.array(self.indices)[:, ~corner_voxels])
Expand All @@ -702,8 +702,8 @@ def calculate_vel(self, fpop, rho):
"""
Calculate velocity based on the prescribed pressure/density (Zou/He BC)
"""
unormal = -1. + 1. / rho * (jnp.sum(fpop[self.indices] * self.imiddleBitmask, axis=1) +
2. * jnp.sum(fpop[self.indices] * self.iknownBitmask, axis=1))
unormal = -1. + 1. / rho * (jnp.sum(fpop[self.indices] * self.imiddleMask, axis=1) +
2. * jnp.sum(fpop[self.indices] * self.iknownMask, axis=1))

# Return the above unormal as a normal vector which sets the tangential velocities to zero
vel = unormal[:, None] * self.normals
Expand All @@ -716,8 +716,8 @@ def calculate_rho(self, fpop, vel):
"""
unormal = np.sum(self.normals*vel, axis=1)

rho = (1.0/(1.0 + unormal))[..., None] * (jnp.sum(fpop[self.indices] * self.imiddleBitmask, axis=1, keepdims=True) +
2.*jnp.sum(fpop[self.indices] * self.iknownBitmask, axis=1, keepdims=True))
rho = (1.0/(1.0 + unormal))[..., None] * (jnp.sum(fpop[self.indices] * self.imiddleMask, axis=1, keepdims=True) +
2.*jnp.sum(fpop[self.indices] * self.iknownMask, axis=1, keepdims=True))
return rho

@partial(jit, static_argnums=(0,), inline=True)
Expand Down Expand Up @@ -939,16 +939,16 @@ def __init__(self, indices, gridInfo, precision_policy):
self.needsExtraConfiguration = True
self.sound_speed = 1./jnp.sqrt(3.)

def configure(self, boundaryBitmask):
def configure(self, boundaryMask):
"""
Configure the boundary condition by finding neighbouring voxel indices.
Parameters
----------
boundaryBitmask : np.ndarray
The connectivity bitmask for the boundary voxels.
boundaryMask : np.ndarray
The grid mask for the boundary voxels.
"""
hasFluidNeighbour = ~boundaryBitmask[:, self.lattice.opp_indices]
hasFluidNeighbour = ~boundaryMask[:, self.lattice.opp_indices]
idx = np.array(self.indices).T
idx_trg = []
for i in range(self.lattice.q):
Expand Down Expand Up @@ -1066,7 +1066,7 @@ def set_proximity_ratio(self):
solid_indices = idx + c[:, q]
solid_indices_tuple = tuple(map(tuple, solid_indices.T))
sdf_s = self.implicit_distances[solid_indices_tuple]
mask = self.iknownBitmask[:, q]
mask = self.iknownMask[:, q]
self.weights[mask, q] = sdf_f[mask] / (sdf_f[mask] - sdf_s[mask])
return

Expand Down

0 comments on commit d2b2855

Please sign in to comment.