Skip to content

Commit

Permalink
Model: Add EOS token support from generation_config.json
Browse files Browse the repository at this point in the history
GenerationConfig is meant to override various parts of the model
on generation within the transformers lib. Rather than implementing
the entire GenerationConfig framework (since it's pretty redundant),
add in multi eos_token support like VLLM.

The GenerationConfig is used only for generation, but can be used
for other uses if needed.

If there's more necessary parameters in the future, add those in
as well.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Apr 20, 2024
1 parent 933c5af commit 8824ea0
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
get_template_from_model_json,
get_template_from_file,
)
from common.transformers_utils import GenerationConfig
from common.utils import coalesce, unwrap


Expand All @@ -57,6 +58,7 @@ class ExllamaV2Container:
# Internal config vars
cache_mode: str = "FP16"
use_cfg: bool = False
generation_config: Optional[GenerationConfig] = None

# GPU split vars
gpu_split: Optional[list] = None
Expand Down Expand Up @@ -193,6 +195,21 @@ def progress(loaded_modules: int, total_modules: int,
kwargs.get("prompt_template"), model_directory
)

# Load generation config overrides
generation_config_path = (
pathlib.Path(self.config.model_dir) / "generation_config.json"
)
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)

# Catch all for template lookup errors
if self.prompt_template:
logger.info(
Expand Down Expand Up @@ -566,6 +583,7 @@ def decode_tokens(self, ids: List[int], **kwargs):
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]

# TODO: Maybe support generation_config for eos_token
def get_special_tokens(
self, add_bos_token: bool = True, ban_eos_token: bool = False
):
Expand Down Expand Up @@ -840,13 +858,20 @@ def generate_gen_sync(
grammar_string, gen_settings, self.model, self.tokenizer
)

# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
self.generation_config.eos_tokens()
if self.generation_config
else [self.tokenizer.eos_token_id]
)

# Ban the EOS token if specified. If not, append to stop conditions
# as well.
# Set this below logging to avoid polluting the stop strings array
if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
gen_settings.disallow_tokens(self.tokenizer, eos_tokens)
else:
stop_conditions.append(self.tokenizer.eos_token_id)
stop_conditions += eos_tokens

# Stop conditions
self.generator.set_stop_conditions(stop_conditions)
Expand Down Expand Up @@ -891,6 +916,8 @@ def generate_gen_sync(
token_healing=token_healing,
auto_scale_penalty_range=auto_scale_penalty_range,
generate_window=generate_window,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
speculative_ngram=self.generator.speculative_ngram,
Expand Down

0 comments on commit 8824ea0

Please sign in to comment.