Skip to content

Commit

Permalink
FEM prints out B.C. info if necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
tianjuxue committed Jun 11, 2023
1 parent facfb62 commit 1d34e86
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 12 deletions.
2 changes: 1 addition & 1 deletion applications/cfd/gamma/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def simulation():

gamma_args = {}

#laser parameter
# laser parameter
gamma_args['eta'] = 0.43
gamma_args['r'] = 5e-5
gamma_args['rho'] = 8440
Expand Down
4 changes: 2 additions & 2 deletions jax_am/fem/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_shape_vals_and_grads(ele_type):
vals_and_grads = element.tabulate(1, quad_points)[:, :, re_order, :]
shape_values = vals_and_grads[0, :, :, 0]
shape_grads_ref = onp.transpose(vals_and_grads[1:, :, :, 0], axes=(1, 2, 0))
print(f"ele_type = {ele_type}, quad_points.shape = {quad_points.shape}")
print(f"ele_type = {ele_type}, quad_points.shape = (num_quads, dim) = {quad_points.shape}")
return shape_values, shape_grads_ref, weights


Expand Down Expand Up @@ -179,5 +179,5 @@ def get_face_shape_vals_and_grads(ele_type):
face_shape_vals = vals_and_grads[0, :, :, 0].reshape(num_faces, num_face_quads, -1)
face_shape_grads_ref = vals_and_grads[1:, :, :, 0].reshape(dim, num_faces, num_face_quads, -1)
face_shape_grads_ref = onp.transpose(face_shape_grads_ref, axes=(1, 2, 3, 0))
print(f"face_quad_points.shape = {face_quad_points.shape}")
print(f"face_quad_points.shape = (num_faces, num_face_quads, dim) = {face_quad_points.shape}")
return face_shape_vals, face_shape_grads_ref, face_weights, face_normals, face_inds
32 changes: 31 additions & 1 deletion jax_am/fem/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def get_boundary_conditions_inds(self, location_fns):
-------
boundary_inds_list : List[onp.ndarray]
(num_selected_faces, 2)
boundary_inds_list[k][i, j] returns the index of face j of cell i of surface k
boundary_inds_list[k][i, 0] returns the global cell index of the ith selected face of boundary subset k
boundary_inds_list[k][i, 1] returns the local face index of the ith selected face of boundary subset k
"""
cell_points = onp.take(self.points, self.cells, axis=0) # (num_cells, num_nodes, dim)
cell_face_points = onp.take(cell_points, self.face_inds, axis=1) # (num_cells, num_faces, num_face_nodes, dim)
Expand Down Expand Up @@ -684,3 +685,32 @@ def set_params(self, params):
"""Used for solving inverse problems.
"""
raise NotImplementedError("Child class must implement this function!")

def print_BC_info(self):
"""Print boundary condition information for debugging purposes.
"""
if hasattr(self, 'neumann_boundary_inds_list'):
print(f"\n\n### Neumann B.C. is specified")
for i in range(len(self.neumann_boundary_inds_list)):
print(f"\nNeumann Boundary part {i + 1} information:")
print(self.neumann_boundary_inds_list[i])
print(f"Array.shape = (num_selected_faces, 2) = {self.neumann_boundary_inds_list[i].shape}")
print(f"Interpretation:")
print(f" Array[i, 0] returns the global cell index of the ith selected face")
print(f" Array[i, 1] returns the local face index of the ith selected face")
else:
print(f"\n\n### No Neumann B.C. found.")

if len(self.node_inds_list) != 0:
print(f"\n\n### Dirichlet B.C. is specified")
for i in range(len(self.node_inds_list)):
print(f"\nDirichlet Boundary part {i + 1} information:")
bc_array = onp.stack([self.node_inds_list[i], self.vec_inds_list[i], self.vals_list[i]]).T
print(bc_array)
print(f"Array.shape = (num_selected_dofs, 3) = {bc_array.shape}")
print(f"Interpretation:")
print(f" Array[i, 0] returns the node index of the ith selected dof")
print(f" Array[i, 1] returns the vec index of the ith selected dof")
print(f" Array[i, 2] returns the value assigned to ith selected dof")
else:
print(f"\n\n### No Dirichlet B.C. found.")
2 changes: 2 additions & 0 deletions jax_am/fem/generate_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def box_mesh(Nx, Ny, Nz, Lx, Ly, Lz, data_dir, ele_type='HEX8'):
https://gitlab.onelab.info/gmsh/gmsh/-/blob/master/examples/api/hex.py
https://gitlab.onelab.info/gmsh/gmsh/-/blob/gmsh_4_7_1/tutorial/python/t1.py
https://gitlab.onelab.info/gmsh/gmsh/-/blob/gmsh_4_7_1/tutorial/python/t3.py
Accepts ele_type = 'HEX8', 'TET4' or 'TET10'
"""

assert ele_type != 'HEX20', f"gmsh cannot produce HEX20 mesh?"
Expand Down
16 changes: 8 additions & 8 deletions jax_am/fem/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,9 @@ def linear_incremental_solver(problem, res_vec, A_fn, dofs, precond, use_petsc):
def get_A_fn(problem, use_petsc):
print(f"Creating sparse matrix with scipy...")
A_sp_scipy = scipy.sparse.csr_array((problem.V, (problem.I, problem.J)), shape=(problem.num_total_dofs, problem.num_total_dofs))
print(f"Creating sparse matrix from scipy using JAX BCOO...")
# print(f"Creating sparse matrix from scipy using JAX BCOO...")
A_sp = BCOO.from_scipy_sparse(A_sp_scipy).sort_indices()
print(f"self.A_sp.data.shape = {A_sp.data.shape}")
print(f"Global sparse matrix takes about {A_sp.data.shape[0]*8*3/2**30} G memory to store.")
# print(f"Global sparse matrix takes about {A_sp.data.shape[0]*8*3/2**30} G memory to store.")
problem.A_sp_scipy = A_sp_scipy

def compute_linearized_residual(dofs):
Expand Down Expand Up @@ -417,10 +416,9 @@ def symmetry(I, J, V):

print(f"Aug - Creating sparse matrix with scipy...")
A_sp_scipy_aug = scipy.sparse.csc_array((V, (I, J)), shape=(group_index, group_index))
print(f"Aug - Creating sparse matrix from scipy using JAX BCOO...")
# print(f"Aug - Creating sparse matrix from scipy using JAX BCOO...")
A_sp_aug = BCOO.from_scipy_sparse(A_sp_scipy_aug).sort_indices()
print(f"Aug - self.A_sp.data.shape = {A_sp_aug.data.shape}")
print(f"Aug - Global sparse matrix takes about {A_sp_aug.data.shape[0]*8*3/2**30} G memory to store.")
# print(f"Aug - Global sparse matrix takes about {A_sp_aug.data.shape[0]*8*3/2**30} G memory to store.")

# TODO: Potential bug: Shouldn't this be problem.A_sp_scipy = A_sp_scipy_aug?
problem.A_sp_scipy_aug = A_sp_scipy_aug
Expand Down Expand Up @@ -462,7 +460,9 @@ def solver_lagrange_multiplier(problem, linear, use_petsc=True):
p_num_eps = problem.p_num_eps
else:
p_num_eps = 1.
print(f"Setting p_num_eps = {p_num_eps}. If periodic B.C. fails to be applied, consider modifying this parameter.")

if not use_petsc:
print(f"Setting p_num_eps = {p_num_eps}. If periodic B.C. fails to be applied, consider modifying this parameter.")

def newton_update_helper(dofs_aug):
res_vec = problem.newton_update(dofs_aug[:problem.num_total_dofs].reshape(sol_shape)).reshape(-1)
Expand Down Expand Up @@ -763,7 +763,7 @@ def f_fwd(params):
return sol, (params, sol)

def f_bwd(res, v):
print("\nRunning backward...")
print("\nRunning backward and solving the adjoint problem...")
params, sol = res
vjp_result = implicit_vjp(problem, sol, params, v, use_petsc)
return (vjp_result,)
Expand Down

0 comments on commit 1d34e86

Please sign in to comment.