Skip to content

Commit

Permalink
Merge pull request #244 from DocShotgun/draft-flash-attn-fix
Browse files Browse the repository at this point in the history
Fix draft model non-FA2 fallback
  • Loading branch information
bdashore3 authored Nov 17, 2024
2 parents 101ebd6 + 5bb46df commit dfc8899
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):

if enable_draft:
self.draft_config = ExLlamaV2Config()
self.draft_config.no_flash_attn = self.config.no_flash_attn
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
Expand Down Expand Up @@ -253,6 +252,8 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
or not supports_paged_attn()
):
self.config.no_flash_attn = True
if self.draft_config:
self.draft_config.no_flash_attn = True
self.paged = False
self.max_batch_size = 1
torch.backends.cuda.enable_flash_sdp(False)
Expand Down

0 comments on commit dfc8899

Please sign in to comment.