diff --git a/docs/composability.md b/docs/composability.md index ffbb052a..87c0b022 100644 --- a/docs/composability.md +++ b/docs/composability.md @@ -22,3 +22,6 @@ One issue with seed checkpoints is that we rely on initializing _every_ model st ## On upcasting the final output to fp32 We intentionally upcast the final output tensor to fp32 inside the loss function rather in the `Transformer.forward()` so that forward and backward casts can be fused with the loss forward and backward respectively when we `torch.compile()` the loss function. This can improve both throughput and memory usage. + +## Setting `TORCH_NCCL_AVOID_RECORD_STREAMS=1` for TP +Users should set the environemnt variable `TORCH_NCCL_AVOID_RECORD_STREAMS=1` when using tensor parallelism (TP) to avoid unexpectedly high memory usage. `Tensor.record_stream` is a legacy approach for ensuring that a tensor allocated in one stream (e.g. default stream) and used in another stream (e.g. process group stream) is not freed before its usage completes. In particular, `record_stream` gets called for the `async_op=True` collectives used in TP. `Tensor.record_stream(stream)` records a CUDA event in the consumer stream `stream`, and the CUDA caching allocator queries this event for completion upon each future allocation to check if the tensor can be freed. This means that the tensor is not considered free in the caching allocator until the last GPU kernel before the recorded event finishes, which could be arbitrarily far in the future. For example, if the CPU is running far ahead of the GPU and issues `K` allocations before that last GPU kernel finishes, then none of those `K` allocations can reuse the memory from the tensor on which `record_stream` was called. Setting the environment variable `TORCH_NCCL_AVOID_RECORD_STREAMS=1` uses a simple alternative approach. The process group stashes references to the collective tensors until the user calls `wait()` on the collective. This should be intuitive: the collective input/output tensors cannot be freed until after the user calls `wait()`.