From 2adefe639ba688aa41726361f21416d000fe608b Mon Sep 17 00:00:00 2001 From: oliver Date: Sun, 19 Nov 2023 21:52:40 -0800 Subject: [PATCH 1/2] added countour plots --- examples/CFD/airfoil3d.py | 77 ++++++++++++++++++++++++++++++++------- src/base.py | 10 ++--- src/utils.py | 65 ++++++++++++++++++++++++++++++++- 3 files changed, 132 insertions(+), 20 deletions(-) diff --git a/examples/CFD/airfoil3d.py b/examples/CFD/airfoil3d.py index e601551..366339c 100644 --- a/examples/CFD/airfoil3d.py +++ b/examples/CFD/airfoil3d.py @@ -41,6 +41,9 @@ import jax import scipy +# PhantomGaze for rendering +import phantomgaze as pg + # Function to create a NACA airfoil shape given its length, thickness, and angle of attack def makeNacaAirfoil(length, thickness=30, angle=0): def nacaAirfoil(x, thickness, chordLength): @@ -74,6 +77,11 @@ def set_boundary_conditions(self): self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top'])) self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) + # Store wall boundary for visualization + self.boundary = jnp.zeros((self.nx, self.ny, self.nz), dtype=jnp.float32) + self.boundary = self.boundary.at[tuple(wall.T)].set(1.0) + self.boundary = self.boundary[2:-2, 2:-2, 2:-2] + doNothing = self.boundingBoxIndices['right'] self.BCs.append(DoNothing(tuple(doNothing.T), self.gridInfo, self.precisionPolicy)) @@ -85,20 +93,61 @@ def set_boundary_conditions(self): self.BCs.append(EquilibriumBC(tuple(inlet.T), self.gridInfo, self.precisionPolicy, rho_inlet, vel_inlet)) def output_data(self, **kwargs): - # 1:-1 to remove boundary voxels (not needed for visualization when using full-way bounce-back) - rho = np.array(kwargs['rho'][..., 1:-1, :]) - u = np.array(kwargs['u'][..., 1:-1, :]) - timestep = kwargs['timestep'] - u_prev = kwargs['u_prev'][..., 1:-1, :] - - u_old = np.linalg.norm(u_prev, axis=2) - u_new = np.linalg.norm(u, axis=2) + # Compute q-criterion and vorticity using finite differences + # Get velocity field + u = kwargs['u'][..., 1:-1, :] + + # vorticity and q-criterion + norm_mu, q = q_criterion(u) + + # Make phantomgaze volume + dx = 0.01 + origin = (0.0, 0.0, 0.0) + upper_bound = (self.boundary.shape[0] * dx, self.boundary.shape[1] * dx, self.boundary.shape[2] * dx) + q_volume = pg.objects.Volume( + q, + spacing=(dx, dx, dx), + origin=origin, + ) + norm_mu_volume = pg.objects.Volume( + norm_mu, + spacing=(dx, dx, dx), + origin=origin, + ) + boundary_volume = pg.objects.Volume( + self.boundary, + spacing=(dx, dx, dx), + origin=origin, + ) + + # Make colormap for norm_mu + colormap = pg.Colormap("jet", vmin=0.0, vmax=0.05) + + # Get camera parameters + focal_point = (self.boundary.shape[0] * dx / 2, self.boundary.shape[1] * dx / 2, self.boundary.shape[2] * dx / 2) + radius = 3.0 + angle = kwargs['timestep'] * 0.0001 + camera_position = (focal_point[0] + radius * np.sin(angle), focal_point[1], focal_point[2] + radius * np.cos(angle)) + + # Rotate camera + camera = pg.Camera(position=camera_position, focal_point=focal_point, view_up=(0.0, 1.0, 0.0), max_depth=30.0, height=1080, width=1920) + + # Make wireframe + screen_buffer = pg.render.wireframe(lower_bound=origin, upper_bound=upper_bound, thickness=0.01, camera=camera) + + # Render axes + screen_buffer = pg.render.axes(size=0.1, center=(0.0, 0.0, 1.1), camera=camera, screen_buffer=screen_buffer) + + # Render q-criterion + screen_buffer = pg.render.contour(q_volume, threshold=0.00003, color=norm_mu_volume, colormap=colormap, camera=camera, screen_buffer=screen_buffer) + + # Render boundary + boundary_colormap = pg.Colormap("Greys_r", vmin=0.0, vmax=3.0, opacity=np.linspace(0.0, 6.0, 256)) # This will make it grey + screen_buffer = pg.render.volume(boundary_volume, camera=camera, colormap=boundary_colormap, screen_buffer=screen_buffer) + + # Show the rendered image + plt.imsave('q_criterion_' + str(kwargs['timestep']).zfill(7) + '.png', np.minimum(screen_buffer.image.get(), 1.0)) - err = np.sum(np.abs(u_old - u_new)) - print('error= {:07.6f}'.format(err)) - # save_image(timestep, rho, u) - fields = {"rho": rho[..., 0], "u_x": u[..., 0], "u_y": u[..., 1], "u_z": u[..., 2]} - save_fields_vtk(timestep, fields) if __name__ == '__main__': airfoil_length = 101 @@ -142,4 +191,4 @@ def output_data(self, **kwargs): sim = Airfoil(**kwargs) print('Domain size: ', sim.nx, sim.ny, sim.nz) - sim.run(20000) \ No newline at end of file + sim.run(20000) diff --git a/src/base.py b/src/base.py index c4dbf36..5f7370f 100644 --- a/src/base.py +++ b/src/base.py @@ -792,8 +792,8 @@ def run(self, t_max): rho_prev = downsample_field(rho_prev, self.downsamplingFactor) u_prev = downsample_field(u_prev, self.downsamplingFactor) # Gather the data from all processes and convert it to numpy arrays (move to host memory) - rho_prev = np.array(process_allgather(rho_prev)) - u_prev = np.array(process_allgather(u_prev)) + rho_prev = process_allgather(rho_prev) + u_prev = process_allgather(u_prev) # Perform one time-step (collision, streaming, and boundary conditions) @@ -810,8 +810,8 @@ def run(self, t_max): u = downsample_field(u, self.downsamplingFactor) # Gather the data from all processes and convert it to numpy arrays (move to host memory) - rho = np.array(process_allgather(rho)) - u = np.array(process_allgather(u)) + rho = process_allgather(rho) + u = process_allgather(u) # Save the data self.handle_io_timestep(timestep, f, fstar, rho, u, rho_prev, u_prev) @@ -978,4 +978,4 @@ def apply_force(self, f_postcollision, feq, rho, u): return f_postcollision - \ No newline at end of file + diff --git a/src/utils.py b/src/utils.py index d7b2eb1..b9105ea 100644 --- a/src/utils.py +++ b/src/utils.py @@ -348,4 +348,67 @@ def axangle2mat(axis, angle, is_normalized=False): return jnp.array([ [x * xC + c, xyC - zs, zxC + ys], [xyC + zs, y * yC + c, yzC - xs], - [zxC - ys, yzC + xs, z * zC + c]]) \ No newline at end of file + [zxC - ys, yzC + xs, z * zC + c]]) + +@partial(jit) +def q_criterion(u): + # Compute derivatives + u_x = u[..., 0] + u_y = u[..., 1] + u_z = u[..., 2] + + # Compute derivatives + u_x_dx = (u_x[2:, 1:-1, 1:-1] - u_x[:-2, 1:-1, 1:-1]) / 2 + u_x_dy = (u_x[1:-1, 2:, 1:-1] - u_x[1:-1, :-2, 1:-1]) / 2 + u_x_dz = (u_x[1:-1, 1:-1, 2:] - u_x[1:-1, 1:-1, :-2]) / 2 + u_y_dx = (u_y[2:, 1:-1, 1:-1] - u_y[:-2, 1:-1, 1:-1]) / 2 + u_y_dy = (u_y[1:-1, 2:, 1:-1] - u_y[1:-1, :-2, 1:-1]) / 2 + u_y_dz = (u_y[1:-1, 1:-1, 2:] - u_y[1:-1, 1:-1, :-2]) / 2 + u_z_dx = (u_z[2:, 1:-1, 1:-1] - u_z[:-2, 1:-1, 1:-1]) / 2 + u_z_dy = (u_z[1:-1, 2:, 1:-1] - u_z[1:-1, :-2, 1:-1]) / 2 + u_z_dz = (u_z[1:-1, 1:-1, 2:] - u_z[1:-1, 1:-1, :-2]) / 2 + + # Compute vorticity + mu_x = u_z_dy - u_y_dz + mu_y = u_x_dz - u_z_dx + mu_z = u_y_dx - u_x_dy + norm_mu = jnp.sqrt(mu_x ** 2 + mu_y ** 2 + mu_z ** 2) + + # Compute strain rate + s_0_0 = u_x_dx + s_0_1 = 0.5 * (u_x_dy + u_y_dx) + s_0_2 = 0.5 * (u_x_dz + u_z_dx) + s_1_0 = s_0_1 + s_1_1 = u_y_dy + s_1_2 = 0.5 * (u_y_dz + u_z_dy) + s_2_0 = s_0_2 + s_2_1 = s_1_2 + s_2_2 = u_z_dz + s_dot_s = ( + s_0_0 ** 2 + s_0_1 ** 2 + s_0_2 ** 2 + + s_1_0 ** 2 + s_1_1 ** 2 + s_1_2 ** 2 + + s_2_0 ** 2 + s_2_1 ** 2 + s_2_2 ** 2 + ) + + # Compute omega + omega_0_0 = 0.0 + omega_0_1 = 0.5 * (u_x_dy - u_y_dx) + omega_0_2 = 0.5 * (u_x_dz - u_z_dx) + omega_1_0 = -omega_0_1 + omega_1_1 = 0.0 + omega_1_2 = 0.5 * (u_y_dz - u_z_dy) + omega_2_0 = -omega_0_2 + omega_2_1 = -omega_1_2 + omega_2_2 = 0.0 + omega_dot_omega = ( + omega_0_0 ** 2 + omega_0_1 ** 2 + omega_0_2 ** 2 + + omega_1_0 ** 2 + omega_1_1 ** 2 + omega_1_2 ** 2 + + omega_2_0 ** 2 + omega_2_1 ** 2 + omega_2_2 ** 2 + ) + + # Compute q-criterion + q = 0.5 * (omega_dot_omega - s_dot_s) + + return norm_mu, q + + From 37824c59cca9cf271f141f7d7ae4893e2f2fe8d5 Mon Sep 17 00:00:00 2001 From: Mehdi Ataeei Date: Thu, 23 Nov 2023 14:04:46 -0500 Subject: [PATCH 2/2] Fixed the airfoil shape problem --- examples/CFD/airfoil3d.py | 57 +++++++++++++++++++++------------------ requirements.txt | 18 ++++++------- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/examples/CFD/airfoil3d.py b/examples/CFD/airfoil3d.py index 366339c..16cedfe 100644 --- a/examples/CFD/airfoil3d.py +++ b/examples/CFD/airfoil3d.py @@ -21,9 +21,9 @@ 4. Simulation Parameters: The example allows for the setting of various simulation parameters, including the Reynolds number, inlet velocity, and characteristic length. -5. Visualization: The example outputs data in VTK format, which can be visualized using software such - as Paraview. The error between the old and new velocity fields is also printed out at each time step - to monitor the convergence of the solution. +5. In-situ visualization: The example outputs rendering images of the q-criterion using + PhantomGaze library (https://github.com/loliverhennigh/PhantomGaze) without any I/O overhead + while the data is still on the GPU. """ @@ -41,23 +41,30 @@ import jax import scipy -# PhantomGaze for rendering +# PhantomGaze for in-situ rendering import phantomgaze as pg -# Function to create a NACA airfoil shape given its length, thickness, and angle of attack def makeNacaAirfoil(length, thickness=30, angle=0): def nacaAirfoil(x, thickness, chordLength): coeffs = [0.2969, -0.1260, -0.3516, 0.2843, -0.1015] exponents = [0.5, 1, 2, 3, 4] - af = [coeff * (x / chordLength) ** exp for coeff, exp in zip(coeffs, exponents)] - return 5. * thickness / 100 * chordLength * np.sum(af) + yt = [coeff * (x / chordLength) ** exp for coeff, exp in zip(coeffs, exponents)] + yt = 5. * thickness / 100 * chordLength * np.sum(yt) - x = np.arange(length) - y = np.arange(-int(length * thickness / 200), int(length * thickness / 200)) - xx, yy = np.meshgrid(x, y) - domain = np.where(np.abs(yy) < nacaAirfoil(xx, thickness, length), 1, 0).T + return yt - domain = scipy.ndimage.rotate(np.rot90(domain), -angle) + x = np.linspace(0, length, num=length) + yt = np.array([nacaAirfoil(xi, thickness, length) for xi in x]) + + y_max = int(np.max(yt)) + 1 + domain = np.zeros((2 * y_max, len(x)), dtype=int) + + for i, xi in enumerate(x): + upper_bound = int(y_max + yt[i]) + lower_bound = int(y_max - yt[i]) + domain[lower_bound:upper_bound, i] = 1 + + domain = scipy.ndimage.rotate(domain, angle, reshape=True) domain = np.where(domain > 0.5, 1, 0) return domain @@ -77,10 +84,9 @@ def set_boundary_conditions(self): self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top'])) self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy)) - # Store wall boundary for visualization - self.boundary = jnp.zeros((self.nx, self.ny, self.nz), dtype=jnp.float32) - self.boundary = self.boundary.at[tuple(wall.T)].set(1.0) - self.boundary = self.boundary[2:-2, 2:-2, 2:-2] + # Store airfoil boundary for visualization + self.visualization_bc = jnp.zeros((self.nx, self.ny, self.nz), dtype=jnp.float32) + self.visualization_bc = self.visualization_bc.at[tuple(airfoil_indices.T)].set(1.0) doNothing = self.boundingBoxIndices['right'] self.BCs.append(DoNothing(tuple(doNothing.T), self.gridInfo, self.precisionPolicy)) @@ -94,16 +100,15 @@ def set_boundary_conditions(self): def output_data(self, **kwargs): # Compute q-criterion and vorticity using finite differences - # Get velocity field + # Get velocity field u = kwargs['u'][..., 1:-1, :] - # vorticity and q-criterion norm_mu, q = q_criterion(u) # Make phantomgaze volume dx = 0.01 origin = (0.0, 0.0, 0.0) - upper_bound = (self.boundary.shape[0] * dx, self.boundary.shape[1] * dx, self.boundary.shape[2] * dx) + upper_bound = (self.visualization_bc.shape[0] * dx, self.visualization_bc.shape[1] * dx, self.visualization_bc.shape[2] * dx) q_volume = pg.objects.Volume( q, spacing=(dx, dx, dx), @@ -115,7 +120,7 @@ def output_data(self, **kwargs): origin=origin, ) boundary_volume = pg.objects.Volume( - self.boundary, + self.visualization_bc, spacing=(dx, dx, dx), origin=origin, ) @@ -124,13 +129,13 @@ def output_data(self, **kwargs): colormap = pg.Colormap("jet", vmin=0.0, vmax=0.05) # Get camera parameters - focal_point = (self.boundary.shape[0] * dx / 2, self.boundary.shape[1] * dx / 2, self.boundary.shape[2] * dx / 2) - radius = 3.0 + focal_point = (self.visualization_bc.shape[0] * dx / 2, self.visualization_bc.shape[1] * dx / 2, self.visualization_bc.shape[2] * dx / 2) + radius = 5.0 angle = kwargs['timestep'] * 0.0001 camera_position = (focal_point[0] + radius * np.sin(angle), focal_point[1], focal_point[2] + radius * np.cos(angle)) # Rotate camera - camera = pg.Camera(position=camera_position, focal_point=focal_point, view_up=(0.0, 1.0, 0.0), max_depth=30.0, height=1080, width=1920) + camera = pg.Camera(position=camera_position, focal_point=focal_point, view_up=(0.0, 1.0, 0.0), max_depth=30.0, height=1080, width=1920, background=pg.SolidBackground(color=(0.0, 0.0, 0.0))) # Make wireframe screen_buffer = pg.render.wireframe(lower_bound=origin, upper_bound=upper_bound, thickness=0.01, camera=camera) @@ -142,7 +147,7 @@ def output_data(self, **kwargs): screen_buffer = pg.render.contour(q_volume, threshold=0.00003, color=norm_mu_volume, colormap=colormap, camera=camera, screen_buffer=screen_buffer) # Render boundary - boundary_colormap = pg.Colormap("Greys_r", vmin=0.0, vmax=3.0, opacity=np.linspace(0.0, 6.0, 256)) # This will make it grey + boundary_colormap = pg.Colormap("bone_r", vmin=0.0, vmax=3.0, opacity=np.linspace(0.0, 6.0, 256)) screen_buffer = pg.render.volume(boundary_volume, camera=camera, colormap=boundary_colormap, screen_buffer=screen_buffer) # Show the rendered image @@ -164,10 +169,10 @@ def output_data(self, **kwargs): print("airfoil shape: ", airfoil.shape) ny = 3 * ny - nx = 4 * nx + nx = 5 * nx nz = 101 - Re = 10000.0 + Re = 30000.0 prescribed_vel = 0.1 clength = airfoil_length diff --git a/requirements.txt b/requirements.txt index 2c912ed..11ee0fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -jax==0.4.11 -jaxlib==0.4.11 +jax==0.4.20 +jaxlib==0.4.20 jmp==0.0.4 -matplotlib==3.7.1 -numpy==1.24.2 -pyvista==0.38.5 +matplotlib==3.8.0 +numpy==1.26.1 +pyvista==0.42.3 Rtree==1.0.1 -trimesh==3.20.2 -orbax-checkpoint==0.2.3 -portpicker===1.5.2 -termcolor==2.3.0 \ No newline at end of file +trimesh==4.0.0 +orbax-checkpoint==0.4.1 +termcolor==2.3.0 +PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git \ No newline at end of file