Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CFG support #46

Merged
merged 2 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions OAI/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class CommonCompletionRequest(BaseModel):
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
negative_prompt: Optional[str] = None

# Aliased variables
penalty_range: Optional[int] = Field(
Expand All @@ -86,6 +87,10 @@ class CommonCompletionRequest(BaseModel):
),
)

cfg_scale: Optional[float] = Field(
default=1.0, validation_alias=AliasChoices("cfg_scale", "guidance_scale")
)

def to_gen_params(self):
"""Converts to internal generation parameters."""
# Convert stop to an array of strings
Expand Down Expand Up @@ -115,4 +120,6 @@ def to_gen_params(self):
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
}
1 change: 1 addition & 0 deletions OAI/types/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ModelLoadRequest(BaseModel):
cache_mode: Optional[str] = "FP16"
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_cfg: Optional[bool] = None
draft: Optional[DraftModelLoadRequest] = None


Expand Down
5 changes: 5 additions & 0 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def add_model_args(parser: argparse.ArgumentParser):
type=int,
help="Number of experts to use per token in MoE models",
)
model_group.add_argument(
"--use-cfg",
type=str_to_bool,
help="Enables CFG support",
)


def add_logging_args(parser: argparse.ArgumentParser):
Expand Down
4 changes: 4 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ model:
# NOTE: For MoE models (ex. Mixtral) only!
#num_experts_per_token:

# Enables CFG support (default: False)
# WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream)
use_cfg: False

# Options for draft models (speculative decoding). This will use more VRAM!
#draft:
# Overrides the directory to look for draft (default: models)
Expand Down
8 changes: 6 additions & 2 deletions gen_logging.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Functions for logging generation events.
"""
from typing import Dict
from pydantic import BaseModel
from typing import Dict, Optional

from logger import init_logger

Expand Down Expand Up @@ -53,12 +53,16 @@ def log_generation_params(**kwargs):
logger.info(f"Generation options: {kwargs}\n")


def log_prompt(prompt: str):
def log_prompt(prompt: str, negative_prompt: Optional[str]):
"""Logs the prompt to console."""
if PREFERENCES.prompt:
formatted_prompt = "\n" + prompt
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")

if negative_prompt:
formatted_negative_prompt = "\n" + negative_prompt
logger.info(f"Negative Prompt: {formatted_negative_prompt}\n")


def log_response(response: str):
"""Logs the response to console."""
Expand Down
93 changes: 77 additions & 16 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class ModelContainer:
cache_fp8: bool = False
gpu_split_auto: bool = True
gpu_split: Optional[list] = None
use_cfg: bool = False

active_loras: List[ExLlamaV2Lora] = []

Expand Down Expand Up @@ -95,6 +96,8 @@ def progress(loaded_modules: int, total_modules: int,
tensors, per device
'no_flash_attn' (bool): Turns off flash attention
(increases vram usage) (default: False)
'use_cfg" (bool): Enables CFG support. Disables flash attention
(default: False)
"""

self.quiet = quiet
Expand Down Expand Up @@ -135,8 +138,18 @@ def progress(loaded_modules: int, total_modules: int,
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
)

# Turn off flash attention?
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False)
if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"):
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
else:
logger.warning(
"CFG is not supported by the currently installed ExLlamaV2 version."
)

# Turn off flash attention if CFG is on
# Workaround until batched FA2 is fixed in exllamav2 upstream
self.config.no_flash_attn = (
True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
)

# low_mem is currently broken in exllamav2. Don't use it until it's
# fixed.
Expand Down Expand Up @@ -348,10 +361,15 @@ def progress(loaded_modules: int, total_modules: int)
if isinstance(value, str):
yield value

batch_size = 2 if self.use_cfg else 1
if self.cache_fp8:
self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto)
self.cache = ExLlamaV2Cache_8bit(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)
else:
self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto)
self.cache = ExLlamaV2Cache(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)

if self.gpu_split_auto:
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
Expand Down Expand Up @@ -561,6 +579,23 @@ def generate_gen(self, prompt: str, **kwargs):
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)

# Set CFG scale and negative prompt
cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0)
negative_prompt = None
if cfg_scale not in [None, 1.0]:
if self.use_cfg:
gen_settings.cfg_scale = cfg_scale

# If the negative prompt is empty, use the BOS token
negative_prompt = unwrap(
kwargs.get("negative_prompt"), self.tokenizer.bos_token
)
else:
logger.warn(
"CFG is currently disabled. "
+ "Please reload your model with use_cfg = True.",
)

gen_settings.token_presence_penalty = unwrap(
kwargs.get("presence_penalty"), 0.0
)
Expand Down Expand Up @@ -635,7 +670,7 @@ def generate_gen(self, prompt: str, **kwargs):
)

# Log prompt to console
log_prompt(prompt)
log_prompt(prompt, negative_prompt)

# Set logit bias
if logit_bias:
Expand Down Expand Up @@ -663,8 +698,18 @@ def generate_gen(self, prompt: str, **kwargs):
self.generator.set_stop_conditions(stop_conditions)

# Tokenized context
ids = self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True
ids, offsets = self.tokenizer.encode(
[prompt, negative_prompt]
if negative_prompt and gen_settings.cfg_scale not in [None, 1.0]
else prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
return_offsets=True,
)
mask = (
self.tokenizer.padding_mask(ids)
if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0]
else None
)
context_len = len(ids[0])

Expand All @@ -683,25 +728,39 @@ def generate_gen(self, prompt: str, **kwargs):
start_time = time.time()
last_chunk_time = start_time

save_tokens = torch.empty((1, 0), dtype=torch.bool)
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
chunk_buffer = ""
chunk_tokens = 0

while True:
# Ingest prompt
if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim=-1)
save_tokens = torch.empty((1, 0), dtype=torch.bool)
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
active_ids = ids[:, max(0, overflow) :]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]

self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
)
# Split for exllama versions that have CFG
if self.use_cfg:
self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
input_mask=mask,
position_offsets=offsets,
)
else:
self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
)

# Reset offsets for subsequent passes if the context is truncated
offsets = None

if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens
Expand All @@ -714,7 +773,9 @@ def generate_gen(self, prompt: str, **kwargs):
ids[:, -1] = self.generator.sequence_ids[:, -2]
token_healing = False

save_tokens = torch.cat((save_tokens, tokens), dim=-1)
save_tokens = torch.cat(
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
)
chunk_buffer += chunk

generated_tokens += 1
Expand Down
Loading