Skip to content

Commit

Permalink
Update to remove mpi_distribute
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Aug 27, 2024
1 parent dd8f535 commit 9059d18
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 25 deletions.
8 changes: 0 additions & 8 deletions mirgecom/gas_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,14 +806,6 @@ def make_operator_fluid_states(
dcoll, volume_state.smoothness_d, volume_dd=dd_vol,
comm_tag=(_FluidSmoothnessDiffTag, comm_tag))]

smoothness_d_interior_pairs = None
if volume_state.smoothness_d is not None:
smoothness_d_interior_pairs = [
interp_to_surf_quad(tpair=tpair)
for tpair in interior_trace_pairs(
dcoll, volume_state.smoothness_d, volume_dd=dd_vol,
tag=(_FluidSmoothnessDiffTag, comm_tag))]

smoothness_beta_interior_pairs = None
if volume_state.smoothness_beta is not None:
smoothness_beta_interior_pairs = [
Expand Down
34 changes: 17 additions & 17 deletions mirgecom/simutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,10 @@
THE SOFTWARE.
"""
import logging
import sys
import os
import pickle
from functools import partial
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import Dict, List, Optional
from contextlib import contextmanager

from logpyle import IntervalTimer
Expand Down Expand Up @@ -1062,7 +1061,6 @@ def distribute_mesh(comm, get_mesh_data, partition_generator_func=None, logmgr=N
from mpi4py import MPI
from mpi4py.util import pkl5
from socket import gethostname
from meshmode.distributed import mpi_distribute

num_ranks = comm.Get_size()
my_global_rank = comm.Get_rank()
Expand All @@ -1088,6 +1086,7 @@ def partition_generator_func(mesh, tag_to_elements, num_ranks):
reader_color = 0 if my_node_rank == 0 else 1
reader_comm = comm.Split(reader_color, my_global_rank)
my_reader_rank = reader_comm.Get_rank()
num_node_ranks = node_comm.Get_size()

if my_node_rank == 0:
num_reading_ranks = reader_comm.Get_size()
Expand Down Expand Up @@ -1146,7 +1145,7 @@ def partition_generator_func(mesh, tag_to_elements, num_ranks):
partition_generator_func(mesh, tag_to_elements,
num_ranks)

def get_rank_to_mesh_data():
def get_rank_to_mesh_data_dict():
if tag_to_elements is None:
rank_to_mesh_data = _partition_single_volume_mesh(
mesh, num_ranks, rank_per_element,
Expand All @@ -1163,11 +1162,11 @@ def get_rank_to_mesh_data():
rank: node_rank
for node_rank, rank in enumerate(node_ranks)}

node_rank_to_mesh_data = {
node_rank_to_mesh_data_dict = {
rank_to_node_rank[rank]: mesh_data
for rank, mesh_data in rank_to_mesh_data.items()}

return node_rank_to_mesh_data
return node_rank_to_mesh_data_dict

reader_comm.Barrier()
if my_reader_rank == 0:
Expand All @@ -1176,9 +1175,13 @@ def get_rank_to_mesh_data():
if logmgr:
logmgr.add_quantity(t_mesh_split)
with t_mesh_split.get_sub_timer():
node_rank_to_mesh_data = get_rank_to_mesh_data()
node_rank_to_mesh_data_dict = get_rank_to_mesh_data_dict()
else:
node_rank_to_mesh_data = get_rank_to_mesh_data()
node_rank_to_mesh_data_dict = get_rank_to_mesh_data_dict()

node_rank_to_mesh_data = [
node_rank_to_mesh_data_dict[rank]
for rank in range(num_node_ranks)]

reader_comm.Barrier()
if my_reader_rank == 0:
Expand All @@ -1189,24 +1192,21 @@ def get_rank_to_mesh_data():
if logmgr:
logmgr.add_quantity(t_mesh_dist)
with t_mesh_dist.get_sub_timer():
local_mesh_data = mpi_distribute(
node_comm, source_rank=0,
source_data=node_rank_to_mesh_data)
local_mesh_data = \
node_comm.scatter(node_rank_to_mesh_data, root=0)
else:
local_mesh_data = mpi_distribute(
node_comm, source_rank=0,
source_data=node_rank_to_mesh_data)
local_mesh_data = \
node_comm.scatter(node_rank_to_mesh_data, root=0)

else: # my_node_rank > 0, get mesh part from MPI
global_nelements = node_comm.bcast(None, root=0)

if logmgr:
logmgr.add_quantity(t_mesh_dist)
with t_mesh_dist.get_sub_timer():
local_mesh_data = \
mpi_distribute(node_comm, source_rank=0)
local_mesh_data = node_comm.scatter(None, root=0)
else:
local_mesh_data = mpi_distribute(node_comm, source_rank=0)
local_mesh_data = node_comm.scatter(None, root=0)

return local_mesh_data, global_nelements

Expand Down

0 comments on commit 9059d18

Please sign in to comment.