diff --git a/OAI/types/common.py b/OAI/types/common.py index 5040c71c..0cc848a9 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -75,6 +75,7 @@ class CommonCompletionRequest(BaseModel): add_bos_token: Optional[bool] = True ban_eos_token: Optional[bool] = False logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]]) + negative_prompt: Optional[str] = None # Aliased variables penalty_range: Optional[int] = Field( @@ -86,6 +87,10 @@ class CommonCompletionRequest(BaseModel): ), ) + cfg_scale: Optional[float] = Field( + default=1.0, validation_alias=AliasChoices("cfg_scale", "guidance_scale") + ) + def to_gen_params(self): """Converts to internal generation parameters.""" # Convert stop to an array of strings @@ -115,4 +120,6 @@ def to_gen_params(self): "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, "mirostat_eta": self.mirostat_eta, + "cfg_scale": self.cfg_scale, + "negative_prompt": self.negative_prompt, } diff --git a/OAI/types/model.py b/OAI/types/model.py index bfd395ed..9653324b 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -83,6 +83,7 @@ class ModelLoadRequest(BaseModel): cache_mode: Optional[str] = "FP16" prompt_template: Optional[str] = None num_experts_per_token: Optional[int] = None + use_cfg: Optional[bool] = None draft: Optional[DraftModelLoadRequest] = None diff --git a/args.py b/args.py index 1aa25316..1aee3271 100644 --- a/args.py +++ b/args.py @@ -106,6 +106,11 @@ def add_model_args(parser: argparse.ArgumentParser): type=int, help="Number of experts to use per token in MoE models", ) + model_group.add_argument( + "--use-cfg", + type=str_to_bool, + help="Enables CFG support", + ) def add_logging_args(parser: argparse.ArgumentParser): diff --git a/config_sample.yml b/config_sample.yml index 2006002c..557d886c 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -85,6 +85,10 @@ model: # NOTE: For MoE models (ex. Mixtral) only! #num_experts_per_token: + # Enables CFG support (default: False) + # WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream) + use_cfg: False + # Options for draft models (speculative decoding). This will use more VRAM! #draft: # Overrides the directory to look for draft (default: models) diff --git a/gen_logging.py b/gen_logging.py index 0ad5abb3..a82cea35 100644 --- a/gen_logging.py +++ b/gen_logging.py @@ -1,8 +1,8 @@ """ Functions for logging generation events. """ -from typing import Dict from pydantic import BaseModel +from typing import Dict, Optional from logger import init_logger @@ -53,12 +53,16 @@ def log_generation_params(**kwargs): logger.info(f"Generation options: {kwargs}\n") -def log_prompt(prompt: str): +def log_prompt(prompt: str, negative_prompt: Optional[str]): """Logs the prompt to console.""" if PREFERENCES.prompt: formatted_prompt = "\n" + prompt logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n") + if negative_prompt: + formatted_negative_prompt = "\n" + negative_prompt + logger.info(f"Negative Prompt: {formatted_negative_prompt}\n") + def log_response(response: str): """Logs the response to console.""" diff --git a/model.py b/model.py index b462ba6a..d61512fc 100644 --- a/model.py +++ b/model.py @@ -47,6 +47,7 @@ class ModelContainer: cache_fp8: bool = False gpu_split_auto: bool = True gpu_split: Optional[list] = None + use_cfg: bool = False active_loras: List[ExLlamaV2Lora] = [] @@ -95,6 +96,8 @@ def progress(loaded_modules: int, total_modules: int, tensors, per device 'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False) + 'use_cfg" (bool): Enables CFG support. Disables flash attention + (default: False) """ self.quiet = quiet @@ -135,8 +138,18 @@ def progress(loaded_modules: int, total_modules: int, kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len) ) - # Turn off flash attention? - self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False) + if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"): + self.use_cfg = unwrap(kwargs.get("use_cfg"), False) + else: + logger.warning( + "CFG is not supported by the currently installed ExLlamaV2 version." + ) + + # Turn off flash attention if CFG is on + # Workaround until batched FA2 is fixed in exllamav2 upstream + self.config.no_flash_attn = ( + True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False) + ) # low_mem is currently broken in exllamav2. Don't use it until it's # fixed. @@ -348,10 +361,15 @@ def progress(loaded_modules: int, total_modules: int) if isinstance(value, str): yield value + batch_size = 2 if self.use_cfg else 1 if self.cache_fp8: - self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto) + self.cache = ExLlamaV2Cache_8bit( + self.model, lazy=self.gpu_split_auto, batch_size=batch_size + ) else: - self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto) + self.cache = ExLlamaV2Cache( + self.model, lazy=self.gpu_split_auto, batch_size=batch_size + ) if self.gpu_split_auto: reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16 @@ -561,6 +579,19 @@ def generate_gen(self, prompt: str, **kwargs): gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5) gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1) + # Set CFG scale and negative prompt + cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0) + negative_prompt = None + if cfg_scale not in [None, 1.0]: + if self.use_cfg: + gen_settings.cfg_scale = cfg_scale + negative_prompt = kwargs.get("negative_prompt") + else: + logger.warn( + "CFG is currently disabled. " + + "Please reload your model with use_cfg = True.", + ) + gen_settings.token_presence_penalty = unwrap( kwargs.get("presence_penalty"), 0.0 ) @@ -635,7 +666,7 @@ def generate_gen(self, prompt: str, **kwargs): ) # Log prompt to console - log_prompt(prompt) + log_prompt(prompt, negative_prompt) # Set logit bias if logit_bias: @@ -663,8 +694,18 @@ def generate_gen(self, prompt: str, **kwargs): self.generator.set_stop_conditions(stop_conditions) # Tokenized context - ids = self.tokenizer.encode( - prompt, add_bos=add_bos_token, encode_special_tokens=True + ids, offsets = self.tokenizer.encode( + [prompt, negative_prompt] + if negative_prompt and gen_settings.cfg_scale not in [None, 1.0] + else prompt, + add_bos=add_bos_token, + encode_special_tokens=True, + return_offsets=True, + ) + mask = ( + self.tokenizer.padding_mask(ids) + if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0] + else None ) context_len = len(ids[0]) @@ -683,7 +724,7 @@ def generate_gen(self, prompt: str, **kwargs): start_time = time.time() last_chunk_time = start_time - save_tokens = torch.empty((1, 0), dtype=torch.bool) + save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool) chunk_buffer = "" chunk_tokens = 0 @@ -691,17 +732,31 @@ def generate_gen(self, prompt: str, **kwargs): # Ingest prompt if chunk_tokens == 0: ids = torch.cat((ids, save_tokens), dim=-1) - save_tokens = torch.empty((1, 0), dtype=torch.bool) + save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool) overflow = ids.shape[-1] + generate_window - self.config.max_seq_len active_ids = ids[:, max(0, overflow) :] chunk_tokens = self.config.max_seq_len - active_ids.shape[-1] - self.generator.begin_stream( - active_ids, - gen_settings, - token_healing=token_healing, - loras=self.active_loras, - ) + # Split for exllama versions that have CFG + if self.use_cfg: + self.generator.begin_stream( + active_ids, + gen_settings, + token_healing=token_healing, + loras=self.active_loras, + input_mask=mask, + position_offsets=offsets, + ) + else: + self.generator.begin_stream( + active_ids, + gen_settings, + token_healing=token_healing, + loras=self.active_loras, + ) + + # Reset offsets for subsequent passes if the context is truncated + offsets = None if auto_scale_penalty_range: gen_settings.token_repetition_range = generated_tokens @@ -714,7 +769,9 @@ def generate_gen(self, prompt: str, **kwargs): ids[:, -1] = self.generator.sequence_ids[:, -2] token_healing = False - save_tokens = torch.cat((save_tokens, tokens), dim=-1) + save_tokens = torch.cat( + (save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1 + ) chunk_buffer += chunk generated_tokens += 1