Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update user buffers documentation #1144

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions rosetta/docs/GPU_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ The following flag removes extra copies introduced by DUS (dynamic update slice)
Enable user-buffers in NCCL for zero-copy collectives and send/recv. Needs NCCL_NVLS_ENABLE=1 for AG, AR, RS.
- --xla_gpu_enable_nccl_user_buffers=true

When user-buffers is enabled, a separate memory pool is created for user-buffer registered memory. Environment variable `XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB` can be used to configure this memory pool. It may also be necessary to reduce `XLA_PYTHON_CLIENT_MEM_FRACTION` to ensure there is enough memory for the user buffer pool.
- `XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB=0` (default value) - The user buffer pool will start empty, but will grow during execution as more collective memory is required. This setting can result in extra fragmentation and inefficient memory use.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there login user can enable to see if this happens?
That way, they know how to detect this and can do the fix easily.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically you can tell via TF_CPP_VMODULE=bfc_allocator=1. It will log when the pool is expanded.
In practice, an OOM will probably be the indicator.

- `XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB=<amount of MiB to preallocate>` - The user buffer pool will preallocate this amount of memory at the begining. The number should be high enough to cover peak collective memory usage.


Flags to reduce memory consumed by NCCL.
- --xla_gpu_enable_nccl_comm_splitting=true
- --xla_gpu_enable_nccl_per_stream_comms=false [https://github.com/openxla/xla/pull/9845](https://github.com/openxla/xla/pull/9845)
Expand Down