diff --git a/applications/cfd/gamma/example.py b/applications/cfd/gamma/example.py index 0a35c7b..300d987 100644 --- a/applications/cfd/gamma/example.py +++ b/applications/cfd/gamma/example.py @@ -66,7 +66,7 @@ def simulation(): gamma_args = {} - #laser parameter + # laser parameter gamma_args['eta'] = 0.43 gamma_args['r'] = 5e-5 gamma_args['rho'] = 8440 diff --git a/jax_am/fem/basis.py b/jax_am/fem/basis.py index 1872236..cc3ca84 100644 --- a/jax_am/fem/basis.py +++ b/jax_am/fem/basis.py @@ -117,7 +117,7 @@ def get_shape_vals_and_grads(ele_type): vals_and_grads = element.tabulate(1, quad_points)[:, :, re_order, :] shape_values = vals_and_grads[0, :, :, 0] shape_grads_ref = onp.transpose(vals_and_grads[1:, :, :, 0], axes=(1, 2, 0)) - print(f"ele_type = {ele_type}, quad_points.shape = {quad_points.shape}") + print(f"ele_type = {ele_type}, quad_points.shape = (num_quads, dim) = {quad_points.shape}") return shape_values, shape_grads_ref, weights @@ -179,5 +179,5 @@ def get_face_shape_vals_and_grads(ele_type): face_shape_vals = vals_and_grads[0, :, :, 0].reshape(num_faces, num_face_quads, -1) face_shape_grads_ref = vals_and_grads[1:, :, :, 0].reshape(dim, num_faces, num_face_quads, -1) face_shape_grads_ref = onp.transpose(face_shape_grads_ref, axes=(1, 2, 3, 0)) - print(f"face_quad_points.shape = {face_quad_points.shape}") + print(f"face_quad_points.shape = (num_faces, num_face_quads, dim) = {face_quad_points.shape}") return face_shape_vals, face_shape_grads_ref, face_weights, face_normals, face_inds diff --git a/jax_am/fem/core.py b/jax_am/fem/core.py index 6499099..74a5ab3 100644 --- a/jax_am/fem/core.py +++ b/jax_am/fem/core.py @@ -308,7 +308,8 @@ def get_boundary_conditions_inds(self, location_fns): ------- boundary_inds_list : List[onp.ndarray] (num_selected_faces, 2) - boundary_inds_list[k][i, j] returns the index of face j of cell i of surface k + boundary_inds_list[k][i, 0] returns the global cell index of the ith selected face of boundary subset k + boundary_inds_list[k][i, 1] returns the local face index of the ith selected face of boundary subset k """ cell_points = onp.take(self.points, self.cells, axis=0) # (num_cells, num_nodes, dim) cell_face_points = onp.take(cell_points, self.face_inds, axis=1) # (num_cells, num_faces, num_face_nodes, dim) @@ -684,3 +685,32 @@ def set_params(self, params): """Used for solving inverse problems. """ raise NotImplementedError("Child class must implement this function!") + + def print_BC_info(self): + """Print boundary condition information for debugging purposes. + """ + if hasattr(self, 'neumann_boundary_inds_list'): + print(f"\n\n### Neumann B.C. is specified") + for i in range(len(self.neumann_boundary_inds_list)): + print(f"\nNeumann Boundary part {i + 1} information:") + print(self.neumann_boundary_inds_list[i]) + print(f"Array.shape = (num_selected_faces, 2) = {self.neumann_boundary_inds_list[i].shape}") + print(f"Interpretation:") + print(f" Array[i, 0] returns the global cell index of the ith selected face") + print(f" Array[i, 1] returns the local face index of the ith selected face") + else: + print(f"\n\n### No Neumann B.C. found.") + + if len(self.node_inds_list) != 0: + print(f"\n\n### Dirichlet B.C. is specified") + for i in range(len(self.node_inds_list)): + print(f"\nDirichlet Boundary part {i + 1} information:") + bc_array = onp.stack([self.node_inds_list[i], self.vec_inds_list[i], self.vals_list[i]]).T + print(bc_array) + print(f"Array.shape = (num_selected_dofs, 3) = {bc_array.shape}") + print(f"Interpretation:") + print(f" Array[i, 0] returns the node index of the ith selected dof") + print(f" Array[i, 1] returns the vec index of the ith selected dof") + print(f" Array[i, 2] returns the value assigned to ith selected dof") + else: + print(f"\n\n### No Dirichlet B.C. found.") \ No newline at end of file diff --git a/jax_am/fem/generate_mesh.py b/jax_am/fem/generate_mesh.py index 770dccd..984623f 100644 --- a/jax_am/fem/generate_mesh.py +++ b/jax_am/fem/generate_mesh.py @@ -61,6 +61,8 @@ def box_mesh(Nx, Ny, Nz, Lx, Ly, Lz, data_dir, ele_type='HEX8'): https://gitlab.onelab.info/gmsh/gmsh/-/blob/master/examples/api/hex.py https://gitlab.onelab.info/gmsh/gmsh/-/blob/gmsh_4_7_1/tutorial/python/t1.py https://gitlab.onelab.info/gmsh/gmsh/-/blob/gmsh_4_7_1/tutorial/python/t3.py + + Accepts ele_type = 'HEX8', 'TET4' or 'TET10' """ assert ele_type != 'HEX20', f"gmsh cannot produce HEX20 mesh?" diff --git a/jax_am/fem/solver.py b/jax_am/fem/solver.py index d2ebbf7..6a40ec8 100644 --- a/jax_am/fem/solver.py +++ b/jax_am/fem/solver.py @@ -205,10 +205,9 @@ def linear_incremental_solver(problem, res_vec, A_fn, dofs, precond, use_petsc): def get_A_fn(problem, use_petsc): print(f"Creating sparse matrix with scipy...") A_sp_scipy = scipy.sparse.csr_array((problem.V, (problem.I, problem.J)), shape=(problem.num_total_dofs, problem.num_total_dofs)) - print(f"Creating sparse matrix from scipy using JAX BCOO...") + # print(f"Creating sparse matrix from scipy using JAX BCOO...") A_sp = BCOO.from_scipy_sparse(A_sp_scipy).sort_indices() - print(f"self.A_sp.data.shape = {A_sp.data.shape}") - print(f"Global sparse matrix takes about {A_sp.data.shape[0]*8*3/2**30} G memory to store.") + # print(f"Global sparse matrix takes about {A_sp.data.shape[0]*8*3/2**30} G memory to store.") problem.A_sp_scipy = A_sp_scipy def compute_linearized_residual(dofs): @@ -417,10 +416,9 @@ def symmetry(I, J, V): print(f"Aug - Creating sparse matrix with scipy...") A_sp_scipy_aug = scipy.sparse.csc_array((V, (I, J)), shape=(group_index, group_index)) - print(f"Aug - Creating sparse matrix from scipy using JAX BCOO...") + # print(f"Aug - Creating sparse matrix from scipy using JAX BCOO...") A_sp_aug = BCOO.from_scipy_sparse(A_sp_scipy_aug).sort_indices() - print(f"Aug - self.A_sp.data.shape = {A_sp_aug.data.shape}") - print(f"Aug - Global sparse matrix takes about {A_sp_aug.data.shape[0]*8*3/2**30} G memory to store.") + # print(f"Aug - Global sparse matrix takes about {A_sp_aug.data.shape[0]*8*3/2**30} G memory to store.") # TODO: Potential bug: Shouldn't this be problem.A_sp_scipy = A_sp_scipy_aug? problem.A_sp_scipy_aug = A_sp_scipy_aug @@ -462,7 +460,9 @@ def solver_lagrange_multiplier(problem, linear, use_petsc=True): p_num_eps = problem.p_num_eps else: p_num_eps = 1. - print(f"Setting p_num_eps = {p_num_eps}. If periodic B.C. fails to be applied, consider modifying this parameter.") + + if not use_petsc: + print(f"Setting p_num_eps = {p_num_eps}. If periodic B.C. fails to be applied, consider modifying this parameter.") def newton_update_helper(dofs_aug): res_vec = problem.newton_update(dofs_aug[:problem.num_total_dofs].reshape(sol_shape)).reshape(-1) @@ -763,7 +763,7 @@ def f_fwd(params): return sol, (params, sol) def f_bwd(res, v): - print("\nRunning backward...") + print("\nRunning backward and solving the adjoint problem...") params, sol = res vjp_result = implicit_vjp(problem, sol, params, v, use_petsc) return (vjp_result,)