Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: banned strings #833

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aphrodite/common/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class SamplingParams(
input into sections where repetition is evaluated separately.
Common examples are newlines, quotes, and other structural tokens.
Defaults to None.
banned_strings: A list of banned strings.
"""

n: int = 1
Expand Down Expand Up @@ -210,6 +211,7 @@ class SamplingParams(
prompt_logprobs: Optional[int] = None
detokenize: bool = True
custom_token_bans: Optional[List[int]] = None
banned_strings: Optional[List[List[int]]] = None
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
# Optional[List[LogitsProcessorFunc]] type.
Expand Down Expand Up @@ -264,6 +266,7 @@ class SamplingParams(
"prompt_logprobs": None,
"detokenize": True,
"custom_token_bans": [],
"banned_strings": [],
"skip_special_tokens": True,
"spaces_between_special_tokens": True,
"include_stop_str_in_output": False,
Expand Down
65 changes: 52 additions & 13 deletions aphrodite/common/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,17 @@ class SequenceStatus(enum.IntEnum):
WAITING = 0
RUNNING = 1
SWAPPED = 2
# Note: anything after SWAPPED (2) will be considered
BUFFERING = 3
# Note: anything after BUFFERING (3) will be considered
# as a finished status.
FINISHED_STOPPED = 3
FINISHED_LENGTH_CAPPED = 4
FINISHED_ABORTED = 5
FINISHED_IGNORED = 6
FINISHED_STOPPED = 4
FINISHED_LENGTH_CAPPED = 5
FINISHED_ABORTED = 6
FINISHED_IGNORED = 7

@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
return status > SequenceStatus.SWAPPED
return status > SequenceStatus.BUFFERING

@staticmethod
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
Expand Down Expand Up @@ -147,6 +148,8 @@ class SequenceData(msgspec.Struct,

_new_appended_tokens: List[int] = msgspec.field(default_factory=list)

_status: SequenceStatus = SequenceStatus.BUFFERING

def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l"
Expand All @@ -159,6 +162,15 @@ def _update_cached_all_tokens(self):
assert isinstance(self._prompt_token_ids, array)
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
self._output_token_ids)

def pop_token(self) -> None:
if len(self._output_token_ids) > 0:
tokens = self._get_mutable_output_tokens()
tokens.pop()
self._output_token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
tokens)
self._update_cached_all_tokens()
self._num_computed_tokens = max(0, self._num_computed_tokens - 1)

@property
def cumulative_logprob(self) -> float:
Expand Down Expand Up @@ -202,13 +214,19 @@ def output_token_ids_array(self) -> array:
"""
assert isinstance(self._output_token_ids, array)
return self._output_token_ids

def _get_mutable_output_tokens(self) -> List[int]:
return list(self._output_token_ids)

def append_token_id(self, token_id: int, logprob: float) -> None:
self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
self._cumulative_logprob += logprob

def is_buffering(self) -> bool:
return self._status == SequenceStatus.BUFFERING

def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids)

Expand Down Expand Up @@ -238,9 +256,15 @@ def get_num_computed_tokens(self) -> int:

def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
self._num_computed_tokens += num_new_computed_tokens
assert self._num_computed_tokens <= self.get_len(), (
self._num_computed_tokens, self.get_len())
if self.stage == SequenceStage.PREFILL:
self._num_computed_tokens = num_new_computed_tokens
else:
# Don't increment counter for buffered tokens
if not self.is_buffering:
self._num_computed_tokens += num_new_computed_tokens

# Ensure we don't exceed sequence length
self._num_computed_tokens = min(self._num_computed_tokens, self.get_len())
# If all tokens are computed, it means it is in decoding phase.
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
Expand Down Expand Up @@ -290,6 +314,11 @@ def apply_delta(self, delta: SequenceDataDelta):
@property
def stage(self) -> SequenceStage:
return self._stage

def resest_state_for_recompute(self) -> None:
self._num_computed_tokens = 0
self._stage = SequenceStage.PREFILL
self._new_appended_tokens = []

def __repr__(self) -> str:
return (f"SequenceData("
Expand Down Expand Up @@ -458,11 +487,21 @@ def reset_state_for_recompute(self):
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
logprobs: Optional[Dict[int, Logprob]],
) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
"""Modified to handle buffered tokens (which won't have logprobs)"""
if logprobs is not None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
else:
# For buffered tokens, we don't have logprobs
self.output_logprobs.append({}) # Empty logprobs for buffered token
self.data.append_token_id(token_id, 0.0) # Use 0.0 as placeholder logprob # noqa

def is_buffering(self) -> bool:
"""Check if sequence is currently buffering tokens."""
return self.status == SequenceStatus.BUFFERING

def get_len(self) -> int:
return self.data.get_len()
Expand Down
35 changes: 35 additions & 0 deletions aphrodite/endpoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ class CompletionRequest(OpenAIBaseModel):
dynatemp_exponent: Optional[float] = 1.0
nsigma: Optional[float] = 0.0
custom_token_bans: Optional[List[int]] = None
banned_strings: Optional[Union[str, List[str]]] = None
# doc: end-completion-sampling-params

# doc: begin-completion-extra-params
Expand Down Expand Up @@ -501,6 +502,18 @@ def to_sampling_params(
token_id = tokenizer.encode(f'a{s}')[-1]
dry_sequence_breaker_ids.append(token_id)

banned_sequences = []
if self.banned_strings:
# Handle both single string and list of strings
banned_strings = (
[self.banned_strings] if isinstance(self.banned_strings, str)
else self.banned_strings
)
for s in banned_strings:
# Tokenize each banned string into a list of token IDs
token_ids = tokenizer.encode(s, add_special_tokens=False)
banned_sequences.append(token_ids)

return SamplingParams(
n=self.n,
best_of=self.best_of,
Expand Down Expand Up @@ -548,6 +561,7 @@ def to_sampling_params(
dynatemp_exponent=self.dynatemp_exponent,
nsigma=self.nsigma,
custom_token_bans=self.custom_token_bans,
banned_strings=banned_sequences,
)

@model_validator(mode="before")
Expand Down Expand Up @@ -606,6 +620,27 @@ def parse_dry_sequence_breakers(cls, data):
)

return data

@model_validator(mode='before')
@classmethod
def validate_banned_strings(cls, data):
if 'banned_strings' in data and data['banned_strings'] is not None:
banned = data['banned_strings']
if isinstance(banned, str):
# Single string is fine
return data

if not isinstance(banned, list):
raise ValueError(
"banned_strings must be a string or list of strings"
)

if not all(isinstance(x, str) for x in banned):
raise ValueError(
"All elements in banned_strings must be strings"
)

return data


class EmbeddingRequest(OpenAIBaseModel):
Expand Down
25 changes: 25 additions & 0 deletions aphrodite/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,31 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
sample = outputs.samples[0]
# only have one sequence
seq = seq_group.seqs[0]

buffered_tokens, rollback_len = self.stop_checker.maybe_stop_sequence(
seq,
0, # new_char_count will be calculated after appending
sampling_params,
lora_req=seq_group.lora_request,
)

if rollback_len > 0:
# Roll back the sequence
for _ in range(rollback_len):
seq.output_logprobs.pop()
seq.data.pop_token()
return

if buffered_tokens == []:
return [] # Skip this token

if buffered_tokens:
for token_id in buffered_tokens:
seq.append_token_id(token_id, None)
if sampling_params.detokenize:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)

seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
Expand Down
92 changes: 91 additions & 1 deletion aphrodite/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional
from typing import Callable, Optional, List, Dict, Tuple

from transformers import PreTrainedTokenizer

Expand All @@ -20,13 +20,80 @@ def __init__(self, max_model_len: int,
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self._sequence_buffers: Dict[int, List[int]] = {} # seq_id -> buffered_tokens
self._active_banned_patterns: Dict[int, List[List[int]]] = {} # seq_id -> potential banned patterns
self._rollback_length: Dict[int, int] = {} # seq_id -> num tokens to rollback

def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
if lora_req and lora_req.long_lora_max_len:
return lora_req.long_lora_max_len
else:
return self._max_model_len

def _check_banned_sequences(
self,
seq: Sequence,
token_id: int,
sampling_params: SamplingParams,
) -> Tuple[bool, Optional[List[int]], bool, int]:
"""Check if the token continues a banned sequence.

Returns:
Tuple[should_buffer: bool, tokens_to_release: Optional[List[int]],
should_ban: bool, rollback_length: int]
"""
seq_id = seq.seq_id
buffer = self._sequence_buffers.get(seq_id, [])
active_patterns = self._active_banned_patterns.get(seq_id, [])

# Check if this token starts any banned sequences
if not buffer:
matching_patterns = [
pattern for pattern in sampling_params.banned_strings
if pattern[0] == token_id
]
if matching_patterns:
# Only ban if any pattern is single token
should_ban = any(len(pattern) == 1 for pattern in matching_patterns)
# Buffer if any pattern is longer
should_buffer = any(len(pattern) > 1 for pattern in matching_patterns)

if should_buffer:
self._sequence_buffers[seq_id] = [token_id]
self._active_banned_patterns[seq_id] = [p for p in matching_patterns if len(p) > 1]
seq.status = SequenceStatus.BUFFERING

return should_buffer, None, should_ban, 1 if should_ban else 0

# Check existing buffer
if buffer:
buffer.append(token_id)
next_idx = len(buffer) - 1

# Update active patterns
still_active = []
for pattern in active_patterns:
if len(pattern) > next_idx and pattern[next_idx] == token_id:
if len(pattern) == len(buffer):
# Found complete banned sequence - clear buffers and rollback
del self._sequence_buffers[seq_id]
del self._active_banned_patterns[seq_id]
seq.status = SequenceStatus.RUNNING
return True, None, True, len(buffer)
still_active.append(pattern)

if still_active:
self._active_banned_patterns[seq_id] = still_active
return True, None, False, 0
else:
# No patterns match anymore - release buffer
tokens = self._sequence_buffers.pop(seq_id)
self._active_banned_patterns.pop(seq_id)
seq.status = SequenceStatus.RUNNING
return False, tokens, False, 0

return False, None, False, 0

def maybe_stop_sequence(
self,
seq: Sequence,
Expand Down Expand Up @@ -86,6 +153,29 @@ def maybe_stop_sequence(
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check banned strings if any exist
if sampling_params.banned_strings:
should_buffer, tokens_to_release, should_ban, rollback_len = self._check_banned_sequences(
seq, seq.get_last_token_id(), sampling_params)

if should_ban:
# Remove the banned sequence and roll back
if new_char_count:
seq.output_text = seq.output_text[:-new_char_count]
# Set status to WAITING to trigger rescheduling
seq.status = SequenceStatus.WAITING
# Reset sequence state for recomputation
seq.data.reset_state_for_recompute()
return None, rollback_len

if should_buffer:
return [], 0 # Signal to skip this token

if tokens_to_release:
return tokens_to_release, 0 # Release buffered tokens

return None, 0

@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> Optional[str]:
Expand Down
Loading
Loading