Skip to content

Commit

Permalink
Merge pull request #1 from mehdiataei/example_contour
Browse files Browse the repository at this point in the history
Fixed the airfoil shape problem
  • Loading branch information
loliverhennigh authored Nov 27, 2023
2 parents 2adefe6 + 37824c5 commit 4d4f553
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 35 deletions.
57 changes: 31 additions & 26 deletions examples/CFD/airfoil3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
4. Simulation Parameters: The example allows for the setting of various simulation parameters,
including the Reynolds number, inlet velocity, and characteristic length.
5. Visualization: The example outputs data in VTK format, which can be visualized using software such
as Paraview. The error between the old and new velocity fields is also printed out at each time step
to monitor the convergence of the solution.
5. In-situ visualization: The example outputs rendering images of the q-criterion using
PhantomGaze library (https://github.com/loliverhennigh/PhantomGaze) without any I/O overhead
while the data is still on the GPU.
"""


Expand All @@ -41,23 +41,30 @@
import jax
import scipy

# PhantomGaze for rendering
# PhantomGaze for in-situ rendering
import phantomgaze as pg

# Function to create a NACA airfoil shape given its length, thickness, and angle of attack
def makeNacaAirfoil(length, thickness=30, angle=0):
def nacaAirfoil(x, thickness, chordLength):
coeffs = [0.2969, -0.1260, -0.3516, 0.2843, -0.1015]
exponents = [0.5, 1, 2, 3, 4]
af = [coeff * (x / chordLength) ** exp for coeff, exp in zip(coeffs, exponents)]
return 5. * thickness / 100 * chordLength * np.sum(af)
yt = [coeff * (x / chordLength) ** exp for coeff, exp in zip(coeffs, exponents)]
yt = 5. * thickness / 100 * chordLength * np.sum(yt)

x = np.arange(length)
y = np.arange(-int(length * thickness / 200), int(length * thickness / 200))
xx, yy = np.meshgrid(x, y)
domain = np.where(np.abs(yy) < nacaAirfoil(xx, thickness, length), 1, 0).T
return yt

domain = scipy.ndimage.rotate(np.rot90(domain), -angle)
x = np.linspace(0, length, num=length)
yt = np.array([nacaAirfoil(xi, thickness, length) for xi in x])

y_max = int(np.max(yt)) + 1
domain = np.zeros((2 * y_max, len(x)), dtype=int)

for i, xi in enumerate(x):
upper_bound = int(y_max + yt[i])
lower_bound = int(y_max - yt[i])
domain[lower_bound:upper_bound, i] = 1

domain = scipy.ndimage.rotate(domain, angle, reshape=True)
domain = np.where(domain > 0.5, 1, 0)

return domain
Expand All @@ -77,10 +84,9 @@ def set_boundary_conditions(self):
self.boundingBoxIndices['bottom'], self.boundingBoxIndices['top']))
self.BCs.append(BounceBack(tuple(wall.T), self.gridInfo, self.precisionPolicy))

# Store wall boundary for visualization
self.boundary = jnp.zeros((self.nx, self.ny, self.nz), dtype=jnp.float32)
self.boundary = self.boundary.at[tuple(wall.T)].set(1.0)
self.boundary = self.boundary[2:-2, 2:-2, 2:-2]
# Store airfoil boundary for visualization
self.visualization_bc = jnp.zeros((self.nx, self.ny, self.nz), dtype=jnp.float32)
self.visualization_bc = self.visualization_bc.at[tuple(airfoil_indices.T)].set(1.0)

doNothing = self.boundingBoxIndices['right']
self.BCs.append(DoNothing(tuple(doNothing.T), self.gridInfo, self.precisionPolicy))
Expand All @@ -94,16 +100,15 @@ def set_boundary_conditions(self):

def output_data(self, **kwargs):
# Compute q-criterion and vorticity using finite differences
# Get velocity field
# Get velocity field
u = kwargs['u'][..., 1:-1, :]

# vorticity and q-criterion
norm_mu, q = q_criterion(u)

# Make phantomgaze volume
dx = 0.01
origin = (0.0, 0.0, 0.0)
upper_bound = (self.boundary.shape[0] * dx, self.boundary.shape[1] * dx, self.boundary.shape[2] * dx)
upper_bound = (self.visualization_bc.shape[0] * dx, self.visualization_bc.shape[1] * dx, self.visualization_bc.shape[2] * dx)
q_volume = pg.objects.Volume(
q,
spacing=(dx, dx, dx),
Expand All @@ -115,7 +120,7 @@ def output_data(self, **kwargs):
origin=origin,
)
boundary_volume = pg.objects.Volume(
self.boundary,
self.visualization_bc,
spacing=(dx, dx, dx),
origin=origin,
)
Expand All @@ -124,13 +129,13 @@ def output_data(self, **kwargs):
colormap = pg.Colormap("jet", vmin=0.0, vmax=0.05)

# Get camera parameters
focal_point = (self.boundary.shape[0] * dx / 2, self.boundary.shape[1] * dx / 2, self.boundary.shape[2] * dx / 2)
radius = 3.0
focal_point = (self.visualization_bc.shape[0] * dx / 2, self.visualization_bc.shape[1] * dx / 2, self.visualization_bc.shape[2] * dx / 2)
radius = 5.0
angle = kwargs['timestep'] * 0.0001
camera_position = (focal_point[0] + radius * np.sin(angle), focal_point[1], focal_point[2] + radius * np.cos(angle))

# Rotate camera
camera = pg.Camera(position=camera_position, focal_point=focal_point, view_up=(0.0, 1.0, 0.0), max_depth=30.0, height=1080, width=1920)
camera = pg.Camera(position=camera_position, focal_point=focal_point, view_up=(0.0, 1.0, 0.0), max_depth=30.0, height=1080, width=1920, background=pg.SolidBackground(color=(0.0, 0.0, 0.0)))

# Make wireframe
screen_buffer = pg.render.wireframe(lower_bound=origin, upper_bound=upper_bound, thickness=0.01, camera=camera)
Expand All @@ -142,7 +147,7 @@ def output_data(self, **kwargs):
screen_buffer = pg.render.contour(q_volume, threshold=0.00003, color=norm_mu_volume, colormap=colormap, camera=camera, screen_buffer=screen_buffer)

# Render boundary
boundary_colormap = pg.Colormap("Greys_r", vmin=0.0, vmax=3.0, opacity=np.linspace(0.0, 6.0, 256)) # This will make it grey
boundary_colormap = pg.Colormap("bone_r", vmin=0.0, vmax=3.0, opacity=np.linspace(0.0, 6.0, 256))
screen_buffer = pg.render.volume(boundary_volume, camera=camera, colormap=boundary_colormap, screen_buffer=screen_buffer)

# Show the rendered image
Expand All @@ -164,10 +169,10 @@ def output_data(self, **kwargs):
print("airfoil shape: ", airfoil.shape)

ny = 3 * ny
nx = 4 * nx
nx = 5 * nx
nz = 101

Re = 10000.0
Re = 30000.0
prescribed_vel = 0.1
clength = airfoil_length

Expand Down
18 changes: 9 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
jax==0.4.11
jaxlib==0.4.11
jax==0.4.20
jaxlib==0.4.20
jmp==0.0.4
matplotlib==3.7.1
numpy==1.24.2
pyvista==0.38.5
matplotlib==3.8.0
numpy==1.26.1
pyvista==0.42.3
Rtree==1.0.1
trimesh==3.20.2
orbax-checkpoint==0.2.3
portpicker===1.5.2
termcolor==2.3.0
trimesh==4.0.0
orbax-checkpoint==0.4.1
termcolor==2.3.0
PhantomGaze @ git+https://github.com/loliverhennigh/PhantomGaze.git

0 comments on commit 4d4f553

Please sign in to comment.