From e0e05a74083a159f7836c89ca69ce2ec053aad60 Mon Sep 17 00:00:00 2001 From: ByteDance Date: Thu, 11 Jul 2024 13:07:40 +0800 Subject: [PATCH] add process sync via temp file in lmms_eval/evaluator.py --- lmms_eval/evaluator.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index 6788467e..041eec31 100755 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -1,3 +1,5 @@ +import os +import time import random import itertools import json @@ -428,6 +430,12 @@ def evaluate( # Ensure all ranks wait for rank 0 to finish aggregation torch.distributed.barrier() + # Synchronize processes with a temp file in case the evluation metric requires gpus + # TODO: fix barriers' taking up gpu computation + os.makedirs(cli_args.output_path, exist_ok=True) + if os.path.exists(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt"): + os.remove(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt") + if lm.rank == 0: ### Get task ordering for correct sample-wide aggregation group_to_task = {} @@ -628,8 +636,12 @@ def print_tasks(task_hierarchy, task_order, task_version, task_group_alias): } if log_samples: results_dict["samples"] = dict(samples) - - return results_dict - else: - return None + results_dict = None + + with open(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt", 'w') as f: + f.write(f"rank {int(os.environ.get('RANK', 0))} eval done") + while len([file for file in os.listdir(cli_args.output_path) if file.endswith('metric_eval_done.txt')]) < lm.accelerator.num_processes: + time.sleep(1) + + return results_dict