Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot see multiple GPUs when using Slurm (with proposed fix) #865

Open
gabeweisz opened this issue Sep 4, 2024 · 0 comments
Open

Cannot see multiple GPUs when using Slurm (with proposed fix) #865

gabeweisz opened this issue Sep 4, 2024 · 0 comments
Labels

Comments

@gabeweisz
Copy link

When using MaxText with slurm, our jobs only see one GPU per node because jax.distributed assumes one GPU per process when used with slurm (see the Jax docs.

This behavior can be overridden by passing local_device_ids to jax.distributed.initialize, so one way to fix this is to change initialize_jax_for_gpu as follows (max_utils.py line 243):
def initialize_jax_for_gpu():
"""Jax distributed initialize for GPUs."""
if os.environ.get("JAX_COORDINATOR_IP") is not None:
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
device_list = {os.getenv("CUDA_VISIBLE_DEVICES")}
if len(device_list) == 0:
device_list = None
jax.distributed.initialize(
coordinator_address=f"{coordinator_ip}:{coordinator_port}",
num_processes=int(os.getenv("NNODES")),
process_id=int(os.getenv("NODE_RANK")),
local_device_ids=device_list,
)
max_logging.log(f"JAX global devices: {jax.devices()}")

This can probably use more robust error handling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants