Skip to content

Commit

Permalink
Use async unsorted map
Browse files Browse the repository at this point in the history
  • Loading branch information
wvangeit committed Dec 23, 2024
1 parent 4c91a70 commit 494c635
Showing 1 changed file with 38 additions and 23 deletions.
61 changes: 38 additions & 23 deletions docker_scripts/parallelrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import multiprocessing
import os
import pathlib as pl
import pprint
import tempfile
import time
import traceback
Expand Down Expand Up @@ -161,35 +162,42 @@ def run_input_tasks(self, input_tasks, tasks_uuid):

input_batches = self.batch_input_tasks(input_tasks, n_of_batches)

output_tasks = input_tasks.copy()
for output_task in output_tasks:
output_task["status"] = "SUBMITTED"
output_tasks_content = json.dumps(
{"uuid": tasks_uuid, "tasks": output_tasks}
)
self.output_tasks_path.write_text(output_tasks_content)

output_batches = self.run_batches(
tasks_uuid, input_batches, number_of_workers
)

output_tasks = self.unbatch_output_tasks(output_batches)
for output_batch in output_batches:
output_batch_tasks = output_batch['tasks']

output_tasks_content = json.dumps(
{"uuid": tasks_uuid, "tasks": output_tasks}
)
self.output_tasks_path.write_text(output_tasks_content)
for output_task_i, output_task in output_batch_tasks:
output_tasks[output_task_i] = output_task
# logging.info(output_task["status"])

output_tasks_content = json.dumps(
{"uuid": tasks_uuid, "tasks": output_tasks}
)
self.output_tasks_path.write_text(output_tasks_content)
logger.info(f"Finished a batch of {len(output_batch_tasks)} tasks")
logger.info(f"Finished a set of {len(output_tasks)} tasks")
logger.debug(f"Finished a set of tasks: {output_tasks_content}")

def batch_input_tasks(self, input_tasks, n_of_batches):
batches = [[] for _ in range(n_of_batches)]
batches = [{'batch_i': None,'tasks':[]} for _ in range(n_of_batches)]

for task_i, input_task in enumerate(input_tasks):
batch_id = task_i % n_of_batches
batches[batch_id].append(input_task)
batches[batch_id]['batch_i'] = batch_id
batches[batch_id]['tasks'].append((task_i, input_task))
return batches

def unbatch_output_tasks(self, batches):
output_tasks = []
n_of_tasks = sum(len(batch) for batch in batches)

for task_i in range(n_of_tasks):
batch_id = task_i % len(batches)
output_tasks.append(batches[batch_id].pop(0))
return output_tasks

def create_job_inputs(self, input):
"""Create job inputs"""
Expand Down Expand Up @@ -296,10 +304,12 @@ async def run_job(self, job_inputs, input_batch):
def process_job_outputs(self, results, batch, status):
if self.settings.template_id == "TEST_UUID":
logger.info("Map in test mode, just returning input")
for task_i, task in batch["tasks"]:
task["status"] = "SUCCESS"

return batch

for task_i, task in enumerate(batch):
for task_i, task in batch["tasks"]:
output = task["output"]
task["status"] = status
for probe_name, probe_output in results.items():
Expand Down Expand Up @@ -358,7 +368,7 @@ def process_job_outputs(self, results, batch, status):

def transform_batch_to_task_input(self, batch):
task_input = {}
for task in batch:
for task_i, task in batch["tasks"]:
input = task["input"]
for param_name, param_input in input.items():
param_type = param_input["type"]
Expand Down Expand Up @@ -449,6 +459,10 @@ def run_batches(self, tasks_uuid, input_batches, number_of_workers):
def map_func(batch_with_uuid, trial_number=1):
return asyncio.run(async_map_func(batch_with_uuid, trial_number))

def set_batch_status(batch, message):
for task_i, task in batch["tasks"]:
task["status"] = "FAILURE"

async def async_map_func(batch_with_uuid, trial_number=1):
batch_uuid, batch = batch_with_uuid
try:
Expand Down Expand Up @@ -489,27 +503,28 @@ async def async_map_func(batch_with_uuid, trial_number=1):
except ParallelRunner.FatalException as error:
logger.info(
f"Batch {batch} failed with fatal error ({error}) in "
f"trial {trial_number}, not retrying, raising error"
f"trial {trial_number}, not retrying"
)
self.jobs_file_write_status_change(
id=batch_uuid,
status="failed",
)

raise error
set_batch_status(batch, "FAILURE")
# raise error
except Exception as error:
if trial_number >= self.settings.max_job_trials:
logger.info(
f"Batch {batch} failed with error ("
f"{traceback.format_exc()}) in "
f"trial {trial_number}, reach max number of trials of "
f"{self.settings.max_job_trials}, not retrying, raising error"
f"{self.settings.max_job_trials}, not retrying"
)
self.jobs_file_write_status_change(
id=batch_uuid,
status="failed",
)
raise error
set_batch_status(batch, "FAILURE")
# raise error
else:
logger.info(
f"Batch {batch} failed with error ("
Expand Down Expand Up @@ -540,7 +555,7 @@ async def async_map_func(batch_with_uuid, trial_number=1):

with pathos.pools.ThreadPool(nodes=number_of_workers) as pool:
pool.restart()
output_tasks = list(pool.map(map_func, input_batches_with_uuid))
output_tasks = pool.uimap(map_func, input_batches_with_uuid)
pool.close()
pool.join()
pool.clear() # Pool is singleton, need to clear old pool
Expand Down

0 comments on commit 494c635

Please sign in to comment.