From 9059d18692af598807c36de77b7590a843e9907f Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Tue, 27 Aug 2024 18:56:55 -0500 Subject: [PATCH] Update to remove mpi_distribute --- mirgecom/gas_model.py | 8 -------- mirgecom/simutil.py | 34 +++++++++++++++++----------------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/mirgecom/gas_model.py b/mirgecom/gas_model.py index 26ec36800..d7dfe273d 100644 --- a/mirgecom/gas_model.py +++ b/mirgecom/gas_model.py @@ -806,14 +806,6 @@ def make_operator_fluid_states( dcoll, volume_state.smoothness_d, volume_dd=dd_vol, comm_tag=(_FluidSmoothnessDiffTag, comm_tag))] - smoothness_d_interior_pairs = None - if volume_state.smoothness_d is not None: - smoothness_d_interior_pairs = [ - interp_to_surf_quad(tpair=tpair) - for tpair in interior_trace_pairs( - dcoll, volume_state.smoothness_d, volume_dd=dd_vol, - tag=(_FluidSmoothnessDiffTag, comm_tag))] - smoothness_beta_interior_pairs = None if volume_state.smoothness_beta is not None: smoothness_beta_interior_pairs = [ diff --git a/mirgecom/simutil.py b/mirgecom/simutil.py index 56414af51..35b5baf7c 100644 --- a/mirgecom/simutil.py +++ b/mirgecom/simutil.py @@ -76,11 +76,10 @@ THE SOFTWARE. """ import logging -import sys import os import pickle from functools import partial -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import Dict, List, Optional from contextlib import contextmanager from logpyle import IntervalTimer @@ -1062,7 +1061,6 @@ def distribute_mesh(comm, get_mesh_data, partition_generator_func=None, logmgr=N from mpi4py import MPI from mpi4py.util import pkl5 from socket import gethostname - from meshmode.distributed import mpi_distribute num_ranks = comm.Get_size() my_global_rank = comm.Get_rank() @@ -1088,6 +1086,7 @@ def partition_generator_func(mesh, tag_to_elements, num_ranks): reader_color = 0 if my_node_rank == 0 else 1 reader_comm = comm.Split(reader_color, my_global_rank) my_reader_rank = reader_comm.Get_rank() + num_node_ranks = node_comm.Get_size() if my_node_rank == 0: num_reading_ranks = reader_comm.Get_size() @@ -1146,7 +1145,7 @@ def partition_generator_func(mesh, tag_to_elements, num_ranks): partition_generator_func(mesh, tag_to_elements, num_ranks) - def get_rank_to_mesh_data(): + def get_rank_to_mesh_data_dict(): if tag_to_elements is None: rank_to_mesh_data = _partition_single_volume_mesh( mesh, num_ranks, rank_per_element, @@ -1163,11 +1162,11 @@ def get_rank_to_mesh_data(): rank: node_rank for node_rank, rank in enumerate(node_ranks)} - node_rank_to_mesh_data = { + node_rank_to_mesh_data_dict = { rank_to_node_rank[rank]: mesh_data for rank, mesh_data in rank_to_mesh_data.items()} - return node_rank_to_mesh_data + return node_rank_to_mesh_data_dict reader_comm.Barrier() if my_reader_rank == 0: @@ -1176,9 +1175,13 @@ def get_rank_to_mesh_data(): if logmgr: logmgr.add_quantity(t_mesh_split) with t_mesh_split.get_sub_timer(): - node_rank_to_mesh_data = get_rank_to_mesh_data() + node_rank_to_mesh_data_dict = get_rank_to_mesh_data_dict() else: - node_rank_to_mesh_data = get_rank_to_mesh_data() + node_rank_to_mesh_data_dict = get_rank_to_mesh_data_dict() + + node_rank_to_mesh_data = [ + node_rank_to_mesh_data_dict[rank] + for rank in range(num_node_ranks)] reader_comm.Barrier() if my_reader_rank == 0: @@ -1189,13 +1192,11 @@ def get_rank_to_mesh_data(): if logmgr: logmgr.add_quantity(t_mesh_dist) with t_mesh_dist.get_sub_timer(): - local_mesh_data = mpi_distribute( - node_comm, source_rank=0, - source_data=node_rank_to_mesh_data) + local_mesh_data = \ + node_comm.scatter(node_rank_to_mesh_data, root=0) else: - local_mesh_data = mpi_distribute( - node_comm, source_rank=0, - source_data=node_rank_to_mesh_data) + local_mesh_data = \ + node_comm.scatter(node_rank_to_mesh_data, root=0) else: # my_node_rank > 0, get mesh part from MPI global_nelements = node_comm.bcast(None, root=0) @@ -1203,10 +1204,9 @@ def get_rank_to_mesh_data(): if logmgr: logmgr.add_quantity(t_mesh_dist) with t_mesh_dist.get_sub_timer(): - local_mesh_data = \ - mpi_distribute(node_comm, source_rank=0) + local_mesh_data = node_comm.scatter(None, root=0) else: - local_mesh_data = mpi_distribute(node_comm, source_rank=0) + local_mesh_data = node_comm.scatter(None, root=0) return local_mesh_data, global_nelements