-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Adding torch accelerator and requirements file to FSDP2 example #1375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ Deploy Preview for pytorch-examples-preview canceled.
|
distributed/FSDP2/example.py
Outdated
torch.distributed.init_process_group(backend="nccl", device_id=device) | ||
if torch.accelerator.is_available(): | ||
device_type = torch.accelerator.current_accelerator() | ||
device: torch.device = torch.device(f"{device_type}:{rank}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need device: torch.device =
instead of just device =
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was just a flag for me, but I'll change it to use just torch.device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done :)
distributed/FSDP2/example.py
Outdated
backend = torch.distributed.get_default_backend_for_device(device) | ||
torch.distributed.init_process_group(backend=backend, device_id=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these 2 lines should work for cpu as well. You can simplify the code:
if torch.accelerator.is_available():
...
else:
device = torch.device("cpu")
backend = torch.distributed.get_default_backend_for_device(device)
torch.distributed.init_process_group(backend=backend, device_id=device)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Signed-off-by: dggaytan <[email protected]>
1f0d7d3
to
5e960d8
Compare
Adding torch accelerator support to FSDP2 example and
Updates to FSDP2 example:
Script Renaming and Documentation Updates:
train.py
toexample.py
and updated references inREADME.md
to reflect the new filename. Added instructions to install dependencies viarequirements.txt
before running the example.GPU Verification and Device Initialization:
verify_min_gpu_count
function to ensure at least two GPUs are available before running the example.main()
to dynamically detect and configure the device type usingtorch.accelerator
. This improves compatibility with different hardware setups.New supporting files:
Dependency Management:
requirements.txt
file listing required dependencies (torch>=2.7
andnumpy
).Script for Running Examples:
run_example.sh
to simplify launching FSDP2 example.Integration into Distributed Examples:
distributed_FSDP2
inrun_distributed_examples.sh
to include the FSDP2 example in the distributed testing workflow.CC: @msaroufim @malfet @dvrogozh