diff --git a/rsopt/__main__.py b/rsopt/__main__.py new file mode 100644 index 0000000..8d0ac04 --- /dev/null +++ b/rsopt/__main__.py @@ -0,0 +1,2 @@ +from mpi4py import MPI +MPI.Init() diff --git a/rsopt/mpi.py b/rsopt/mpi.py index 456b612..f761ac4 100644 --- a/rsopt/mpi.py +++ b/rsopt/mpi.py @@ -1,18 +1,62 @@ +import os +import subprocess +from inspect import currentframe, getframeinfo +import rsopt + +__active_env = None def get_mpi_environment(): + """Checks MPI environment and whether or not MPI is initialized + + Params: + None + + Returns: + None if mpi is unavailable; else a dict representing the active MPI environment""" + global __active_env + + # Test for mpi4py install try: + import mpi4py + mpi4py.rc.initialize = False from mpi4py import MPI except ModuleNotFoundError: # mpi4py not installed so it can't be used - return + __active_env = "no_mpi" + + if __active_env == "no_mpi": + return None + + # If we already ran this process and have an environment, return the active environment + if __active_env: + return __active_env + + frameinfo = getframeinfo(currentframe()) + print(f"Initializing MPI from {frameinfo.filename}:L{frameinfo.lineno}", flush=True) + + #import faulthandler + #import sys + #faulthandler.enable(file=sys.stderr, all_threads=True) + + # Test MPI intialization in another thread + fname = os.path.dirname(rsopt.__file__) + "/__main__.py" + pp = subprocess.run(["python", fname]) + + if pp.returncode != 0: + __active_env = "no_mpi" + return None if not MPI.COMM_WORLD.Get_size() - 1: # MPI not being used # (if user did start MPI with size 1 this would be an illegal configuration since: main + 1 worker = 2 ranks) - return + __active_env = "no_mpi" + return None nworkers = MPI.COMM_WORLD.Get_size() - 1 is_manager = MPI.COMM_WORLD.Get_rank() == 0 mpi_environment = {'mpi_comm': MPI.COMM_WORLD, 'comms': 'mpi', 'nworkers': nworkers, 'is_manager': is_manager} - return mpi_environment \ No newline at end of file + # Save global environment + __active_env = mpi_environment + + return mpi_environment diff --git a/rsopt/util.py b/rsopt/util.py index 0672e65..9857fb9 100644 --- a/rsopt/util.py +++ b/rsopt/util.py @@ -4,6 +4,7 @@ import numpy as np import pickle from libensemble.tools import save_libE_output +from .mpi import get_mpi_environment SLURM_PREFIX = 'nid' @@ -54,12 +55,11 @@ def return_nodelist(nodelist_string): def return_used_nodes(): """Returns all used processor names to rank 0 or an empty list if MPI not used. For ranks != 0 returns None.""" - try: - from mpi4py import MPI - except ModuleNotFoundError: - # If MPI not being used to start rsopt then no nodes will have srun executed yet + if not get_mpi_environment(): return [] + from mpi4py import MPI + rank = MPI.COMM_WORLD.Get_rank() name = MPI.Get_processor_name() all_names = MPI.COMM_WORLD.gather(name, root=0) @@ -92,12 +92,11 @@ def return_unused_node(): def broadcast(data, root_rank=0): """broadcast, or don't bother""" - try: - from mpi4py import MPI - except ModuleNotFoundError: - # If MPI not available for import then assume it isn't needed + if not get_mpi_environment(): return data + from mpi4py import MPI + if MPI.COMM_WORLD.Get_size() == 1: return data @@ -121,4 +120,4 @@ def _libe_save(H, persis_info, mess, filename): np.save(filename, H) with open(filename + ".pickle", "wb") as f: - pickle.dump(persis_info, f) \ No newline at end of file + pickle.dump(persis_info, f)