diff --git a/docker_scripts/parallelrunner.py b/docker_scripts/parallelrunner.py index 26142f7..4963e19 100755 --- a/docker_scripts/parallelrunner.py +++ b/docker_scripts/parallelrunner.py @@ -7,6 +7,7 @@ import multiprocessing import os import pathlib as pl +import pprint import tempfile import time import traceback @@ -161,35 +162,42 @@ def run_input_tasks(self, input_tasks, tasks_uuid): input_batches = self.batch_input_tasks(input_tasks, n_of_batches) + output_tasks = input_tasks.copy() + for output_task in output_tasks: + output_task["status"] = "SUBMITTED" + output_tasks_content = json.dumps( + {"uuid": tasks_uuid, "tasks": output_tasks} + ) + self.output_tasks_path.write_text(output_tasks_content) + output_batches = self.run_batches( tasks_uuid, input_batches, number_of_workers ) - output_tasks = self.unbatch_output_tasks(output_batches) + for output_batch in output_batches: + output_batch_tasks = output_batch['tasks'] - output_tasks_content = json.dumps( - {"uuid": tasks_uuid, "tasks": output_tasks} - ) - self.output_tasks_path.write_text(output_tasks_content) + for output_task_i, output_task in output_batch_tasks: + output_tasks[output_task_i] = output_task + # logging.info(output_task["status"]) + + output_tasks_content = json.dumps( + {"uuid": tasks_uuid, "tasks": output_tasks} + ) + self.output_tasks_path.write_text(output_tasks_content) + logger.info(f"Finished a batch of {len(output_batch_tasks)} tasks") logger.info(f"Finished a set of {len(output_tasks)} tasks") logger.debug(f"Finished a set of tasks: {output_tasks_content}") def batch_input_tasks(self, input_tasks, n_of_batches): - batches = [[] for _ in range(n_of_batches)] + batches = [{'batch_i': None,'tasks':[]} for _ in range(n_of_batches)] for task_i, input_task in enumerate(input_tasks): batch_id = task_i % n_of_batches - batches[batch_id].append(input_task) + batches[batch_id]['batch_i'] = batch_id + batches[batch_id]['tasks'].append((task_i, input_task)) return batches - def unbatch_output_tasks(self, batches): - output_tasks = [] - n_of_tasks = sum(len(batch) for batch in batches) - - for task_i in range(n_of_tasks): - batch_id = task_i % len(batches) - output_tasks.append(batches[batch_id].pop(0)) - return output_tasks def create_job_inputs(self, input): """Create job inputs""" @@ -296,10 +304,12 @@ async def run_job(self, job_inputs, input_batch): def process_job_outputs(self, results, batch, status): if self.settings.template_id == "TEST_UUID": logger.info("Map in test mode, just returning input") + for task_i, task in batch["tasks"]: + task["status"] = "SUCCESS" return batch - for task_i, task in enumerate(batch): + for task_i, task in batch["tasks"]: output = task["output"] task["status"] = status for probe_name, probe_output in results.items(): @@ -358,7 +368,7 @@ def process_job_outputs(self, results, batch, status): def transform_batch_to_task_input(self, batch): task_input = {} - for task in batch: + for task_i, task in batch["tasks"]: input = task["input"] for param_name, param_input in input.items(): param_type = param_input["type"] @@ -449,6 +459,10 @@ def run_batches(self, tasks_uuid, input_batches, number_of_workers): def map_func(batch_with_uuid, trial_number=1): return asyncio.run(async_map_func(batch_with_uuid, trial_number)) + def set_batch_status(batch, message): + for task_i, task in batch["tasks"]: + task["status"] = "FAILURE" + async def async_map_func(batch_with_uuid, trial_number=1): batch_uuid, batch = batch_with_uuid try: @@ -489,27 +503,28 @@ async def async_map_func(batch_with_uuid, trial_number=1): except ParallelRunner.FatalException as error: logger.info( f"Batch {batch} failed with fatal error ({error}) in " - f"trial {trial_number}, not retrying, raising error" + f"trial {trial_number}, not retrying" ) self.jobs_file_write_status_change( id=batch_uuid, status="failed", ) - - raise error + set_batch_status(batch, "FAILURE") + # raise error except Exception as error: if trial_number >= self.settings.max_job_trials: logger.info( f"Batch {batch} failed with error (" f"{traceback.format_exc()}) in " f"trial {trial_number}, reach max number of trials of " - f"{self.settings.max_job_trials}, not retrying, raising error" + f"{self.settings.max_job_trials}, not retrying" ) self.jobs_file_write_status_change( id=batch_uuid, status="failed", ) - raise error + set_batch_status(batch, "FAILURE") + # raise error else: logger.info( f"Batch {batch} failed with error (" @@ -540,7 +555,7 @@ async def async_map_func(batch_with_uuid, trial_number=1): with pathos.pools.ThreadPool(nodes=number_of_workers) as pool: pool.restart() - output_tasks = list(pool.map(map_func, input_batches_with_uuid)) + output_tasks = pool.uimap(map_func, input_batches_with_uuid) pool.close() pool.join() pool.clear() # Pool is singleton, need to clear old pool