Skip to content

Commit

Permalink
add the parameter ncommon to pymetis.part_mesh() function (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhf-0 authored Sep 5, 2023
1 parent 69bbc12 commit 7f14a7e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 3 deletions.
11 changes: 9 additions & 2 deletions pymetis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ def part_graph(nparts, adjacency=None, xadj=None, adjncy=None,
eweights, options, recursive)


def part_mesh(n_parts, connectivity, options=None, tpwgts=None, gtype=None):
def part_mesh(n_parts, connectivity, options=None, tpwgts=None, gtype=None,
ncommon=1):
"""This function is used to partition a mesh into *n_parts* parts based on a
graph partitioning where each vertex is a node in the graph. A mesh is a
collection of non-overlapping elements which are identified by their vertices.
Expand Down Expand Up @@ -403,6 +404,12 @@ def part_mesh(n_parts, connectivity, options=None, tpwgts=None, gtype=None):
``gtype`` specifies the partitioning is based on a nodal/dual graph of the mesh.
It has to be one of :attr:`GType.NODAL` or :attr:`GType.DUAL`.
``ncommon`` is needed when ``gtype = GType.DUAL``. It Specifies the number of
common nodes that two elements must have in order to put an edge between them
in the dual graph. For example, for tetrahedron meshes, ncommon should be 3,
which creates an edge between two tets when they share a triangular face
(i.e., 3 nodes).
Returns a namedtuple of ``(edge_cuts, element_part, vertex_part)``, where
``edge_cuts`` is the number of cuts to the connectivity graph, ``element_part``
is an array of length n_elements, with entries identifying the element's
Expand Down Expand Up @@ -448,6 +455,6 @@ def part_mesh(n_parts, connectivity, options=None, tpwgts=None, gtype=None):

from pymetis._internal import part_mesh
return MeshPartition(*part_mesh(n_parts, conn_offset, conn,
tpwgts, gtype, n_elements, n_vertex, options))
tpwgts, gtype, n_elements, n_vertex, ncommon, options))

# vim: foldmethod=marker
2 changes: 1 addition & 1 deletion src/wrapper/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ namespace
idx_t &gtype,
idx_t &nElements,
idx_t &nVertex,
idx_t &ncommon,
metis_options &options)
{
idx_t edgeCuts = 0;
Expand All @@ -242,7 +243,6 @@ namespace
}
else if(gtype == METIS_GTYPE_DUAL)
{
idx_t ncommon = 1;
idx_t objval = 1;
int info = METIS_PartMeshDual(&nElements, &nVertex,
connectivityOffsets.data(), connectivity.data(),
Expand Down
58 changes: 58 additions & 0 deletions test/test_partition_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,64 @@ def test_2d_quad_mesh_dual(vis=False):
[float(n_vert)/float(n_part)] * n_part, rel=0.1)


def test_2d_quad_mesh_dual_with_ncommon(vis=False):
"""
Generate simple 2D `mesh` connectivity with rectangular elements, eg
6 --- 7 --- 9
| | |
3 --- 4 --- 5
| | |
0 --- 1 --- 2
if use the default `ncommon = 1`
`_, elem_idx_list, _ = pymetis.part_mesh(2, mesh, gtype=pymetis.GType.DUAL)`
Then the output of `elem_idx_list` is `[0, 0, 0, 0]`, and the number is not
balanced.
if set `ncommon = 2`
`_, elem_idx_list, _ = pymetis.part_mesh(2, mesh, gtype=pymetis.GType.DUAL,
ncommon)`
Then the output of `elem_idx_list` is `[0, 1, 0, 1]`, the number is balanced.
"""
n_cells_x = 2
n_cells_y = 2
points, connectivity = generate_mesh_2d(n_cells_x, n_cells_y)

n_part = 2
ncommon = 2
n_cuts, elem_part, vert_part = pymetis.part_mesh(n_part, connectivity,
None, None, pymetis.GType.DUAL, ncommon)

print(n_cuts)
print([elem_part.count(it) for it in range(n_part)])
print([vert_part.count(it) for it in range(n_part)])

if vis:
import pyvtk
vtkelements = pyvtk.VtkData(
pyvtk.UnstructuredGrid(points, quad=connectivity),
"Mesh",
pyvtk.CellData(pyvtk.Scalars(elem_part, name="Rank"))
)
vtkelements.tofile("quad.vtk")

# Assertions about partition
assert min(elem_part) == 0
assert max(elem_part) == n_part-1
assert min(vert_part) == 0
assert max(vert_part) == n_part-1

assert len(elem_part) == n_cells_x*n_cells_y
assert len(vert_part) == (n_cells_x+1)*(n_cells_y+1)

# Test that the partition assigns approx the same number of elements/vertices
# to each partition
n_elem = n_cells_x*n_cells_y
elem_count = [elem_part.count(it) for it in range(n_part)]
assert elem_count == pytest.approx(
[float(n_elem)/float(n_part)] * n_part, rel=0.1)


def test_2d_quad_mesh_nodal_with_weights(vis=False):
n_cells_x = 70
n_cells_y = 50
Expand Down

0 comments on commit 7f14a7e

Please sign in to comment.