Skip to content

Commit

Permalink
API: Add allowed_tokens support
Browse files Browse the repository at this point in the history
This is the opposite of banned tokens. Exllama specific implementation
of #181.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Aug 30, 2024
1 parent 10d9419 commit 2171257
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
8 changes: 7 additions & 1 deletion backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,11 @@ async def generate_gen(
if banned_tokens:
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)

# Set allowed tokens
allowed_tokens = unwrap(kwargs.get("allowed_tokens"), [])
if allowed_tokens:
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)

# Set logit bias
if logit_bias:
# Create a vocab tensor if it doesn't exist for token biasing
Expand Down Expand Up @@ -1167,7 +1172,7 @@ async def generate_gen(
log_prompt(
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
request_id,
negative_prompt
negative_prompt,
)

# Create and add a new job
Expand Down Expand Up @@ -1313,6 +1318,7 @@ async def generate_gen(
logprobs=request_logprobs,
stop_conditions=stop_conditions,
banned_tokens=banned_tokens,
allowed_tokens=allowed_tokens,
banned_strings=banned_strings,
logit_bias=logit_bias,
filters=grammar_handler.filters,
Expand Down
15 changes: 14 additions & 1 deletion common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ class BaseSamplerRequest(BaseModel):
examples=[[128, 330]],
)

allowed_tokens: Optional[Union[List[int], str]] = Field(
default_factory=lambda: get_default_sampler_value("allowed_tokens", []),
validation_alias=AliasChoices("allowed_tokens", "allowed_token_ids"),
description="Aliases: allowed_token_ids",
examples=[[128, 330]],
)

token_healing: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False)
)
Expand Down Expand Up @@ -287,12 +294,17 @@ def to_gen_params(self, **kwargs):
if self.banned_strings and isinstance(self.banned_strings, str):
self.banned_strings = [self.banned_strings]

# Convert string banned tokens to an integer list
# Convert string banned and allowed tokens to an integer list
if self.banned_tokens and isinstance(self.banned_tokens, str):
self.banned_tokens = [
int(x) for x in self.banned_tokens.split(",") if x.isdigit()
]

if self.allowed_tokens and isinstance(self.allowed_tokens, str):
self.allowed_tokens = [
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
]

gen_params = {
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
Expand All @@ -305,6 +317,7 @@ def to_gen_params(self, **kwargs):
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"banned_tokens": self.banned_tokens,
"allowed_tokens": self.allowed_tokens,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"min_temp": self.min_temp,
Expand Down
4 changes: 4 additions & 0 deletions sampler_overrides/sample_preset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ banned_tokens:
override: []
force: false
additive: false
allowed_tokens:
override: []
force: false
additive: false

# MARK: CFG scale
cfg_scale:
Expand Down

0 comments on commit 2171257

Please sign in to comment.