diff --git a/src/atlas/meshgenerator/detail/RegularMeshGenerator.cc b/src/atlas/meshgenerator/detail/RegularMeshGenerator.cc index 6035bc64a..cc07bab77 100644 --- a/src/atlas/meshgenerator/detail/RegularMeshGenerator.cc +++ b/src/atlas/meshgenerator/detail/RegularMeshGenerator.cc @@ -45,6 +45,11 @@ namespace atlas { namespace meshgenerator { RegularMeshGenerator::RegularMeshGenerator(const eckit::Parametrisation& p) { + + std::string mpi_comm = mpi::comm().name(); + p.get("mpi_comm", mpi_comm); + options.set("mpi_comm",mpi_comm); + configure_defaults(); // options copied from Structured MeshGenerator @@ -87,15 +92,19 @@ RegularMeshGenerator::RegularMeshGenerator(const eckit::Parametrisation& p) { } void RegularMeshGenerator::configure_defaults() { + std::string mpi_comm; + options.get("mpi_comm",mpi_comm); + auto& comm = mpi::comm(mpi_comm); + // This option sets number of parts the mesh will be split in - options.set("nb_parts", mpi::size()); + options.set("nb_parts", comm.size()); // This option sets the part that will be generated - options.set("part", mpi::rank()); + options.set("part", comm.rank()); // This options sets the default partitioner std::string partitioner; - if (grid::Partitioner::exists("ectrans") && mpi::size() > 1) { + if (grid::Partitioner::exists("ectrans") && comm.size() > 1) { partitioner = "ectrans"; } else { @@ -126,8 +135,10 @@ void RegularMeshGenerator::generate(const Grid& grid, Mesh& mesh) const { // if ( nb_parts == 1 || eckit::mpi::size() == 1 ) partitioner_factory = // "equal_regions"; // Only one part --> Trans is slower + mpi::push(options.getString("mpi_comm")); grid::Partitioner partitioner(partitioner_type, nb_parts); grid::Distribution distribution(partitioner.partition(grid)); + mpi::pop(); generate(grid, distribution, mesh); } @@ -162,6 +173,7 @@ void RegularMeshGenerator::generate(const Grid& grid, const grid::Distribution& void RegularMeshGenerator::generate_mesh(const RegularGrid& rg, const grid::Distribution& distribution, // const Region& region, Mesh& mesh) const { + mpi::Scope mpi_scope(options.getString("mpi_comm")); int mypart = options.get("part"); int nparts = options.get("nb_parts"); int nx = rg.nx(); @@ -552,6 +564,10 @@ void RegularMeshGenerator::generate_mesh(const RegularGrid& rg, const grid::Dist } #endif + mesh.metadata().set("nb_parts",options.getInt("nb_parts")); + mesh.metadata().set("part",options.getInt("part")); + mesh.metadata().set("mpi_comm",options.getString("mpi_comm")); + generateGlobalElementNumbering(mesh); nodes.metadata().set("parallel", true); diff --git a/src/tests/mesh/test_meshgen_splitcomm.cc b/src/tests/mesh/test_meshgen_splitcomm.cc index 9bde00b84..f415244e2 100644 --- a/src/tests/mesh/test_meshgen_splitcomm.cc +++ b/src/tests/mesh/test_meshgen_splitcomm.cc @@ -53,6 +53,16 @@ Grid grid_CS() { return g; } +Grid grid_Regular() { + static Grid g (color() == 0 ? "S8" : "L16" ); + return g; +} + +Grid grid_healpix() { + static Grid g (color() == 0 ? "H8" : "H16" ); + return g; +} + struct Fixture { Fixture() { mpi::comm().split(color(),"split"); @@ -98,6 +108,24 @@ CASE("StructuredMeshGenerator") { mesh.polygon().outputPythonScript(grid().name()+"_polygons_1.py"); } +CASE("RegularMeshGenerator") { + Fixture fixture; + + MeshGenerator meshgen{"regular", option::mpi_comm("split")}; + Mesh mesh = meshgen.generate(grid_Regular()); + EXPECT_EQUAL(mesh.nb_parts(),mpi::comm("split").size()); + EXPECT_EQUAL(mesh.part(),mpi::comm("split").rank()); + EXPECT_EQUAL(mesh.mpi_comm(),"split"); + EXPECT_EQUAL(mpi::comm().name(),"world"); + output::Gmsh gmsh(grid().name()+"_regular.msh"); + gmsh.write(mesh); + + // partitioning graph and polygon output + EXPECT_NO_THROW(mesh.partitionGraph()); + EXPECT_NO_THROW(mesh.polygons()); + mesh.polygon().outputPythonScript(grid().name()+"_regular_polygons_1.py"); +} + CASE("DelaunayMeshGenerator") { if( ATLAS_HAVE_TESSELATION ) { Fixture fixture; @@ -194,7 +222,6 @@ CASE("MatchingPartitioner") { mesh_B.polygon().outputPythonScript(grid_B().name()+"_polygons_3.py"); } - } // namespace test } // namespace atlas