You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using MaxText with slurm, our jobs only see one GPU per node because jax.distributed assumes one GPU per process when used with slurm (see the Jax docs.
This behavior can be overridden by passing local_device_ids to jax.distributed.initialize, so one way to fix this is to change initialize_jax_for_gpu as follows (max_utils.py line 243):
def initialize_jax_for_gpu():
"""Jax distributed initialize for GPUs."""
if os.environ.get("JAX_COORDINATOR_IP") is not None:
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
device_list = {os.getenv("CUDA_VISIBLE_DEVICES")}
if len(device_list) == 0:
device_list = None
jax.distributed.initialize(
coordinator_address=f"{coordinator_ip}:{coordinator_port}",
num_processes=int(os.getenv("NNODES")),
process_id=int(os.getenv("NODE_RANK")),
local_device_ids=device_list,
)
max_logging.log(f"JAX global devices: {jax.devices()}")
This can probably use more robust error handling.
The text was updated successfully, but these errors were encountered:
When using MaxText with slurm, our jobs only see one GPU per node because jax.distributed assumes one GPU per process when used with slurm (see the Jax docs.
This behavior can be overridden by passing local_device_ids to jax.distributed.initialize, so one way to fix this is to change initialize_jax_for_gpu as follows (max_utils.py line 243):
def initialize_jax_for_gpu():
"""Jax distributed initialize for GPUs."""
if os.environ.get("JAX_COORDINATOR_IP") is not None:
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
device_list = {os.getenv("CUDA_VISIBLE_DEVICES")}
if len(device_list) == 0:
device_list = None
jax.distributed.initialize(
coordinator_address=f"{coordinator_ip}:{coordinator_port}",
num_processes=int(os.getenv("NNODES")),
process_id=int(os.getenv("NODE_RANK")),
local_device_ids=device_list,
)
max_logging.log(f"JAX global devices: {jax.devices()}")
This can probably use more robust error handling.
The text was updated successfully, but these errors were encountered: