From 7f14a7e4e230ff2069880bad3dd9ad3d9bf19fbe Mon Sep 17 00:00:00 2001 From: zhf-0 Date: Tue, 5 Sep 2023 14:02:09 +0800 Subject: [PATCH] add the parameter `ncommon` to `pymetis.part_mesh()` function (#57) --- pymetis/__init__.py | 11 +++++-- src/wrapper/wrapper.cpp | 2 +- test/test_partition_mesh.py | 58 +++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/pymetis/__init__.py b/pymetis/__init__.py index c18373b..ee0c46a 100644 --- a/pymetis/__init__.py +++ b/pymetis/__init__.py @@ -373,7 +373,8 @@ def part_graph(nparts, adjacency=None, xadj=None, adjncy=None, eweights, options, recursive) -def part_mesh(n_parts, connectivity, options=None, tpwgts=None, gtype=None): +def part_mesh(n_parts, connectivity, options=None, tpwgts=None, gtype=None, + ncommon=1): """This function is used to partition a mesh into *n_parts* parts based on a graph partitioning where each vertex is a node in the graph. A mesh is a collection of non-overlapping elements which are identified by their vertices. @@ -403,6 +404,12 @@ def part_mesh(n_parts, connectivity, options=None, tpwgts=None, gtype=None): ``gtype`` specifies the partitioning is based on a nodal/dual graph of the mesh. It has to be one of :attr:`GType.NODAL` or :attr:`GType.DUAL`. + ``ncommon`` is needed when ``gtype = GType.DUAL``. It Specifies the number of + common nodes that two elements must have in order to put an edge between them + in the dual graph. For example, for tetrahedron meshes, ncommon should be 3, + which creates an edge between two tets when they share a triangular face + (i.e., 3 nodes). + Returns a namedtuple of ``(edge_cuts, element_part, vertex_part)``, where ``edge_cuts`` is the number of cuts to the connectivity graph, ``element_part`` is an array of length n_elements, with entries identifying the element's @@ -448,6 +455,6 @@ def part_mesh(n_parts, connectivity, options=None, tpwgts=None, gtype=None): from pymetis._internal import part_mesh return MeshPartition(*part_mesh(n_parts, conn_offset, conn, - tpwgts, gtype, n_elements, n_vertex, options)) + tpwgts, gtype, n_elements, n_vertex, ncommon, options)) # vim: foldmethod=marker diff --git a/src/wrapper/wrapper.cpp b/src/wrapper/wrapper.cpp index 420e1a8..b2e030c 100644 --- a/src/wrapper/wrapper.cpp +++ b/src/wrapper/wrapper.cpp @@ -216,6 +216,7 @@ namespace idx_t >ype, idx_t &nElements, idx_t &nVertex, + idx_t &ncommon, metis_options &options) { idx_t edgeCuts = 0; @@ -242,7 +243,6 @@ namespace } else if(gtype == METIS_GTYPE_DUAL) { - idx_t ncommon = 1; idx_t objval = 1; int info = METIS_PartMeshDual(&nElements, &nVertex, connectivityOffsets.data(), connectivity.data(), diff --git a/test/test_partition_mesh.py b/test/test_partition_mesh.py index 88d2e80..7295724 100644 --- a/test/test_partition_mesh.py +++ b/test/test_partition_mesh.py @@ -95,6 +95,64 @@ def test_2d_quad_mesh_dual(vis=False): [float(n_vert)/float(n_part)] * n_part, rel=0.1) +def test_2d_quad_mesh_dual_with_ncommon(vis=False): + """ + Generate simple 2D `mesh` connectivity with rectangular elements, eg + + 6 --- 7 --- 9 + | | | + 3 --- 4 --- 5 + | | | + 0 --- 1 --- 2 + if use the default `ncommon = 1` + `_, elem_idx_list, _ = pymetis.part_mesh(2, mesh, gtype=pymetis.GType.DUAL)` + Then the output of `elem_idx_list` is `[0, 0, 0, 0]`, and the number is not + balanced. + + if set `ncommon = 2` + `_, elem_idx_list, _ = pymetis.part_mesh(2, mesh, gtype=pymetis.GType.DUAL, + ncommon)` + Then the output of `elem_idx_list` is `[0, 1, 0, 1]`, the number is balanced. + """ + n_cells_x = 2 + n_cells_y = 2 + points, connectivity = generate_mesh_2d(n_cells_x, n_cells_y) + + n_part = 2 + ncommon = 2 + n_cuts, elem_part, vert_part = pymetis.part_mesh(n_part, connectivity, + None, None, pymetis.GType.DUAL, ncommon) + + print(n_cuts) + print([elem_part.count(it) for it in range(n_part)]) + print([vert_part.count(it) for it in range(n_part)]) + + if vis: + import pyvtk + vtkelements = pyvtk.VtkData( + pyvtk.UnstructuredGrid(points, quad=connectivity), + "Mesh", + pyvtk.CellData(pyvtk.Scalars(elem_part, name="Rank")) + ) + vtkelements.tofile("quad.vtk") + + # Assertions about partition + assert min(elem_part) == 0 + assert max(elem_part) == n_part-1 + assert min(vert_part) == 0 + assert max(vert_part) == n_part-1 + + assert len(elem_part) == n_cells_x*n_cells_y + assert len(vert_part) == (n_cells_x+1)*(n_cells_y+1) + + # Test that the partition assigns approx the same number of elements/vertices + # to each partition + n_elem = n_cells_x*n_cells_y + elem_count = [elem_part.count(it) for it in range(n_part)] + assert elem_count == pytest.approx( + [float(n_elem)/float(n_part)] * n_part, rel=0.1) + + def test_2d_quad_mesh_nodal_with_weights(vis=False): n_cells_x = 70 n_cells_y = 50