From 5bb46df3c3b12f2b6d88a47ce04ead8f53e37a58 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Fri, 15 Nov 2024 21:04:25 -0800 Subject: [PATCH] Model: Fix draft model non-FA2 fallback --- backends/exllamav2/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index c7d2069..44e354f 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -159,7 +159,6 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): if enable_draft: self.draft_config = ExLlamaV2Config() - self.draft_config.no_flash_attn = self.config.no_flash_attn draft_model_path = pathlib.Path( unwrap(draft_args.get("draft_model_dir"), "models") ) @@ -253,6 +252,8 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): or not supports_paged_attn() ): self.config.no_flash_attn = True + if self.draft_config: + self.draft_config.no_flash_attn = True self.paged = False self.max_batch_size = 1 torch.backends.cuda.enable_flash_sdp(False)