Skip to content

Commit

Permalink
Apply CEESD changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Jul 18, 2024
1 parent 4383375 commit e733704
Show file tree
Hide file tree
Showing 16 changed files with 2,123 additions and 59 deletions.
1,540 changes: 1,540 additions & 0 deletions meshmode/array_context.py

Large diffs are not rendered by default.

57 changes: 40 additions & 17 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ConcurrentElementInameTag,
DiscretizationDOFAxisTag,
DiscretizationElementAxisTag,
DiscretizationDOFPickListAxisTag,
)


Expand Down Expand Up @@ -478,6 +479,10 @@ def _per_target_group_pick_info(
cgrp = self.groups[i_tgrp]
tgrp = self.to_discr.groups[i_tgrp]

if tgrp.nelements == 1:
from warnings import warn
warn("_per_target_group_pick_info: tgrp has 1 element")

batch_dof_pick_lists = [
self._resample_point_pick_indices(i_tgrp, i_batch)
for i_batch in range(len(cgrp.batches))]
Expand Down Expand Up @@ -541,17 +546,22 @@ def _per_target_group_pick_info(
_FromGroupPickData(
from_group_index=source_group_index,
dof_pick_lists=actx.freeze(
actx.tag(NameHint("dof_pick_lists"),
actx.from_numpy(dof_pick_lists))),
actx.tag_axis(0, DiscretizationDOFPickListAxisTag(),
actx.tag(NameHint("dof_pick_lists"),
actx.from_numpy(dof_pick_lists)))),
dof_pick_list_indices=actx.freeze(
actx.tag(NameHint("dof_pick_list_indices"),
actx.from_numpy(dof_pick_list_indices))),
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("dof_pick_list_indices"),
actx.from_numpy(dof_pick_list_indices)))),
from_el_present=actx.freeze(
actx.tag(NameHint("from_el_present"),
actx.from_numpy(from_el_present.astype(np.int8)))),
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("from_el_present"),
actx.from_numpy(
from_el_present.astype(np.int8))))),
from_element_indices=actx.freeze(
actx.tag(NameHint("from_el_indices"),
actx.from_numpy(from_el_indices))),
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("from_el_indices"),
actx.from_numpy(from_el_indices)))),
is_surjective=from_el_present.all()
))

Expand Down Expand Up @@ -723,25 +733,29 @@ def group_pick_knl(is_surjective: bool):
group_pick_info = None

if group_pick_info is not None:
group_array_contributions = []
# group_array_contributions = []

if actx.permits_advanced_indexing and not _force_use_loopy:
for fgpd in group_pick_info:
from_element_indices = actx.thaw(fgpd.from_element_indices)

if ary[fgpd.from_group_index].size:
grp_ary_contrib = ary[fgpd.from_group_index][
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_element_indices, (-1, 1)),
actx, from_element_indices, (-1, 1))),
actx.thaw(fgpd.dof_pick_lists)[
actx.thaw(fgpd.dof_pick_list_indices)]
]

if not fgpd.is_surjective:
from_el_present = actx.thaw(fgpd.from_el_present)
grp_ary_contrib = actx.np.where(
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1))),
grp_ary_contrib,
0)

Expand Down Expand Up @@ -791,8 +805,10 @@ def group_pick_knl(is_surjective: bool):
mat = self._resample_matrix(actx, i_tgrp, i_batch)
if actx.permits_advanced_indexing and not _force_use_loopy:
batch_result = actx.np.where(
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
actx, from_el_present, (-1, 1))),
actx.einsum("ij,ej->ei",
mat, grp_ary[from_element_indices]),
0)
Expand All @@ -813,11 +829,15 @@ def group_pick_knl(is_surjective: bool):

if actx.permits_advanced_indexing and not _force_use_loopy:
batch_result = actx.np.where(
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
actx, from_el_present, (-1, 1))),
from_vec[
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_element_indices, (-1, 1)),
actx, from_element_indices, (-1, 1))),
pick_list],
0)
else:
Expand All @@ -844,10 +864,13 @@ def group_pick_knl(is_surjective: bool):
else:
# If no batched data at all, return zeros for this
# particular group array
group_array = actx.zeros(
group_array = tag_axes(actx, {
0: DiscretizationElementAxisTag(),
1: DiscretizationDOFAxisTag()},
actx.zeros(
shape=(self.to_discr.groups[i_tgrp].nelements,
self.to_discr.groups[i_tgrp].nunit_dofs),
dtype=ary.entry_dtype)
dtype=ary.entry_dtype))

group_arrays.append(group_array)

Expand Down
6 changes: 6 additions & 0 deletions meshmode/discretization/poly_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,14 @@ def __init__(self, mesh_el_group: _MeshTensorProductElementGroup,
"`unit_nodes` dim = {unit_nodes.shape[0]}.")

self._basis = basis
self._bases_1d = basis.bases[0]
self._nodes = unit_nodes

def bases_1d(self):
"""Return 1D component bases used to construct the tensor product basis.
"""
return self._bases_1d

def basis_obj(self):
return self._basis

Expand Down
112 changes: 82 additions & 30 deletions meshmode/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.. autoclass:: InterRankBoundaryInfo
.. autoclass:: MPIBoundaryCommSetupHelper
.. autofunction:: mpi_distribute
.. autofunction:: get_partition_by_pymetis
.. autofunction:: membership_list_to_map
.. autofunction:: get_connected_parts
Expand Down Expand Up @@ -36,11 +37,22 @@
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Hashable, List, Mapping, Sequence, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Hashable,
List,
Optional,
Mapping,
Sequence,
Set,
Union,
cast
)
from warnings import warn

import numpy as np

from contextlib import contextmanager
from arraycontext import ArrayContext

from meshmode.discretization import ElementGroupFactory
Expand All @@ -66,6 +78,70 @@

# {{{ mesh distributor

@contextmanager
def _duplicate_mpi_comm(mpi_comm):
dup_comm = mpi_comm.Dup()
try:
yield dup_comm
finally:
dup_comm.Free()


def mpi_distribute(
mpi_comm: "mpi4py.MPI.Intracomm",
source_data: Optional[Mapping[int, Any]] = None,
source_rank: int = 0) -> Optional[Any]:
"""
Distribute data to a set of processes.
:arg mpi_comm: An ``MPI.Intracomm``
:arg source_data: A :class:`dict` mapping destination ranks to data to be sent.
Only present on the source rank.
:arg source_rank: The rank from which the data is being sent.
:returns: The data local to the current process if there is any, otherwise
*None*.
"""
with _duplicate_mpi_comm(mpi_comm) as mpi_comm:
num_proc = mpi_comm.Get_size()
rank = mpi_comm.Get_rank()

local_data = None

if rank == source_rank:
if source_data is None:
raise TypeError("source rank has no data.")

sending_to = [False] * num_proc
for dest_rank in source_data.keys():
sending_to[dest_rank] = True

mpi_comm.scatter(sending_to, root=source_rank)

reqs = []
for dest_rank, data in source_data.items():
if dest_rank == rank:
local_data = data
logger.info("rank %d: received data", rank)
else:
reqs.append(mpi_comm.isend(data, dest=dest_rank))

logger.info("rank %d: sent all data", rank)

from mpi4py import MPI
MPI.Request.waitall(reqs)

else:
receiving = mpi_comm.scatter([], root=source_rank)

if receiving:
local_data = mpi_comm.recv(source=source_rank)
logger.info("rank %d: received data", rank)

return local_data


# TODO: Deprecate?
class MPIMeshDistributor:
"""
.. automethod:: is_mananger_rank
Expand Down Expand Up @@ -97,9 +173,7 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
Sends each part to a different rank.
Returns one part that was not sent to any other rank.
"""
mpi_comm = self.mpi_comm
rank = mpi_comm.Get_rank()
assert num_parts <= mpi_comm.Get_size()
assert num_parts <= self.mpi_comm.Get_size()

assert self.is_mananger_rank()

Expand All @@ -108,38 +182,16 @@ def send_mesh_parts(self, mesh, part_per_element, num_parts):
from meshmode.mesh.processing import partition_mesh
parts = partition_mesh(mesh, part_num_to_elements)

local_part = None

reqs = []
for r, part in parts.items():
if r == self.manager_rank:
local_part = part
else:
reqs.append(mpi_comm.isend(part, dest=r, tag=TAG_DISTRIBUTE_MESHES))

logger.info("rank %d: sent all mesh parts", rank)
for req in reqs:
req.wait()

return local_part
return mpi_distribute(
self.mpi_comm, source_data=parts, source_rank=self.manager_rank)

def receive_mesh_part(self):
"""
Returns the mesh sent by the manager rank.
"""
mpi_comm = self.mpi_comm
rank = mpi_comm.Get_rank()

assert not self.is_mananger_rank(), "Manager rank cannot receive mesh"

from mpi4py import MPI
status = MPI.Status()
result = self.mpi_comm.recv(
source=self.manager_rank, tag=TAG_DISTRIBUTE_MESHES,
status=status)
logger.info("rank %d: received local mesh (size = %d)", rank, status.count)

return result
return mpi_distribute(self.mpi_comm, source_rank=self.manager_rank)

# }}}

Expand Down
20 changes: 18 additions & 2 deletions meshmode/mesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,8 @@ def check_mesh_consistency(
"parameter force_positive_orientation=True to make_mesh().")
else:
warn("Unimplemented: Cannot check element orientation for a mesh with "
"mesh.dim != mesh.ambient_dim", stacklevel=2)
f"mesh.dim != mesh.ambient_dim ({mesh.dim=},{mesh.ambient_dim=})",
stacklevel=2)


def is_mesh_consistent(
Expand Down Expand Up @@ -944,6 +945,7 @@ def make_mesh(
node_vertex_consistency_tolerance: Optional[float] = None,
skip_element_orientation_test: bool = False,
force_positive_orientation: bool = False,
face_vertex_indices_to_tags=None,
) -> "Mesh":
"""Construct a new mesh from a given list of *groups*.
Expand Down Expand Up @@ -1032,6 +1034,15 @@ def make_mesh(
nodal_adjacency = (
NodalAdjacency(neighbors_starts=nb_starts, neighbors=nbs))

face_vert_ind_to_tags_local = None
if face_vertex_indices_to_tags is not None:
face_vert_ind_to_tags_local = face_vertex_indices_to_tags.copy()

if (facial_adjacency_groups is False or facial_adjacency_groups is None):
if face_vertex_indices_to_tags is not None:
facial_adjacency_groups = _compute_facial_adjacency_from_vertices(
groups, np.int32, np.int8, face_vertex_indices_to_tags)

if (
facial_adjacency_groups is not False
and facial_adjacency_groups is not None):
Expand All @@ -1058,8 +1069,13 @@ def make_mesh(
if force_positive_orientation:
if mesh.dim == mesh.ambient_dim:
import meshmode.mesh.processing as mproc
mesh_making_kwargs = {
"face_vertex_indices_to_tags": face_vert_ind_to_tags_local
}
mesh = mproc.perform_flips(
mesh, mproc.find_volume_mesh_element_orientations(mesh) < 0)
mesh=mesh,
flip_flags=mproc.find_volume_mesh_element_orientations(mesh) < 0,
skip_tests=False, mesh_making_kwargs=mesh_making_kwargs)
else:
raise ValueError("cannot enforce positive element orientation "
"on non-volume meshes")
Expand Down
17 changes: 15 additions & 2 deletions meshmode/mesh/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def get_mesh(self, return_tag_to_elements_map=False):

# compute facial adjacency for Mesh if there is tag information
facial_adjacency_groups = None
face_vert_ind_to_tags_local = face_vertex_indices_to_tags.copy()
if is_conforming and self.tags:
from meshmode.mesh import _compute_facial_adjacency_from_vertices
facial_adjacency_groups = _compute_facial_adjacency_from_vertices(
Expand All @@ -266,6 +267,7 @@ def get_mesh(self, return_tag_to_elements_map=False):
vertices, groups,
is_conforming=is_conforming,
facial_adjacency_groups=facial_adjacency_groups,
face_vertex_indices_to_tags=face_vert_ind_to_tags_local,
**self.mesh_construction_kwargs)

return (mesh, tag_to_elements) if return_tag_to_elements_map else mesh
Expand Down Expand Up @@ -294,10 +296,21 @@ def read_gmsh(
belong to that volume.
"""
from gmsh_interop.reader import read_gmsh
import time
print("Reading gmsh mesh from disk file...")
recv = GmshMeshReceiver(mesh_construction_kwargs=mesh_construction_kwargs)
read_gmsh(recv, filename, force_dimension=force_ambient_dim)

return recv.get_mesh(return_tag_to_elements_map=return_tag_to_elements_map)
read_start = time.time()
read_gmsh(recv, filename, force_dimension=force_ambient_dim)
read_finish = time.time()
print("Done. Populating meshmode data structures...")
retval = recv.get_mesh(
return_tag_to_elements_map=return_tag_to_elements_map)
get_mesh_finish = time.time()
print("Done.")
print(f"Read GMSH: {read_finish - read_start}\n"
f"MeshData: {get_mesh_finish - read_finish}")
return retval


def generate_gmsh(source, dimensions=None, order=None, other_options=None,
Expand Down
Loading

0 comments on commit e733704

Please sign in to comment.