Skip to content

Commit

Permalink
Model: Add CFG support
Browse files Browse the repository at this point in the history
CFG, or classifier-free guidance helps push a model in different
directions based on what the user provides.

Currently, CFG is ignored if the negative prompt is blank (it shouldn't
be used in that way anyways).

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Jan 2, 2024
1 parent bb7a8e4 commit b378773
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 18 deletions.
7 changes: 7 additions & 0 deletions OAI/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class CommonCompletionRequest(BaseModel):
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
negative_prompt: Optional[str] = None

# Aliased variables
penalty_range: Optional[int] = Field(
Expand All @@ -86,6 +87,10 @@ class CommonCompletionRequest(BaseModel):
),
)

cfg_scale: Optional[float] = Field(
default=1.0, validation_alias=AliasChoices("cfg_scale", "guidance_scale")
)

def to_gen_params(self):
"""Converts to internal generation parameters."""
# Convert stop to an array of strings
Expand Down Expand Up @@ -115,4 +120,6 @@ def to_gen_params(self):
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
}
1 change: 1 addition & 0 deletions OAI/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ModelLoadRequest(BaseModel):
cache_mode: Optional[str] = "FP16"
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_cfg: Optional[bool] = None
draft: Optional[DraftModelLoadRequest] = None


Expand Down
5 changes: 5 additions & 0 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def add_model_args(parser: argparse.ArgumentParser):
type=int,
help="Number of experts to use per token in MoE models",
)
model_group.add_argument(
"--use-cfg",
type=str_to_bool,
help="Enables CFG support",
)


def add_logging_args(parser: argparse.ArgumentParser):
Expand Down
4 changes: 4 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ model:
# NOTE: For MoE models (ex. Mixtral) only!
#num_experts_per_token:

# Enables CFG support (default: False)
# WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream)
use_cfg: False

# Options for draft models (speculative decoding). This will use more VRAM!
#draft:
# Overrides the directory to look for draft (default: models)
Expand Down
8 changes: 6 additions & 2 deletions gen_logging.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Functions for logging generation events.
"""
from typing import Dict
from pydantic import BaseModel
from typing import Dict, Optional

from logger import init_logger

Expand Down Expand Up @@ -53,12 +53,16 @@ def log_generation_params(**kwargs):
logger.info(f"Generation options: {kwargs}\n")


def log_prompt(prompt: str):
def log_prompt(prompt: str, negative_prompt: Optional[str]):
"""Logs the prompt to console."""
if PREFERENCES.prompt:
formatted_prompt = "\n" + prompt
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")

if negative_prompt:
formatted_negative_prompt = "\n" + negative_prompt
logger.info(f"Negative Prompt: {formatted_negative_prompt}\n")


def log_response(response: str):
"""Logs the response to console."""
Expand Down
89 changes: 73 additions & 16 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ModelContainer:
cache_fp8: bool = False
gpu_split_auto: bool = True
gpu_split: Optional[list] = None
use_cfg: bool = False

active_loras: List[ExLlamaV2Lora] = []

Expand Down Expand Up @@ -95,6 +96,8 @@ def progress(loaded_modules: int, total_modules: int,
tensors, per device
'no_flash_attn' (bool): Turns off flash attention
(increases vram usage) (default: False)
'use_cfg" (bool): Enables CFG support. Disables flash attention
(default: False)
"""

self.quiet = quiet
Expand Down Expand Up @@ -135,8 +138,18 @@ def progress(loaded_modules: int, total_modules: int,
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
)

# Turn off flash attention?
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False)
if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"):
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
else:
logger.warning(
"CFG is not supported by the currently installed ExLlamaV2 version."
)

# Turn off flash attention if CFG is on
# Workaround until batched FA2 is fixed in exllamav2 upstream
self.config.no_flash_attn = (
True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
)

# low_mem is currently broken in exllamav2. Don't use it until it's
# fixed.
Expand Down Expand Up @@ -348,10 +361,15 @@ def progress(loaded_modules: int, total_modules: int)
if isinstance(value, str):
yield value

batch_size = 2 if self.use_cfg else 1
if self.cache_fp8:
self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto)
self.cache = ExLlamaV2Cache_8bit(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)
else:
self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto)
self.cache = ExLlamaV2Cache(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)

if self.gpu_split_auto:
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
Expand Down Expand Up @@ -561,6 +579,19 @@ def generate_gen(self, prompt: str, **kwargs):
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)

# Set CFG scale and negative prompt
cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0)
negative_prompt = None
if cfg_scale not in [None, 1.0]:
if self.use_cfg:
gen_settings.cfg_scale = cfg_scale
negative_prompt = kwargs.get("negative_prompt")
else:
logger.warn(
"CFG is currently disabled. "
+ "Please reload your model with use_cfg = True.",
)

gen_settings.token_presence_penalty = unwrap(
kwargs.get("presence_penalty"), 0.0
)
Expand Down Expand Up @@ -635,7 +666,7 @@ def generate_gen(self, prompt: str, **kwargs):
)

# Log prompt to console
log_prompt(prompt)
log_prompt(prompt, negative_prompt)

# Set logit bias
if logit_bias:
Expand Down Expand Up @@ -663,8 +694,18 @@ def generate_gen(self, prompt: str, **kwargs):
self.generator.set_stop_conditions(stop_conditions)

# Tokenized context
ids = self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True
ids, offsets = self.tokenizer.encode(
[prompt, negative_prompt]
if negative_prompt and gen_settings.cfg_scale not in [None, 1.0]
else prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
return_offsets=True,
)
mask = (
self.tokenizer.padding_mask(ids)
if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0]
else None
)
context_len = len(ids[0])

Expand All @@ -683,25 +724,39 @@ def generate_gen(self, prompt: str, **kwargs):
start_time = time.time()
last_chunk_time = start_time

save_tokens = torch.empty((1, 0), dtype=torch.bool)
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
chunk_buffer = ""
chunk_tokens = 0

while True:
# Ingest prompt
if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim=-1)
save_tokens = torch.empty((1, 0), dtype=torch.bool)
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
active_ids = ids[:, max(0, overflow) :]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]

self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
)
# Split for exllama versions that have CFG
if self.use_cfg:
self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
input_mask=mask,
position_offsets=offsets,
)
else:
self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
)

# Reset offsets for subsequent passes if the context is truncated
offsets = None

if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens
Expand All @@ -714,7 +769,9 @@ def generate_gen(self, prompt: str, **kwargs):
ids[:, -1] = self.generator.sequence_ids[:, -2]
token_healing = False

save_tokens = torch.cat((save_tokens, tokens), dim=-1)
save_tokens = torch.cat(
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
)
chunk_buffer += chunk

generated_tokens += 1
Expand Down

0 comments on commit b378773

Please sign in to comment.