Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #7854: [NVIDIA SPMD] Add HLO support to run windowed einsum in mul…
…tiple streams Imported from GitHub PR openxla/xla#7854 This PR contains the HLO changes to be able to run windowed einsum in multiple cuda streams. This adds a debug option to use multi-streaming for einsum, it will unroll the loop when turned on and mark the corresponding dots with different stream ids. With the unrolled windowed einsum loop and overlapping multiple gemms, we get 2-3% speedup on the gpt-3 175b models. There are some improvements to be made, major ones are: 1. redundant memcpy for each einsum loop, this can be mitigated by improving loop_double_buffer_transformer to unroll all nested while loops. 2. Some collective-permutes are scheduled after the async gemm, they won't be able to obtain enough SMs to run and are therefore partially exposed, this can be fixed by forcing them to be delayed in the latency hiding scheduler pass. I'm working on both the optimizations above in separate PRs. Detailed discussion [here](openxla/xla#8865) Copybara import of the project: -- 9674dd75275376c41220da45b04cd9318453e501 by TJ Xu <[email protected]>: Add option to run windowed einsum in multiple streams Unroll windowed einsum loop and assign stream id to dots Add tests -- 39983ab1847849ffc5cd7323c096c6f42c44f3c9 by TJ Xu <[email protected]>: Added a pass to annotate stream-related attributes to instructions -- 43409a97bcc0473c2415c800064a595cb3d9ea47 by TJ <[email protected]>: Wrap non-default stream with async instruction -- c6b98c9c43406d8a2e7efd6c47f0ff3df5a78a1b by TJ <[email protected]>: Add lhlo lowering for async compute instruction Add asyncCompute resource in latency hiding scheduler -- 6fb8439198bc667d3b51979d4b6bad5dd39ec743 by TJ <[email protected]>: Rename async_stream_attribute_wrapper to stream_attribute_async_wrapper -- c799e1b5811c3bc0d4b3a90580dc36deff474885 by TJ <[email protected]>: add missing BUILD deps -- 1095524f6bb9da1904ee579a196876bd45e3d4d2 by TJ <[email protected]>: Added a pass to process windowed einsum after spmd pipeline Merging this change closes #7854 PiperOrigin-RevId: 613899647
- Loading branch information