Skip to content

Commit

Permalink
add flash implementation with context parallelism (#362)
Browse files Browse the repository at this point in the history
* add flash implementation with context parallelism

Signed-off-by: xren <[email protected]>

* next more comments

Signed-off-by: xren <[email protected]>

* code comment fix

Signed-off-by: xren <[email protected]>

* comment fix

Signed-off-by: xren <[email protected]>

* add missing space

Signed-off-by: xren <[email protected]>

* fix docstrings

Signed-off-by: xren <[email protected]>

* try to add fa v2 api

Signed-off-by: xren <[email protected]>

* fix a comment

Signed-off-by: xren <[email protected]>

* fix padded kv return

Signed-off-by: xren <[email protected]>

* add docstrings of context parallelism

Signed-off-by: xren <[email protected]>

* minor fix

Signed-off-by: xren <[email protected]>

* minor docstring fix

Signed-off-by: xren <[email protected]>

* fix positional arguments

Signed-off-by: xren <[email protected]>

* make docstring line shorter

Signed-off-by: xren <[email protected]>

* add fa v2 backward api for flash_attn_with_cp

Signed-off-by: xren <[email protected]>

* remove redundant code

Signed-off-by: xren <[email protected]>

* make sure hidden size per attn head is multiple of 8 for FA2

Signed-off-by: xren <[email protected]>

* remove an unnecessary assert check for FA2

Signed-off-by: xren <[email protected]>

* indention fix

Signed-off-by: Xiaowei Ren <[email protected]>

* Update FA version

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Lint

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: xren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
xrennvidia and ksivaman authored Sep 22, 2023
1 parent b95c181 commit 479dbb7
Show file tree
Hide file tree
Showing 3 changed files with 513 additions and 20 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None:

# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.2.1"])
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.2.2"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
Expand Down
Loading

0 comments on commit 479dbb7

Please sign in to comment.