Skip to content

Commit

Permalink
Explain TORCH_NCCL_AVOID_RECORD_STREAMS=1 for TP
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
awgu committed Dec 13, 2024
1 parent e846b69 commit 639c85e
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/composability.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ One issue with seed checkpoints is that we rely on initializing _every_ model st

## On upcasting the final output to fp32
We intentionally upcast the final output tensor to fp32 inside the loss function rather in the `Transformer.forward()` so that forward and backward casts can be fused with the loss forward and backward respectively when we `torch.compile()` the loss function. This can improve both throughput and memory usage.

## Setting `TORCH_NCCL_AVOID_RECORD_STREAMS=1` for TP
Users should set the environemnt variable `TORCH_NCCL_AVOID_RECORD_STREAMS=1` when using tensor parallelism (TP) to avoid unexpectedly high memory usage. `Tensor.record_stream` is a legacy approach for ensuring that a tensor allocated in one stream (e.g. default stream) and used in another stream (e.g. process group stream) is not freed before its usage completes. In particular, `record_stream` gets called for the `async_op=True` collectives used in TP. `Tensor.record_stream(stream)` records a CUDA event in the consumer stream `stream`, and the CUDA caching allocator queries this event for completion upon each future allocation to check if the tensor can be freed. This means that the tensor is not considered free in the caching allocator until the last GPU kernel before the recorded event finishes, which could be arbitrarily far in the future. For example, if the CPU is running far ahead of the GPU and issues `K` allocations before that last GPU kernel finishes, then none of those `K` allocations can reuse the memory from the tensor on which `record_stream` was called. Setting the environment variable `TORCH_NCCL_AVOID_RECORD_STREAMS=1` uses a simple alternative approach. The process group stashes references to the collective tensors until the user calls `wait()` on the collective. This should be intuitive: the collective input/output tensors cannot be freed until after the user calls `wait()`.

0 comments on commit 639c85e

Please sign in to comment.