Skip to content

Commit b741340

Browse files
committed
fix: Update MPIMapProcessor and MPISpawnMapProcessor with new RunConfig
1 parent f87cca3 commit b741340

File tree

1 file changed

+40
-20
lines changed

1 file changed

+40
-20
lines changed

CHAP/common/processor.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,24 +1921,28 @@ class MPIMapProcessor(Processor):
19211921
"""A Processor that applies a parallel generic sub-pipeline to
19221922
a map configuration.
19231923
"""
1924-
def process(self, data, sub_pipeline=None, inputdir='.', outputdir='.',
1925-
interactive=False, log_level='INFO'):
1924+
def process(self, data, config=None, sub_pipeline=None, inputdir=None,
1925+
outputdir=None, interactive=None, log_level=None):
19261926
"""Run a parallel generic sub-pipeline.
19271927
19281928
:param data: Input data.
19291929
:type data: list[PipelineData]
1930+
:param config: Initialization parameters for an instance of
1931+
common.models.map.MapConfig.
1932+
:type config: dict, optional
19301933
:param sub_pipeline: The sub-pipeline.
19311934
:type sub_pipeline: Pipeline, optional
19321935
:param inputdir: Input directory, used only if files in the
1933-
input configuration are not absolute paths,
1934-
defaults to `'.'`.
1936+
input configuration are not absolute paths.
19351937
:type inputdir: str, optional
19361938
:param outputdir: Directory to which any output figures will
1937-
be saved, defaults to `'.'`.
1939+
be saved.
19381940
:type outputdir: str, optional
1939-
:param interactive: Allows for user interactions, defaults to
1940-
`False`.
1941+
:param interactive: Allows for user interactions.
19411942
:type interactive: bool, optional
1943+
:ivar log_level: Logger level (not case sesitive).
1944+
:type log_level: Literal[
1945+
'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], optional
19421946
:return: The `data` field of the first item in the returned
19431947
list of sub-pipeline items.
19441948
"""
@@ -1956,11 +1960,24 @@ def process(self, data, sub_pipeline=None, inputdir='.', outputdir='.',
19561960
num_proc = comm.Get_size()
19571961
rank = comm.Get_rank()
19581962

1959-
# Get the map configuration from data
1960-
map_config = self.get_config(
1961-
data, 'common.models.map.MapConfig', inputdir=inputdir)
1963+
# Get the validated map configuration
1964+
try:
1965+
map_config = self.get_config(
1966+
data, 'common.models.map.MapConfig', inputdir=inputdir)
1967+
except:
1968+
self.logger.info('No valid Map configuration in input pipeline '
1969+
'data, using config parameter instead.')
1970+
try:
1971+
# Local modules
1972+
from CHAP.common.models.map import MapConfig
1973+
1974+
map_config = MapConfig(**config, inputdir=inputdir)
1975+
except Exception as exc:
1976+
raise RuntimeError from exc
19621977

19631978
# Create the spec reader configuration for each processor
1979+
# FIX: catered to EDD with one spec scan
1980+
assert len(map_config.spec_scans) == 1
19641981
spec_scans = map_config.spec_scans[0]
19651982
scan_numbers = spec_scans.scan_numbers
19661983
num_scan = len(scan_numbers)
@@ -1985,7 +2002,7 @@ def process(self, data, sub_pipeline=None, inputdir='.', outputdir='.',
19852002
run_config = {'inputdir': inputdir, 'outputdir': outputdir,
19862003
'interactive': interactive, 'log_level': log_level}
19872004
run_config.update(sub_pipeline.get('config'))
1988-
run_config = RunConfig(**run_config, comm=comm, logger=self.logger)
2005+
run_config = RunConfig(**run_config, comm=comm)
19892006
pipeline_config = []
19902007
for item in sub_pipeline['pipeline']:
19912008
if isinstance(item, dict):
@@ -2000,20 +2017,17 @@ def process(self, data, sub_pipeline=None, inputdir='.', outputdir='.',
20002017
pipeline_config.append(item)
20012018

20022019
# Run the sub-pipeline on each processor
2003-
return run(
2004-
pipeline_config, inputdir=run_config.inputdir,
2005-
outputdir=run_config.outputdir, interactive=run_config.interactive,
2006-
logger=self.logger, comm=comm)
2020+
return run(run_config, pipeline_config, logger=self.logger, comm=comm)
20072021

20082022

20092023
class MPISpawnMapProcessor(Processor):
20102024
"""A Processor that applies a parallel generic sub-pipeline to
20112025
a map configuration by spawning workers processes.
20122026
"""
20132027
def process(
2014-
self, data, num_proc=1, root_as_worker=True, collect_on_root=True,
2015-
sub_pipeline=None, inputdir='.', outputdir='.', interactive=False,
2016-
log_level='INFO'):
2028+
self, data, num_proc=1, root_as_worker=True, collect_on_root=False,
2029+
sub_pipeline=None, inputdir=None, outputdir=None, interactive=None,
2030+
log_level=None):
20172031
"""Spawn workers running a parallel generic sub-pipeline.
20182032
20192033
:param data: Input data.
@@ -2024,7 +2038,7 @@ def process(
20242038
defaults to `True`.
20252039
:type root_as_worker: bool, optional
20262040
:param collect_on_root: Collect the result of the spawned
2027-
workers on the root node, defaults to `True`.
2041+
workers on the root node, defaults to `False`.
20282042
:type collect_on_root: bool, optional
20292043
:param sub_pipeline: The sub-pipeline.
20302044
:type sub_pipeline: Pipeline, optional
@@ -2038,6 +2052,9 @@ def process(
20382052
:param interactive: Allows for user interactions, defaults to
20392053
`False`.
20402054
:type interactive: bool, optional
2055+
:ivar log_level: Logger level (not case sesitive).
2056+
:type log_level: Literal[
2057+
'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], optional
20412058
:return: The `data` field of the first item in the returned
20422059
list of sub-pipeline items.
20432060
"""
@@ -2144,7 +2161,7 @@ def process(
21442161

21452162
# Run the sub-pipeline on the root node
21462163
if root_as_worker:
2147-
data = runner(run_config, pipeline_config[0], common_comm)
2164+
data = runner(run_config, pipeline_config[0], comm=common_comm)
21482165
elif collect_on_root:
21492166
run_config.spawn = 0
21502167
pipeline_config = [{'common.MPICollectProcessor': {
@@ -2158,7 +2175,10 @@ def process(
21582175

21592176
# Disconnect spawned workers and cleanup temporary files
21602177
if num_proc > first_proc:
2178+
# Align with the barrier in main() on common_comm
2179+
# when disconnecting the spawned worker
21612180
common_comm.barrier()
2181+
# Disconnect spawned workers and cleanup temporary files
21622182
sub_comm.Disconnect()
21632183
for tmp_name in tmp_names:
21642184
os.remove(tmp_name)

0 commit comments

Comments
 (0)