Skip to content

Commit

Permalink
Fixed mixed-precision for Warp
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdiataei committed Sep 24, 2024
1 parent 6eca3f2 commit 9496f6f
Show file tree
Hide file tree
Showing 24 changed files with 74 additions and 87 deletions.
2 changes: 1 addition & 1 deletion examples/cfd/flow_past_sphere_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def setup_boundary_masker(self):
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask, (0, 0, 0))

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions, collision_type="BGK")
Expand Down
2 changes: 1 addition & 1 deletion examples/cfd/lid_driven_cavity_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def setup_boundary_masker(self):
self.bc_mask, self.missing_mask = indices_boundary_masker(self.boundary_conditions, self.bc_mask, self.missing_mask)

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self, omega):
self.stepper = IncompressibleNavierStokesStepper(omega, boundary_conditions=self.boundary_conditions)
Expand Down
2 changes: 1 addition & 1 deletion examples/cfd/turbulent_channel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def initialize_fields(self):
u_init = jnp.full(shape=shape, fill_value=1e-2 * u_init)
else:
u_init = wp.array(1e-2 * u_init, dtype=self.precision_policy.compute_precision.wp_dtype)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend, u=u_init)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend, u=u_init)

def setup_stepper(self):
force = self.get_force()
Expand Down
2 changes: 1 addition & 1 deletion examples/cfd/windtunnel_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def setup_boundary_masker(self):
self.bc_mask, self.missing_mask = mesh_boundary_masker(bc_mesh, origin, spacing, self.bc_mask, self.missing_mask)

def initialize_fields(self):
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.backend)
self.f_0 = initialize_eq(self.f_0, self.grid, self.velocity_set, self.precision_policy, self.backend)

def setup_stepper(self):
self.stepper = IncompressibleNavierStokesStepper(self.omega, boundary_conditions=self.boundary_conditions, collision_type="KBC")
Expand Down
2 changes: 1 addition & 1 deletion examples/performance/mlups_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main():
backend, precision_policy = setup_simulation(args)
velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, backend=backend)
grid, f_0, f_1, missing_mask, bc_mask = create_grid_and_fields(args.cube_edge)
f_0 = initialize_eq(f_0, grid, velocity_set, backend)
f_0 = initialize_eq(f_0, grid, velocity_set, precision_policy, backend)

elapsed_time = run(f_0, f_1, backend, precision_policy, grid, bc_mask, missing_mask, args.num_steps)
mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time)
Expand Down
6 changes: 3 additions & 3 deletions xlb/helper/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from xlb.operator.equilibrium import QuadraticEquilibrium


def initialize_eq(f, grid, velocity_set, backend, rho=None, u=None):
def initialize_eq(f, grid, velocity_set, precision_policy, backend, rho=None, u=None):
if rho is None:
rho = grid.create_field(cardinality=1, fill_value=1.0)
rho = grid.create_field(cardinality=1, fill_value=1.0, dtype=precision_policy.compute_precision)
if u is None:
u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0)
u = grid.create_field(cardinality=velocity_set.d, fill_value=0.0, dtype=precision_policy.compute_precision)
equilibrium = QuadraticEquilibrium()

if backend == ComputeBackend.JAX:
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def kernel2d(

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -112,7 +112,7 @@ def kernel3d(

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def kernel2d(

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -137,7 +137,7 @@ def kernel3d(

# Write the result
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_extrapolation_outflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def kernel2d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand Down Expand Up @@ -270,7 +270,7 @@ def kernel3d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_fullway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def kernel2d(

# Write the result to the output
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -121,7 +121,7 @@ def kernel3d(

# Write the result to the output
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_halfway_bounce_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def kernel2d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -136,7 +136,7 @@ def kernel3d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_regularized.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def kernel2d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -370,7 +370,7 @@ def kernel3d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
if self.velocity_set.d == 3 and self.bc_type == "velocity":
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/boundary_condition/bc_zouhe.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def kernel2d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1]] = _f[l]
f_post[l, index[0], index[1]] = self.store_dtype(_f[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -378,7 +378,7 @@ def kernel3d(

# Write the distribution function
for l in range(self.velocity_set.q):
f_post[l, index[0], index[1], index[2]] = _f[l]
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
if self.velocity_set.d == 3 and self.bc_type == "velocity":
Expand Down
8 changes: 4 additions & 4 deletions xlb/operator/boundary_condition/boundary_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def _get_thread_data_2d(
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1]]
_f_post[l] = f_post[l, index[0], index[1]]
_f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1]])
_f_post[l] = self.compute_dtype(f_post[l, index[0], index[1]])

# TODO fix vec bool
if missing_mask[l, index[0], index[1]]:
Expand All @@ -106,8 +106,8 @@ def _get_thread_data_3d(
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of populations
_f_pre[l] = f_pre[l, index[0], index[1], index[2]]
_f_post[l] = f_post[l, index[0], index[1], index[2]]
_f_pre[l] = self.compute_dtype(f_pre[l, index[0], index[1], index[2]])
_f_post[l] = self.compute_dtype(f_post[l, index[0], index[1], index[2]])

# TODO fix vec bool
if missing_mask[l, index[0], index[1], index[2]]:
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/collision/bgk.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def kernel2d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1]] = _fout[l]
fout[l, index[0], index[1]] = self.store_dtype(_fout[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -86,7 +86,7 @@ def kernel3d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1], index[2]] = _fout[l]
fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/collision/kbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def kernel2d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1]] = _fout[l]
fout[l, index[0], index[1]] = self.store_dtype(_fout[l])

# Construct the warp kernel
@wp.kernel
Expand Down Expand Up @@ -369,7 +369,7 @@ def kernel3d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1], index[2]] = _fout[l]
fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l])

functional = functional3d if self.velocity_set.d == 3 else functional2d
kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/equilibrium/quadratic_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def kernel3d(

# Set the output
for l in range(self.velocity_set.q):
f[l, index[0], index[1], index[2]] = feq[l]
f[l, index[0], index[1], index[2]] = self.store_dtype(feq[l])

@wp.kernel
def kernel2d(
Expand All @@ -100,7 +100,7 @@ def kernel2d(

# Set the output
for l in range(self.velocity_set.q):
f[l, index[0], index[1]] = feq[l]
f[l, index[0], index[1]] = self.store_dtype(feq[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/force/exact_difference_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def kernel2d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1]] = _fout[l]
fout[l, index[0], index[1]] = self.store_dtype(_fout[l])

# Construct the warp kernel
@wp.kernel
Expand All @@ -134,7 +134,7 @@ def kernel3d(

# Write the result
for l in range(self.velocity_set.q):
fout[l, index[0], index[1], index[2]] = _fout[l]
fout[l, index[0], index[1], index[2]] = self.store_dtype(_fout[l])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d
return functional, kernel
Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/macroscopic/first_moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def kernel3d(
_u = functional(_f, _rho)

for d in range(self.velocity_set.d):
u[d, index[0], index[1], index[2]] = _u[d]
u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d])

@wp.kernel
def kernel2d(
Expand All @@ -71,7 +71,7 @@ def kernel2d(
_u = functional(_f, _rho)

for d in range(self.velocity_set.d):
u[d, index[0], index[1]] = _u[d]
u[d, index[0], index[1]] = self.store_dtype(_u[d])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
8 changes: 4 additions & 4 deletions xlb/operator/macroscopic/macroscopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def kernel3d(
_f[l] = f[l, index[0], index[1], index[2]]
_rho, _u = functional(_f)

rho[0, index[0], index[1], index[2]] = _rho
rho[0, index[0], index[1], index[2]] = self.store_dtype(_rho)
for d in range(self.velocity_set.d):
u[d, index[0], index[1], index[2]] = _u[d]
u[d, index[0], index[1], index[2]] = self.store_dtype(_u[d])

@wp.kernel
def kernel2d(
Expand All @@ -68,9 +68,9 @@ def kernel2d(
_f[l] = f[l, index[0], index[1]]
_rho, _u = functional(_f)

rho[0, index[0], index[1]] = _rho
rho[0, index[0], index[1]] = self.store_dtype(_rho)
for d in range(self.velocity_set.d):
u[d, index[0], index[1]] = _u[d]
u[d, index[0], index[1]] = self.store_dtype(_u[d])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/macroscopic/second_moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def kernel3d(

# Set the output
for d in range(_pi_dim):
pi[d, index[0], index[1], index[2]] = _pi[d]
pi[d, index[0], index[1], index[2]] = self.store_dtype(_pi[d])

@wp.kernel
def kernel2d(
Expand All @@ -114,7 +114,7 @@ def kernel2d(

# Set the output
for d in range(_pi_dim):
pi[d, index[0], index[1]] = _pi[d]
pi[d, index[0], index[1]] = self.store_dtype(_pi[d])

kernel = kernel3d if self.velocity_set.d == 3 else kernel2d

Expand Down
Loading

0 comments on commit 9496f6f

Please sign in to comment.