Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add flash implementation with context parallelism (#362)
* add flash implementation with context parallelism Signed-off-by: xren <[email protected]> * next more comments Signed-off-by: xren <[email protected]> * code comment fix Signed-off-by: xren <[email protected]> * comment fix Signed-off-by: xren <[email protected]> * add missing space Signed-off-by: xren <[email protected]> * fix docstrings Signed-off-by: xren <[email protected]> * try to add fa v2 api Signed-off-by: xren <[email protected]> * fix a comment Signed-off-by: xren <[email protected]> * fix padded kv return Signed-off-by: xren <[email protected]> * add docstrings of context parallelism Signed-off-by: xren <[email protected]> * minor fix Signed-off-by: xren <[email protected]> * minor docstring fix Signed-off-by: xren <[email protected]> * fix positional arguments Signed-off-by: xren <[email protected]> * make docstring line shorter Signed-off-by: xren <[email protected]> * add fa v2 backward api for flash_attn_with_cp Signed-off-by: xren <[email protected]> * remove redundant code Signed-off-by: xren <[email protected]> * make sure hidden size per attn head is multiple of 8 for FA2 Signed-off-by: xren <[email protected]> * remove an unnecessary assert check for FA2 Signed-off-by: xren <[email protected]> * indention fix Signed-off-by: Xiaowei Ren <[email protected]> * Update FA version Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Lint Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: xren <[email protected]> Signed-off-by: Xiaowei Ren <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
- Loading branch information