Skip to content

Commit

Permalink
Merge pull request #6 from Quentin-Anthony/Quentin-Anthony-patch-1
Browse files Browse the repository at this point in the history
Add rank, local_rank, world_size for any launching mechanism
  • Loading branch information
Quentin-Anthony authored Mar 31, 2023
2 parents ed3a5b0 + 1334bad commit bbe3696
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions magma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,24 @@ def log_table(name, model_outputs, gt_answers_list, global_step):
results_table.add_data(o, gt)
wandb_log({f"eval/{name}": results_table}, step=global_step)


def env2int(env_list, default=-1):
for e in env_list:
val = int(os.environ.get(e, -1))
if val >= 0: return val
return default


def get_world_info():
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = env2int(
['LOCAL_RANK', 'MPI_LOCALRANKID', 'OMPI_COMM_WORLD_LOCAL_RANK', 'MV2_COMM_WORLD_LOCAL_RANK', 'SLURM_LOCALID'])
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(local_rank)
rank = env2int(['RANK', 'MPI_RANKID', 'OMPI_COMM_WORLD_RANK', 'MV2_COMM_WORLD_RANK', 'SLURM_PROCID'])
if 'RANK' not in os.environ:
os.environ['RANK'] = str(rank)
world_size = env2int(['WORLD_SIZE', 'OMPI_COMM_WORLD_SIZE', 'MV2_COMM_WORLD_SIZE', 'SLURM_NPROCS'])
if 'WORLD_SIZE' not in os.environ:
os.environ['WORLD_SIZE'] = str(world_size)
return local_rank, rank, world_size


Expand Down

0 comments on commit bbe3696

Please sign in to comment.