From 5da335eb3d11a0e2a3d4513f0bce073e1997562b Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:34:43 -0800 Subject: [PATCH 1/2] Model: Robust request length checking in generator * Ensure that length of positive/negative prompt + max_tokens does not exceed max_seq_len * Ensure that total required pages for CFG request does not exceed allocated cache_size --- backends/exllamav2/model.py | 46 ++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 50cef42..f46f4f9 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1301,17 +1301,51 @@ async def generate_gen( # The first index will always be the positive prompt context_len = input_ids[0].size(dim=-1) - if context_len > self.config.max_seq_len: - raise ValueError( - f"Context length {context_len} is greater than max_seq_len " - f"{self.config.max_seq_len}" - ) + + # The second index will be the negative prompt if CFG is enabled + if negative_prompt is not None: + negative_context_len = input_ids[1].size(dim=-1) + else: + negative_context_len = 0 # Automatically set max_tokens to fill up the context # This should be an OK default, but may be changed in the future max_tokens = unwrap( - kwargs.get("max_tokens"), self.config.max_seq_len - context_len + kwargs.get("max_tokens"), + self.config.max_seq_len - max(context_len, negative_context_len), ) + if max_tokens < 1: + logger.warning("max_tokens must be a positive integer, " "setting to 1.") + max_tokens = 1 + + # Check total length of request + if context_len + max_tokens > self.config.max_seq_len: + raise ValueError( + f"Request length {context_len} + {max_tokens} is greater than " + f"max_seq_len {self.config.max_seq_len}" + ) + + # Check total length of negative prompt request if CFG is enabled + if negative_prompt is not None: + if context_len + max_tokens > self.config.max_seq_len: + raise ValueError( + f"Request length for negative prompt " + f"{negative_context_len} + {max_tokens} is greater than " + f"max_seq_len {self.config.max_seq_len}" + ) + # Check total required pages for CFG request + if ( + sum( + 256 * math.ceil((context + max_tokens) / 256) + for context in (context_len, negative_context_len) + ) + > self.cache_size + ): + raise ValueError( + f"Total required page size for request " + f"{context_len} + {negative_context_len} + {max_tokens} * 2 " + f"is greater than cache_size {self.cache_size}" + ) # Set min_tokens to generate while keeping EOS banned min_tokens = unwrap(kwargs.get("min_tokens"), 0) From 4d11323c17805b61833eafcc869d5221a89ff5fb Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:37:33 -0800 Subject: [PATCH 2/2] Tree: Format --- backends/exllamav2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index f46f4f9..3478820 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1315,7 +1315,7 @@ async def generate_gen( self.config.max_seq_len - max(context_len, negative_context_len), ) if max_tokens < 1: - logger.warning("max_tokens must be a positive integer, " "setting to 1.") + logger.warning("max_tokens must be a positive integer, setting to 1.") max_tokens = 1 # Check total length of request