Skip to content

Commit

Permalink
added countour plots
Browse files Browse the repository at this point in the history
  • Loading branch information
loliverhennigh committed Nov 20, 2023
1 parent 3cd7fd0 commit 2adefe6
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 20 deletions.
77 changes: 63 additions & 14 deletions examples/CFD/airfoil3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
import jax
import scipy

# PhantomGaze for rendering
import phantomgaze as pg

# Function to create a NACA airfoil shape given its length, thickness, and angle of attack
def makeNacaAirfoil(length, thickness=30, angle=0):
def nacaAirfoil(x, thickness, chordLength):
Expand Down Expand Up @@ -74,6 +77,11 @@ def set_boundary_conditions(self):
self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top']))
self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy))

# Store wall boundary for visualization
self.boundary = jnp.zeros((self.nx, self.ny, self.nz), dtype=jnp.float32)
self.boundary = self.boundary.at[tuple(wall.T)].set(1.0)
self.boundary = self.boundary[2:-2, 2:-2, 2:-2]

doNothing = self.boundingBoxIndices['right']
self.BCs.append(DoNothing(tuple(doNothing.T), self.gridInfo, self.precisionPolicy))

Expand All @@ -85,20 +93,61 @@ def set_boundary_conditions(self):
self.BCs.append(EquilibriumBC(tuple(inlet.T), self.gridInfo, self.precisionPolicy, rho_inlet, vel_inlet))

def output_data(self, **kwargs):
# 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back)
rho = np.array(kwargs['rho'][..., 1:-1, :])
u = np.array(kwargs['u'][..., 1:-1, :])
timestep = kwargs['timestep']
u_prev = kwargs['u_prev'][..., 1:-1, :]

u_old = np.linalg.norm(u_prev, axis=2)
u_new = np.linalg.norm(u, axis=2)
# Compute q-criterion and vorticity using finite differences
# Get velocity field
u = kwargs['u'][..., 1:-1, :]

# vorticity and q-criterion
norm_mu, q = q_criterion(u)

# Make phantomgaze volume
dx = 0.01
origin = (0.0, 0.0, 0.0)
upper_bound = (self.boundary.shape[0] * dx, self.boundary.shape[1] * dx, self.boundary.shape[2] * dx)
q_volume = pg.objects.Volume(
q,
spacing=(dx, dx, dx),
origin=origin,
)
norm_mu_volume = pg.objects.Volume(
norm_mu,
spacing=(dx, dx, dx),
origin=origin,
)
boundary_volume = pg.objects.Volume(
self.boundary,
spacing=(dx, dx, dx),
origin=origin,
)

# Make colormap for norm_mu
colormap = pg.Colormap("jet", vmin=0.0, vmax=0.05)

# Get camera parameters
focal_point = (self.boundary.shape[0] * dx / 2, self.boundary.shape[1] * dx / 2, self.boundary.shape[2] * dx / 2)
radius = 3.0
angle = kwargs['timestep'] * 0.0001
camera_position = (focal_point[0] + radius * np.sin(angle), focal_point[1], focal_point[2] + radius * np.cos(angle))

# Rotate camera
camera = pg.Camera(position=camera_position, focal_point=focal_point, view_up=(0.0, 1.0, 0.0), max_depth=30.0, height=1080, width=1920)

# Make wireframe
screen_buffer = pg.render.wireframe(lower_bound=origin, upper_bound=upper_bound, thickness=0.01, camera=camera)

# Render axes
screen_buffer = pg.render.axes(size=0.1, center=(0.0, 0.0, 1.1), camera=camera, screen_buffer=screen_buffer)

# Render q-criterion
screen_buffer = pg.render.contour(q_volume, threshold=0.00003, color=norm_mu_volume, colormap=colormap, camera=camera, screen_buffer=screen_buffer)

# Render boundary
boundary_colormap = pg.Colormap("Greys_r", vmin=0.0, vmax=3.0, opacity=np.linspace(0.0, 6.0, 256)) # This will make it grey
screen_buffer = pg.render.volume(boundary_volume, camera=camera, colormap=boundary_colormap, screen_buffer=screen_buffer)

# Show the rendered image
plt.imsave('q_criterion_' + str(kwargs['timestep']).zfill(7) + '.png', np.minimum(screen_buffer.image.get(), 1.0))

err = np.sum(np.abs(u_old - u_new))
print('error= {:07.6f}'.format(err))
# save_image(timestep, rho, u)
fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]}
save_fields_vtk(timestep, fields)

if __name__ == '__main__':
airfoil_length = 101
Expand Down Expand Up @@ -142,4 +191,4 @@ def output_data(self, **kwargs):

sim = Airfoil(**kwargs)
print('Domain size: ', sim.nx, sim.ny, sim.nz)
sim.run(20000)
sim.run(20000)
10 changes: 5 additions & 5 deletions src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,8 +792,8 @@ def run(self, t_max):
rho_prev = downsample_field(rho_prev, self.downsamplingFactor)
u_prev = downsample_field(u_prev, self.downsamplingFactor)
# Gather the data from all processes and convert it to numpy arrays (move to host memory)
rho_prev = np.array(process_allgather(rho_prev))
u_prev = np.array(process_allgather(u_prev))
rho_prev = process_allgather(rho_prev)
u_prev = process_allgather(u_prev)


# Perform one time-step (collision, streaming, and boundary conditions)
Expand All @@ -810,8 +810,8 @@ def run(self, t_max):
u = downsample_field(u, self.downsamplingFactor)

# Gather the data from all processes and convert it to numpy arrays (move to host memory)
rho = np.array(process_allgather(rho))
u = np.array(process_allgather(u))
rho = process_allgather(rho)
u = process_allgather(u)

# Save the data
self.handle_io_timestep(timestep, f, fstar, rho, u, rho_prev, u_prev)
Expand Down Expand Up @@ -978,4 +978,4 @@ def apply_force(self, f_postcollision, feq, rho, u):
return f_postcollision




65 changes: 64 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,4 +348,67 @@ def axangle2mat(axis, angle, is_normalized=False):
return jnp.array([
[x * xC + c, xyC - zs, zxC + ys],
[xyC + zs, y * yC + c, yzC - xs],
[zxC - ys, yzC + xs, z * zC + c]])
[zxC - ys, yzC + xs, z * zC + c]])

@partial(jit)
def q_criterion(u):
# Compute derivatives
u_x = u[..., 0]
u_y = u[..., 1]
u_z = u[..., 2]

# Compute derivatives
u_x_dx = (u_x[2:, 1:-1, 1:-1] - u_x[:-2, 1:-1, 1:-1]) / 2
u_x_dy = (u_x[1:-1, 2:, 1:-1] - u_x[1:-1, :-2, 1:-1]) / 2
u_x_dz = (u_x[1:-1, 1:-1, 2:] - u_x[1:-1, 1:-1, :-2]) / 2
u_y_dx = (u_y[2:, 1:-1, 1:-1] - u_y[:-2, 1:-1, 1:-1]) / 2
u_y_dy = (u_y[1:-1, 2:, 1:-1] - u_y[1:-1, :-2, 1:-1]) / 2
u_y_dz = (u_y[1:-1, 1:-1, 2:] - u_y[1:-1, 1:-1, :-2]) / 2
u_z_dx = (u_z[2:, 1:-1, 1:-1] - u_z[:-2, 1:-1, 1:-1]) / 2
u_z_dy = (u_z[1:-1, 2:, 1:-1] - u_z[1:-1, :-2, 1:-1]) / 2
u_z_dz = (u_z[1:-1, 1:-1, 2:] - u_z[1:-1, 1:-1, :-2]) / 2

# Compute vorticity
mu_x = u_z_dy - u_y_dz
mu_y = u_x_dz - u_z_dx
mu_z = u_y_dx - u_x_dy
norm_mu = jnp.sqrt(mu_x ** 2 + mu_y ** 2 + mu_z ** 2)

# Compute strain rate
s_0_0 = u_x_dx
s_0_1 = 0.5 * (u_x_dy + u_y_dx)
s_0_2 = 0.5 * (u_x_dz + u_z_dx)
s_1_0 = s_0_1
s_1_1 = u_y_dy
s_1_2 = 0.5 * (u_y_dz + u_z_dy)
s_2_0 = s_0_2
s_2_1 = s_1_2
s_2_2 = u_z_dz
s_dot_s = (
s_0_0 ** 2 + s_0_1 ** 2 + s_0_2 ** 2 +
s_1_0 ** 2 + s_1_1 ** 2 + s_1_2 ** 2 +
s_2_0 ** 2 + s_2_1 ** 2 + s_2_2 ** 2
)

# Compute omega
omega_0_0 = 0.0
omega_0_1 = 0.5 * (u_x_dy - u_y_dx)
omega_0_2 = 0.5 * (u_x_dz - u_z_dx)
omega_1_0 = -omega_0_1
omega_1_1 = 0.0
omega_1_2 = 0.5 * (u_y_dz - u_z_dy)
omega_2_0 = -omega_0_2
omega_2_1 = -omega_1_2
omega_2_2 = 0.0
omega_dot_omega = (
omega_0_0 ** 2 + omega_0_1 ** 2 + omega_0_2 ** 2 +
omega_1_0 ** 2 + omega_1_1 ** 2 + omega_1_2 ** 2 +
omega_2_0 ** 2 + omega_2_1 ** 2 + omega_2_2 ** 2
)

# Compute q-criterion
q = 0.5 * (omega_dot_omega - s_dot_s)

return norm_mu, q


0 comments on commit 2adefe6

Please sign in to comment.