From b255b36e00e33eda506209cf2561847c86d0d034 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 14 Dec 2023 03:09:01 -0600 Subject: [PATCH 01/76] functional grammar.py, missing EBNF --- vllm/grammar.py | 256 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 vllm/grammar.py diff --git a/vllm/grammar.py b/vllm/grammar.py new file mode 100644 index 000000000000..5dbe14aed38b --- /dev/null +++ b/vllm/grammar.py @@ -0,0 +1,256 @@ +import collections +import functools + + +class TokenIndex: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + self.tok_id_map = tokenizer.vocab + + # map id -> token str including whitespace + self.norm_vocab = {} + for token_id in self.tok_id_map.values(): + # TODO: look into difference between tokens, e.g. 28705, 35 are both " 9" + # assert norm_token not in self.norm_vocab, + norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[len(tokenizer.bos_token):] + self.norm_vocab[norm_token] = token_id + + # get index allowing efficient retrieval of valid tokens given a sequence + # given tokens ["art", "artist", "argument", "alice"] + # map "a" -> ["ar", "al"] + # map "ar" -> ["art", "artist"] + # map "art" -> [None, "artist"] (None indicates match) + self.char_map = collections.defaultdict(set) + for word in self.norm_vocab: + for i in range(1, len(word) + 1): + prefix = word[:i] + if i < len(word): + self.char_map[prefix].add(word[i]) + else: + # Add None for complete matches + self.char_map[prefix].add(None) + + + def get_valid_next_charset(self, seq, legal_chars): + results = set(self.char_map[seq]) & legal_chars + return results + + def is_token(self, tok): + return tok in self.norm_vocab + + +class TokenConstraintLogitProcessor: + def __init__(self, tokenizer, nfa): + self.tokenizer = tokenizer + self.token_index = TokenIndex(tokenizer) + self.nfa = nfa + self.prev_token_ids = [] + self.prev_text = "" + + def __call__(self, token_ids, logits): + + # ensure integrity + assert token_ids[:len(self.prev_token_ids)] == self.prev_token_ids + self.prev_token_ids = token_ids + + # get new text and step NFA forward + text = tokenizer.decode(token_ids) + new_text = text[len(self.prev_text):] + self.prev_text = text + self.nfa.step_seq(new_text) + + # get valid new token ids + valid_tokens = set(self.get_allowable_next_token_set()) + valid_token_ids = [ + self.tokenizer.eos_token_id if t is None else self.token_index.norm_vocab[t] + for t in valid_tokens + ] + + if not valid_token_ids: + raise ValueError("Found no valid tokens, this should never occur.") + + logits = [ + logit_val if tok_id in valid_token_ids else -float("inf") + for tok_id, logit_val in zip(sorted(self.token_index.tok_id_map.values()), logits) + ] + return logits + + def get_allowable_next_token_set(self, current_text="", nfa_next_tok_states=False): + """ + Get set of valid tokens. + 1) Ask NFA for legal first char + 3) While legal TokenIndex hasn't been exhausted + A) Ask TokenIndex for legal Nth char set + B) Ask NFA for + """ + if nfa_next_tok_states is None: + return [None] + if nfa_next_tok_states == False: + nfa_next_tok_states = self.nfa.next_char_state_map() + + legal_tokens = [] + + if None in nfa_next_tok_states: + legal_tokens.append(None) + del nfa_next_tok_states[None] + + for char, next_states in nfa_next_tok_states.items(): + # all current sequences are legal per nfa, find legal next token with token index + new_seq = current_text + char + tokidx_next_chars = self.token_index.get_valid_next_charset( + new_seq, + self.nfa.legal_chars + ) + + if self.token_index.is_token(new_seq): + legal_tokens.append(new_seq) + + # given legal next chars in token index, get the subset allowed by NFA and recurse + legal_tokens += self.get_allowable_next_token_set( + new_seq, + self.nfa.simulate_step(tokidx_next_chars, next_states) + ) + + return legal_tokens + + +class EpsilonNFA: + """ + Traverses a Character-Based Epsilon-NFA. + + Used by find valid next character sequences. + + self.nfa (dict): A dictionary representing the NFA. It includes: + - 'states' (list): A list of states (UUIDsn) inn the NFA. + - 'initial_state' (UUID or any hashable ID): The initial state of the NFA. + - 'final_states' (list): A list of final or accepting states (UUIDs). + - 'alphabets' (list): The set of input symbols (characters). + - 'transition_function' (dict): A dictionary representing the state + transitions. Each key is a state (UUID), and its value is another + dictionary mapping input symbols to lists of next states (UUIDs). + + self.nfa should never be mutated. + """ + def __init__(self, nfa): + self.nfa = nfa + + # Set of states you may be in + self.current_states = set([self.nfa["initial_state"]]) + self.current_str = "" + + self.legal_chars = set([char for char in self.nfa["alphabets"] if char != "$"]) + + self._resolved_epsilon_cache = {} + + def step_seq(self, seq): + for char in seq: + self.step(char) + + def step(self, char): + """ + Updates the canonical state + """ + next_states = self.next_char_state_map()[char] + if not next_states: + raise ValueError(f"Illegal transition from '{self.current_str}', no next state for '{char}'") + self.current_states = next_states + self.current_str += char + + def simulate_step(self, chars, state_set=None): + """ + Return map of chars and their resulting next state-set given a current state-set and new chars + """ + if state_set is None: + state_set = self.current_states + state_map = self.next_char_state_map(state_set) + return {tok: state_map[tok] for tok in chars if tok in state_map} + + + def copy(self): + new_nfa = EpsilonNFA(self.nfa) + new_nfa.current_states = self.current_states + new_nfa.current_str = self.current_str + new_nfa._resolved_epsilon_cache = self._resolved_epsilon_cache + return new_nfa + + @property + def allow_stop_next(self): + return None in self.next_char_state_map() + + def next_char_state_map(self, current_states=None): + """ + Creates a mapping of possible next chars to a set of valid states for each char + """ + if current_states is None: + current_states = self.current_states + + char_to_states = collections.defaultdict(set) + + if bool(current_states & set(self.nfa["final_states"])): + char_to_states[None] = None + + for state in self._resolve_epsilon_closure(current_states): + for char, next_states in self.nfa["transition_function"][state].items(): + if next_states and char != "$": + char_to_states[char].update(next_states) + + return char_to_states + + def _resolve_epsilon_closure(self, states): + closure = set() + for state in states: + if state in self._resolved_epsilon_cache: + new_closures = self._resolved_epsilon_cache[state] + else: + new_closures = self._get_epsilon_closure(state) + self._resolved_epsilon_cache[state] = new_closures + closure.update(self._get_epsilon_closure(state)) + return closure + + def _get_epsilon_closure(self, state, visited=None): + if visited is None: + visited = set() + + stack = [state] + while stack: + current_state = stack.pop() + if current_state not in visited: + visited.add(current_state) + stack.extend(self.nfa["transition_function"][current_state].get('$', [])) + + return visited + + +if __name__ == "__main__": + from automata_toolkit import regex_to_nfa + import transformers + import numpy as np + + tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + + sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts))) + + for i in range(4): + + logit_processor = TokenConstraintLogitProcessor( + tokenizer=tokenizer, + nfa=EpsilonNFA(nfa=regex_to_nfa.regex_to_nfa( + r"(large )?(language )((models )+(inference engines ))(are )((useful)+((very )*complex))." + )), + ) + + token_ids = [] + while True: + logits = logit_processor( + token_ids=token_ids, + logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) + ) + new_token_id = sample_from_logits(logits) + token_ids.append(new_token_id) + if new_token_id == tokenizer.eos_token_id: + break + print(f"run #{i}") + print("\ttokenid", token_ids) + print("\ttokens:", [tokenizer.decode(tok_id, ) for tok_id in token_ids]) + print("\tresult:", tokenizer.decode(token_ids, skip_special_tokens=False)) From 5e1b0d226c78ef79c9c9b08c32f26a9b01b9b85a Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 14 Dec 2023 03:10:26 -0600 Subject: [PATCH 02/76] removed redundant code --- vllm/grammar.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 5dbe14aed38b..d67e37732030 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -6,11 +6,9 @@ class TokenIndex: def __init__(self, tokenizer): self.tokenizer = tokenizer - self.tok_id_map = tokenizer.vocab - # map id -> token str including whitespace self.norm_vocab = {} - for token_id in self.tok_id_map.values(): + for token_id in tokenizer.vocab.values(): # TODO: look into difference between tokens, e.g. 28705, 35 are both " 9" # assert norm_token not in self.norm_vocab, norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[len(tokenizer.bos_token):] @@ -72,7 +70,7 @@ def __call__(self, token_ids, logits): logits = [ logit_val if tok_id in valid_token_ids else -float("inf") - for tok_id, logit_val in zip(sorted(self.token_index.tok_id_map.values()), logits) + for tok_id, logit_val in zip(sorted(self.token_index.norm_vocab.values()), logits) ] return logits From 383e8e2c7a83db794e3c257e4105cbde4847c8a6 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 14 Dec 2023 04:09:58 -0600 Subject: [PATCH 03/76] change ordering --- vllm/grammar.py | 150 ++++++++++++++++++++++++------------------------ 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index d67e37732030..051f373d16be 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -38,81 +38,6 @@ def is_token(self, tok): return tok in self.norm_vocab -class TokenConstraintLogitProcessor: - def __init__(self, tokenizer, nfa): - self.tokenizer = tokenizer - self.token_index = TokenIndex(tokenizer) - self.nfa = nfa - self.prev_token_ids = [] - self.prev_text = "" - - def __call__(self, token_ids, logits): - - # ensure integrity - assert token_ids[:len(self.prev_token_ids)] == self.prev_token_ids - self.prev_token_ids = token_ids - - # get new text and step NFA forward - text = tokenizer.decode(token_ids) - new_text = text[len(self.prev_text):] - self.prev_text = text - self.nfa.step_seq(new_text) - - # get valid new token ids - valid_tokens = set(self.get_allowable_next_token_set()) - valid_token_ids = [ - self.tokenizer.eos_token_id if t is None else self.token_index.norm_vocab[t] - for t in valid_tokens - ] - - if not valid_token_ids: - raise ValueError("Found no valid tokens, this should never occur.") - - logits = [ - logit_val if tok_id in valid_token_ids else -float("inf") - for tok_id, logit_val in zip(sorted(self.token_index.norm_vocab.values()), logits) - ] - return logits - - def get_allowable_next_token_set(self, current_text="", nfa_next_tok_states=False): - """ - Get set of valid tokens. - 1) Ask NFA for legal first char - 3) While legal TokenIndex hasn't been exhausted - A) Ask TokenIndex for legal Nth char set - B) Ask NFA for - """ - if nfa_next_tok_states is None: - return [None] - if nfa_next_tok_states == False: - nfa_next_tok_states = self.nfa.next_char_state_map() - - legal_tokens = [] - - if None in nfa_next_tok_states: - legal_tokens.append(None) - del nfa_next_tok_states[None] - - for char, next_states in nfa_next_tok_states.items(): - # all current sequences are legal per nfa, find legal next token with token index - new_seq = current_text + char - tokidx_next_chars = self.token_index.get_valid_next_charset( - new_seq, - self.nfa.legal_chars - ) - - if self.token_index.is_token(new_seq): - legal_tokens.append(new_seq) - - # given legal next chars in token index, get the subset allowed by NFA and recurse - legal_tokens += self.get_allowable_next_token_set( - new_seq, - self.nfa.simulate_step(tokidx_next_chars, next_states) - ) - - return legal_tokens - - class EpsilonNFA: """ Traverses a Character-Based Epsilon-NFA. @@ -219,6 +144,81 @@ def _get_epsilon_closure(self, state, visited=None): return visited +class TokenConstraintLogitProcessor: + def __init__(self, tokenizer, nfa): + self.tokenizer = tokenizer + self.token_index = TokenIndex(tokenizer) + self.nfa = nfa + self.prev_token_ids = [] + self.prev_text = "" + + def __call__(self, token_ids, logits): + + # ensure integrity + assert token_ids[:len(self.prev_token_ids)] == self.prev_token_ids + self.prev_token_ids = token_ids + + # get new text and step NFA forward + text = tokenizer.decode(token_ids) + new_text = text[len(self.prev_text):] + self.prev_text = text + self.nfa.step_seq(new_text) + + # get valid new token ids + valid_tokens = set(self.get_allowable_next_token_set()) + valid_token_ids = [ + self.tokenizer.eos_token_id if t is None else self.token_index.norm_vocab[t] + for t in valid_tokens + ] + + if not valid_token_ids: + raise ValueError("Found no valid tokens, this should never occur.") + + logits = [ + logit_val if tok_id in valid_token_ids else -float("inf") + for tok_id, logit_val in zip(sorted(self.token_index.norm_vocab.values()), logits) + ] + return logits + + + def get_allowable_next_token_set(self, current_text="", nfa_next_tok_states=False): + """ + Get set of valid tokens. + 1) Ask NFA for legal first char + 3) While legal TokenIndex hasn't been exhausted + A) Ask TokenIndex for legal Nth char set + B) Ask NFA for + """ + if nfa_next_tok_states is None: + return [None] + if nfa_next_tok_states == False: + nfa_next_tok_states = self.nfa.next_char_state_map() + + legal_tokens = [] + + if None in nfa_next_tok_states: + legal_tokens.append(None) + del nfa_next_tok_states[None] + + for char, next_states in nfa_next_tok_states.items(): + # all current sequences are legal per nfa, find legal next token with token index + new_seq = current_text + char + tokidx_next_chars = self.token_index.get_valid_next_charset( + new_seq, + self.nfa.legal_chars + ) + + if self.token_index.is_token(new_seq): + legal_tokens.append(new_seq) + + # given legal next chars in token index, get the subset allowed by NFA and recurse + legal_tokens += self.get_allowable_next_token_set( + new_seq, + self.nfa.simulate_step(tokidx_next_chars, next_states) + ) + + return legal_tokens + if __name__ == "__main__": from automata_toolkit import regex_to_nfa From c11640703e896dbe8feac8c3711f5d0172e9b487 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 14 Dec 2023 06:31:57 -0600 Subject: [PATCH 04/76] bug fix --- vllm/grammar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 051f373d16be..1a65f5dcbc54 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -176,7 +176,7 @@ def __call__(self, token_ids, logits): logits = [ logit_val if tok_id in valid_token_ids else -float("inf") - for tok_id, logit_val in zip(sorted(self.token_index.norm_vocab.values()), logits) + for tok_id, logit_val in zip(sorted(self.tokenizer.vocab.values()), logits) ] return logits From 8f234f6b85d0ea0d7836eacba6d38bcb5094b026 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 14 Dec 2023 06:35:42 -0600 Subject: [PATCH 05/76] remove unused code --- vllm/grammar.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 1a65f5dcbc54..c258d53fe05f 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -1,5 +1,4 @@ import collections -import functools class TokenIndex: @@ -89,18 +88,6 @@ def simulate_step(self, chars, state_set=None): state_map = self.next_char_state_map(state_set) return {tok: state_map[tok] for tok in chars if tok in state_map} - - def copy(self): - new_nfa = EpsilonNFA(self.nfa) - new_nfa.current_states = self.current_states - new_nfa.current_str = self.current_str - new_nfa._resolved_epsilon_cache = self._resolved_epsilon_cache - return new_nfa - - @property - def allow_stop_next(self): - return None in self.next_char_state_map() - def next_char_state_map(self, current_states=None): """ Creates a mapping of possible next chars to a set of valid states for each char From fc05133309194fbe6ab352af742547c62c4a42e4 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 14 Dec 2023 18:00:27 -0600 Subject: [PATCH 06/76] remove dead code --- vllm/grammar.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index c258d53fe05f..ede736024cd0 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -3,8 +3,6 @@ class TokenIndex: def __init__(self, tokenizer): - self.tokenizer = tokenizer - # map id -> token str including whitespace self.norm_vocab = {} for token_id in tokenizer.vocab.values(): From 72e12f1608fedb3815261bb65707e8f6981efc34 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 12:14:07 -0600 Subject: [PATCH 07/76] basic functioning EBNF grammar generator based on lark --- vllm/lark_interactive.py | 276 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 vllm/lark_interactive.py diff --git a/vllm/lark_interactive.py b/vllm/lark_interactive.py new file mode 100644 index 000000000000..d31e659cdad7 --- /dev/null +++ b/vllm/lark_interactive.py @@ -0,0 +1,276 @@ +from typing import Optional + +from lark import Lark +from lark.parsers.lalr_interactive_parser import InteractiveParser +from lark.parsers.lalr_parser_state import ParserState +from lark.lexer import Token, LexerState, PatternStr, PatternRE +from lark.exceptions import UnexpectedCharacters, UnexpectedToken + +import regex + +from copy import deepcopy, copy +import functools + + +class FastParserState(ParserState): + """ + https://github.com/lark-parser/lark/issues/1142#issuecomment-1863209804 + """ + copy_memo = {} + + def __copy__(self): + new_value_stack = [] + for value in self.value_stack: + key = f"{id(self)}_{id(value)}" + if key not in self.copy_memo: + self.copy_memo[key] = deepcopy(value, self.copy_memo) + new_value_stack.append(self.copy_memo[key]) + + new_instance = type(self)( + self.parse_conf, + self.lexer, # XXX copy + copy(self.state_stack), + new_value_stack, + ) + + self.copy_memo[id(self)] = new_instance + return new_instance + + +class FastInteractiveParser(InteractiveParser): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.parser_state = FastParserState( + self.parser_state.parse_conf, + self.parser_state.lexer, + self.parser_state.state_stack, + self.parser_state.value_stack, + ) + + def __copy__(self): + return type(self)( + self.parser, + copy(self.parser_state), + copy(self.lexer_thread), + ) + + +@functools.lru_cache(10000) +def check_pattern_partial_match(compiled_pattern, seq): + return + + +def get_partial_pattern_validator(pattern): + """ + Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE + Returns a function which validates a partial string + + e.g. for PatternRE "abc*", returns true for "a", "ab", "abc", "abcccc" + """ + if isinstance(pattern, PatternRE): + compiled_pattern = regex.compile(pattern.value) + return ( + lambda seq: compiled_pattern.fullmatch(seq, partial=True) is not None + ) + elif isinstance(pattern, PatternStr): + return ( + lambda seq: pattern.value.startswith(seq) + ) + else: + raise TypeError(f"Invalid pattern type: {type(pattern)}") + + +class InteractivePredictiveLALRParser: + """ + Parser which consumes an EBNF grammar and provides helpers to determine allowable language model tokens + + Interfaces: + - step_seq(sequence): Update the parser with a new sequence to append + - is_valid_next_seq(sequence): Determine whether a candidate sequence is valid + + Core components for terminal level, and sub-terminal level processing: + - 1) Lark LALR parser: Applies state transitions, determining set of valid next-terminals + - 2) Incremental terminal filter: Eliminates next-terminal candidates if terminal pattern doesn't match + """ + + def __init__(self, grammar: str, start: str): + self.parser = Lark( + grammar, + regex=True, # use `regex` not `re` + start=start, + parser='lalr', + ) + base_interactive_parser = self.parser.parse_interactive() + self.interactive_parser = FastInteractiveParser( + base_interactive_parser.parser, + base_interactive_parser.parser_state, + base_interactive_parser.lexer_thread + ) + + + self.partial_seq_validator = { + term.name: get_partial_pattern_validator(term.pattern) + for term in self.parser.terminals + } + + self._ignored_terms = set(self.parser.lexer_conf.ignore) + + # for processing terminals interactively + self.last_terminal_pos = 0 + self.valid_next_terminals = None + + # for calculating `accepts()` efficiently + self._accepts_cache = {} + + self.sequence_history = "" + + def _accepts(self): + if self.sequence_history not in self._accepts_cache: + accepted_terminals = self.interactive_parser.accepts() + self._accepts_cache[self.sequence_history] = accepted_terminals + return self._accepts_cache[self.sequence_history] + + @property + def terminal_partial_seq(self): + """ + Return the incomplete subsequence which will eventually comprise a terminal + """ + return self.sequence_history[self.last_terminal_pos:] + + def step_seq(self, sequence: str): + """ + Append sequence to parser and apply state updates + - Append the sequence to the canonical self.sequence_history + - Parse the changes + - Update the character position of the last complete terminal + - Update the set of candidate terminals + """ + new_seq = self.sequence_history + sequence + self.interactive_parser.lexer_thread.state.text = new_seq + try: + self.interactive_parser.exhaust_lexer() + except UnexpectedCharacters as e: + self.last_terminal_pos = e.pos_in_stream + else: + self.last_terminal_pos = len(new_seq) + + self._update_candidate_terminals() + + if not self.valid_next_terminals: + raise ValueError(f"Invalid continuation for `{self.sequence_history}` `{sequence}`") + + def _update_sequence(self, full_sequence: str): + """Set the complete sequences value in the lexer and base""" + assert self.full_sequence.startswith(self.sequence_history) + self.interactive_parser.lexer_thread.state.text = full_sequence + self.sequence_history = full_sequence + + def _update_candidate_terminals(self): + """ + Update the set of candidate terminals + - If a new terminal is reached, get the accepted set of terminals from the parser + - If the new sequence doesn't comprise a full terminal, filter based on partial pattern match + """ + if not self.terminal_partial_seq: + self.valid_next_terminals = self._accepts() | self._ignored_terms + else: + self.valid_next_terminals = set([ + term for term in self.valid_next_terminals + if self.partial_seq_validator[term](self.terminal_partial_seq) + ]) + + def is_valid_next_seq(self, new_seq: Optional[str]): + """ + Check if current un-terminalized sequence + new_seq is valid for any terminal + + new_seq can be a string or None representing EOS + """ + if new_seq is None: + return "$END" in self.valid_next_terminals + for term in self.valid_next_terminals: + if term == "$END": + continue + if self.partial_seq_validator[term](self.terminal_partial_seq + new_seq): + return True + return False + + + def get_valid_next_tokens(self, token_trie): + valid_node_stack = [] + for term in self.valid_next_terminals: + import pdb;pdb.set_trace() + + + +def test_simple_sequence(parser): + for token in ['{', '"', 'k', 'ey', '":', '"val', 'ue', '"']: + print("full:", parser.sequence_history) + print("adding", token) + parser.step_seq(token) + print("partial:", parser.terminal_partial_seq) + print("valid terms:", parser.valid_next_terminals) + + +def test_valid_next_tokens(parser): + # random complicated json file courtesy of https://github.com/simdjson/simdjson/issues/1316#issue-748663718 + complex_json_file = '{"$schema": "http://json-schema.org/draft-04/schema#", "additionalProperties": false, "properties": {"nc:Vehicle": {"description": "A conveyance designed to carry an operator, passengers and/or cargo, over land.", "oneOf": [{"$ref": "#/definitions/nc:VehicleType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:VehicleType"}}]}, "nc:VehicleAxleQuantity": {"description": "A count of common axles of rotation of one or more wheels of a vehicle, whether power driven or freely rotating.", "oneOf": [{"$ref": "#/definitions/niem-xs:nonNegativeInteger"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:nonNegativeInteger"}}]}, "nc:VehicleMSRPAmount": {"description": "A manufacturer\'s suggested retail price of a vehicle; a price at which a manufacturer recommends a vehicle be sold.", "oneOf": [{"$ref": "#/definitions/nc:AmountType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:AmountType"}}]}, "nc:Amount": {"description": "An amount of money.", "oneOf": [{"$ref": "#/definitions/niem-xs:decimal"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:decimal"}}]}, "nc:Currency": {"description": "A data concept for a unit of money or exchange.", "oneOf": [{"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}, {"type": "array", "items": {"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}}]}, "nc:CurrencyCode": {"description": "A unit of money or exchange.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeType"}, {"type": "array", "items": {"$ref": "#/definitions/iso_4217:CurrencyCodeType"}}]}, "nc:VehicleIdentification": {"description": "A unique identification for a specific vehicle.", "oneOf": [{"$ref": "#/definitions/nc:IdentificationType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:IdentificationType"}}]}, "nc:IdentificationID": {"description": "An identifier.", "oneOf": [{"$ref": "#/definitions/niem-xs:string"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:string"}}]}}, "definitions": {"nc:VehicleType": {"description": "A data type for a conveyance designed to carry an operator, passengers and/or cargo, over land.", "allOf": [{"$ref": "#/definitions/nc:ConveyanceType"}, {"type": "object", "properties": {"nc:VehicleAxleQuantity": {"$ref": "#/properties/nc:VehicleAxleQuantity"}, "nc:VehicleIdentification": {"$ref": "#/properties/nc:VehicleIdentification"}, "nc:VehicleMSRPAmount": {"$ref": "#/properties/nc:VehicleMSRPAmount"}}}]}, "nc:ConveyanceType": {"description": "A data type for a means of transport from place to place.", "allOf": [{"$ref": "#/definitions/_base"}, {"$ref": "#/definitions/nc:ItemType"}, {"type": "object", "properties": {}}]}, "nc:ItemType": {"description": "A data type for an article or thing.", "allOf": [{"$ref": "#/definitions/_base"}, {"type": "object", "properties": {}}]}, "nc:AmountType": {"description": "A data type for an amount of money.", "type": "object", "properties": {"nc:Amount": {"$ref": "#/properties/nc:Amount"}, "nc:Currency": {"$ref": "#/properties/nc:Currency"}}}, "iso_4217:CurrencyCodeType": {"description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}, {"type": "object", "properties": {"rdf:value": {"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}}}]}, "iso_4217:CurrencyCodeSimpleType": {"type": "string", "description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"enum": ["EUR"], "description": "Euro"}, {"enum": ["GBP"], "description": "Pound Sterling"}, {"enum": ["USD"], "description": "US Dollar"}]}, "nc:IdentificationType": {"description": "A data type for a representation of an identity.", "type": "object", "properties": {"nc:IdentificationID": {"$ref": "#/properties/nc:IdentificationID"}}}, "niem-xs:decimal": {"description": "A data type for arbitrary precision decimal numbers.", "type": "number"}, "niem-xs:nonNegativeInteger": {"description": "A data type for an integer with a minimum value of 0.", "type": "number"}, "niem-xs:string": {"description": "A data type for character strings in XML.", "type": "string"}, "_base": {"type": "object", "patternProperties": {"^ism:.*": {"type": "string"}, "^ntk:.*": {"type": "string"}}, "properties": {"@id": {"format": "uriref"}, "@base": {"format": "uriref"}}}}}' + + unicode_chars = [chr(i) for i in range(256)] + + import time + start = time.time() + for char in complex_json_file: + parser.step_seq(char) + for ch in unicode_chars: + parser.is_valid_next_seq(ch) + + print("took", time.time() - start, "seconds to process", len(complex_json_file), "characters") + + +def main(): + # Usage + ebnf_grammar = """ + ?value: dict + | list + | string + | SIGNED_NUMBER -> number + | "true" -> true + | "false" -> false + | "null" -> null + + list : "[" [value ("," value)*] "]" + + dict : "{" [pair ("," pair)*] "}" + pair : string ":" value + + string : ESCAPED_STRING + + %import common.ESCAPED_STRING + %import common.SIGNED_NUMBER + %import common.WS + %ignore WS + """ + + parser = InteractivePredictiveLALRParser(ebnf_grammar, 'value') + test_valid_next_tokens(parser) + + +if __name__ == "__main__": + main() + + profile = True + if profile: + import cProfile + import pstats + from io import StringIO + profile = cProfile.Profile() + profile.enable() + main() + profile.disable() + + # Sorting the statistics by cumulative time + s = StringIO() + sortby = 'cumulative' + ps = pstats.Stats(profile, stream=s).sort_stats(sortby) + ps.print_stats() + print(s.getvalue()) From 50ea6528572ae7d334b9652d32057a0db40637cb Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 12:37:11 -0600 Subject: [PATCH 08/76] clean up --- vllm/lark_interactive.py | 46 ++++++++++++++++------------------------ 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/vllm/lark_interactive.py b/vllm/lark_interactive.py index d31e659cdad7..93cd5c6464d2 100644 --- a/vllm/lark_interactive.py +++ b/vllm/lark_interactive.py @@ -55,11 +55,6 @@ def __copy__(self): ) -@functools.lru_cache(10000) -def check_pattern_partial_match(compiled_pattern, seq): - return - - def get_partial_pattern_validator(pattern): """ Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE @@ -73,8 +68,9 @@ def get_partial_pattern_validator(pattern): lambda seq: compiled_pattern.fullmatch(seq, partial=True) is not None ) elif isinstance(pattern, PatternStr): + base_str = pattern.value return ( - lambda seq: pattern.value.startswith(seq) + lambda seq: base_str.startswith(seq) ) else: raise TypeError(f"Invalid pattern type: {type(pattern)}") @@ -107,7 +103,6 @@ def __init__(self, grammar: str, start: str): base_interactive_parser.lexer_thread ) - self.partial_seq_validator = { term.name: get_partial_pattern_validator(term.pattern) for term in self.parser.terminals @@ -137,7 +132,7 @@ def terminal_partial_seq(self): """ return self.sequence_history[self.last_terminal_pos:] - def step_seq(self, sequence: str): + def step_seq(self, new_seq: str): """ Append sequence to parser and apply state updates - Append the sequence to the canonical self.sequence_history @@ -145,25 +140,24 @@ def step_seq(self, sequence: str): - Update the character position of the last complete terminal - Update the set of candidate terminals """ - new_seq = self.sequence_history + sequence - self.interactive_parser.lexer_thread.state.text = new_seq + self._append_to_sequence(new_seq) + try: self.interactive_parser.exhaust_lexer() except UnexpectedCharacters as e: self.last_terminal_pos = e.pos_in_stream else: - self.last_terminal_pos = len(new_seq) + self.last_terminal_pos = len(self.sequence_history) self._update_candidate_terminals() if not self.valid_next_terminals: raise ValueError(f"Invalid continuation for `{self.sequence_history}` `{sequence}`") - def _update_sequence(self, full_sequence: str): + def _append_to_sequence(self, new_seq: str): """Set the complete sequences value in the lexer and base""" - assert self.full_sequence.startswith(self.sequence_history) - self.interactive_parser.lexer_thread.state.text = full_sequence - self.sequence_history = full_sequence + self.sequence_history += new_seq + self.interactive_parser.lexer_thread.state.text = self.sequence_history def _update_candidate_terminals(self): """ @@ -188,20 +182,12 @@ def is_valid_next_seq(self, new_seq: Optional[str]): if new_seq is None: return "$END" in self.valid_next_terminals for term in self.valid_next_terminals: - if term == "$END": - continue - if self.partial_seq_validator[term](self.terminal_partial_seq + new_seq): - return True + if term != "$END": + if self.partial_seq_validator[term](self.terminal_partial_seq + new_seq): + return True return False - def get_valid_next_tokens(self, token_trie): - valid_node_stack = [] - for term in self.valid_next_terminals: - import pdb;pdb.set_trace() - - - def test_simple_sequence(parser): for token in ['{', '"', 'k', 'ey', '":', '"val', 'ue', '"']: print("full:", parser.sequence_history) @@ -215,7 +201,8 @@ def test_valid_next_tokens(parser): # random complicated json file courtesy of https://github.com/simdjson/simdjson/issues/1316#issue-748663718 complex_json_file = '{"$schema": "http://json-schema.org/draft-04/schema#", "additionalProperties": false, "properties": {"nc:Vehicle": {"description": "A conveyance designed to carry an operator, passengers and/or cargo, over land.", "oneOf": [{"$ref": "#/definitions/nc:VehicleType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:VehicleType"}}]}, "nc:VehicleAxleQuantity": {"description": "A count of common axles of rotation of one or more wheels of a vehicle, whether power driven or freely rotating.", "oneOf": [{"$ref": "#/definitions/niem-xs:nonNegativeInteger"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:nonNegativeInteger"}}]}, "nc:VehicleMSRPAmount": {"description": "A manufacturer\'s suggested retail price of a vehicle; a price at which a manufacturer recommends a vehicle be sold.", "oneOf": [{"$ref": "#/definitions/nc:AmountType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:AmountType"}}]}, "nc:Amount": {"description": "An amount of money.", "oneOf": [{"$ref": "#/definitions/niem-xs:decimal"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:decimal"}}]}, "nc:Currency": {"description": "A data concept for a unit of money or exchange.", "oneOf": [{"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}, {"type": "array", "items": {"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}}]}, "nc:CurrencyCode": {"description": "A unit of money or exchange.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeType"}, {"type": "array", "items": {"$ref": "#/definitions/iso_4217:CurrencyCodeType"}}]}, "nc:VehicleIdentification": {"description": "A unique identification for a specific vehicle.", "oneOf": [{"$ref": "#/definitions/nc:IdentificationType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:IdentificationType"}}]}, "nc:IdentificationID": {"description": "An identifier.", "oneOf": [{"$ref": "#/definitions/niem-xs:string"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:string"}}]}}, "definitions": {"nc:VehicleType": {"description": "A data type for a conveyance designed to carry an operator, passengers and/or cargo, over land.", "allOf": [{"$ref": "#/definitions/nc:ConveyanceType"}, {"type": "object", "properties": {"nc:VehicleAxleQuantity": {"$ref": "#/properties/nc:VehicleAxleQuantity"}, "nc:VehicleIdentification": {"$ref": "#/properties/nc:VehicleIdentification"}, "nc:VehicleMSRPAmount": {"$ref": "#/properties/nc:VehicleMSRPAmount"}}}]}, "nc:ConveyanceType": {"description": "A data type for a means of transport from place to place.", "allOf": [{"$ref": "#/definitions/_base"}, {"$ref": "#/definitions/nc:ItemType"}, {"type": "object", "properties": {}}]}, "nc:ItemType": {"description": "A data type for an article or thing.", "allOf": [{"$ref": "#/definitions/_base"}, {"type": "object", "properties": {}}]}, "nc:AmountType": {"description": "A data type for an amount of money.", "type": "object", "properties": {"nc:Amount": {"$ref": "#/properties/nc:Amount"}, "nc:Currency": {"$ref": "#/properties/nc:Currency"}}}, "iso_4217:CurrencyCodeType": {"description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}, {"type": "object", "properties": {"rdf:value": {"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}}}]}, "iso_4217:CurrencyCodeSimpleType": {"type": "string", "description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"enum": ["EUR"], "description": "Euro"}, {"enum": ["GBP"], "description": "Pound Sterling"}, {"enum": ["USD"], "description": "US Dollar"}]}, "nc:IdentificationType": {"description": "A data type for a representation of an identity.", "type": "object", "properties": {"nc:IdentificationID": {"$ref": "#/properties/nc:IdentificationID"}}}, "niem-xs:decimal": {"description": "A data type for arbitrary precision decimal numbers.", "type": "number"}, "niem-xs:nonNegativeInteger": {"description": "A data type for an integer with a minimum value of 0.", "type": "number"}, "niem-xs:string": {"description": "A data type for character strings in XML.", "type": "string"}, "_base": {"type": "object", "patternProperties": {"^ism:.*": {"type": "string"}, "^ntk:.*": {"type": "string"}}, "properties": {"@id": {"format": "uriref"}, "@base": {"format": "uriref"}}}}}' - unicode_chars = [chr(i) for i in range(256)] + test_chars_per_iter = 1000 + unicode_chars = [chr(i) for i in range(test_chars_per_iter)] import time start = time.time() @@ -224,7 +211,10 @@ def test_valid_next_tokens(parser): for ch in unicode_chars: parser.is_valid_next_seq(ch) - print("took", time.time() - start, "seconds to process", len(complex_json_file), "characters") + print("took", + (time.time() - start) / (len(complex_json_file)), + "seconds per step with", + test_chars_per_iter, "characters in vocabulary") def main(): From 77f347ae17880317e6b914e709797a9bea85a656 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 13:40:53 -0600 Subject: [PATCH 09/76] clean up and add TokenTrie --- vllm/lark_interactive.py | 66 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/vllm/lark_interactive.py b/vllm/lark_interactive.py index 93cd5c6464d2..739432237537 100644 --- a/vllm/lark_interactive.py +++ b/vllm/lark_interactive.py @@ -1,3 +1,7 @@ +from copy import deepcopy, copy +import functools +import os +import regex from typing import Optional from lark import Lark @@ -6,10 +10,6 @@ from lark.lexer import Token, LexerState, PatternStr, PatternRE from lark.exceptions import UnexpectedCharacters, UnexpectedToken -import regex - -from copy import deepcopy, copy -import functools class FastParserState(ParserState): @@ -81,8 +81,8 @@ class InteractivePredictiveLALRParser: Parser which consumes an EBNF grammar and provides helpers to determine allowable language model tokens Interfaces: - - step_seq(sequence): Update the parser with a new sequence to append - - is_valid_next_seq(sequence): Determine whether a candidate sequence is valid + - step_seq(new_seq): Update the parser with a new sequence to append + - is_valid_next_seq(new_seq): Determine whether a candidate sequence is valid Core components for terminal level, and sub-terminal level processing: - 1) Lark LALR parser: Applies state transitions, determining set of valid next-terminals @@ -188,6 +188,60 @@ def is_valid_next_seq(self, new_seq: Optional[str]): return False +class TokenTrie: + def __init__(self, tokenizer): + """ + Trie structure for efficiently finding tokens which are suffixes of other sequences + """ + self.norm_vocab = {} + for token_id in tokenizer.vocab.values(): + norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[len(tokenizer.bos_token):] + self.norm_vocab[norm_token] = token_id + + self.trie = {} + for word in self.norm_vocab: + current_dict = self.trie + for char in word: + if char not in current_dict: + current_dict[char] = {} + current_dict = current_dict[char] + current_dict['is_complete_token'] = True + + def get_next_level_token_prefixes(self, subprefix: str, legal_chars: Optional[set[str]] = None): + """ + Traverse the trie starting from a specified subprefix to identify all child nodes that represent + the longest possible strings without omitting any nodes that contain complete tokens. + """ + def _traverse(node, current_prefix): + # Base case: if the current node is a complete token or has multiple branches + if 'is_complete_token' in node or len(node) > 1: + next_level_prefixes.add(current_prefix) + return + + # Recursive case: continue traversal + for char, next_node in node.items(): + if char != 'is_complete_token': + _traverse(next_node, current_prefix + char) + + # Start from the node corresponding to the subprefix + current_node = self.trie + for char in subprefix: + if char not in current_node: + return [] # Subprefix not in trie + current_node = current_node[char] + + next_level_prefixes = set() + # Filter children based on legal_chars if provided + children = current_node.items() if not legal_chars else ((char, node) for char, node in current_node.items() if char in legal_chars) + + # Traverse for each child + for char, child_node in children: + _traverse(child_node, subprefix + char) + + return list(next_level_prefixes) + + + def test_simple_sequence(parser): for token in ['{', '"', 'k', 'ey', '":', '"val', 'ue', '"']: print("full:", parser.sequence_history) From 69fbd80e3be798912987302b17adb64d20ea9e7c Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 15:05:06 -0600 Subject: [PATCH 10/76] clean up and add NextTokenValidator --- vllm/lark_interactive.py | 155 +++++++++++++++++++++++++++++++-------- 1 file changed, 126 insertions(+), 29 deletions(-) diff --git a/vllm/lark_interactive.py b/vllm/lark_interactive.py index 739432237537..e8094d304223 100644 --- a/vllm/lark_interactive.py +++ b/vllm/lark_interactive.py @@ -1,3 +1,4 @@ +import collections from copy import deepcopy, copy import functools import os @@ -119,6 +120,9 @@ def __init__(self, grammar: str, start: str): self.sequence_history = "" + # initiate + self.step_seq("") + def _accepts(self): if self.sequence_history not in self._accepts_cache: accepted_terminals = self.interactive_parser.accepts() @@ -189,14 +193,17 @@ def is_valid_next_seq(self, new_seq: Optional[str]): class TokenTrie: - def __init__(self, tokenizer): + IS_TOKEN = (None, "is complete token") + + def __init__(self, tokenizer, legal_chars: Optional[set[str]] = None): """ Trie structure for efficiently finding tokens which are suffixes of other sequences """ self.norm_vocab = {} for token_id in tokenizer.vocab.values(): norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[len(tokenizer.bos_token):] - self.norm_vocab[norm_token] = token_id + if legal_chars is None or all([char in legal_chars for char in norm_token]): + self.norm_vocab[norm_token] = token_id self.trie = {} for word in self.norm_vocab: @@ -205,41 +212,126 @@ def __init__(self, tokenizer): if char not in current_dict: current_dict[char] = {} current_dict = current_dict[char] - current_dict['is_complete_token'] = True + current_dict[self.IS_TOKEN] = True - def get_next_level_token_prefixes(self, subprefix: str, legal_chars: Optional[set[str]] = None): + def get_next_level_token_prefixes(self, subprefix: str, _node=None): """ Traverse the trie starting from a specified subprefix to identify all child nodes that represent the longest possible strings without omitting any nodes that contain complete tokens. """ - def _traverse(node, current_prefix): - # Base case: if the current node is a complete token or has multiple branches - if 'is_complete_token' in node or len(node) > 1: - next_level_prefixes.add(current_prefix) - return + # if not first level of recursion, and at a branching point or is a token, or return self + if _node is not None and (len(_node) > 1 or self.IS_TOKEN in _node): + return {subprefix} + + # get the current node if at the first level of recursion + if _node is None: + _node = self.trie + for char in subprefix: + if char not in _node: + return set() + _node = _node[char] + + # Single child, need to go deeper + results = set() + for char, next_node in _node.items(): + if char != self.IS_TOKEN: + results |= self.get_next_level_token_prefixes(subprefix + char, _node=next_node) + return results + + def is_token(self, seq): + return seq in self.norm_vocab + + +class NextTokenValidator: + """ + Given a grammar and a tokenset, construct a parser and token trie. + + Interface: + - step_seq(new_seq): Append a sequence, update internal states + - property valid_token_set: The valid set of tokens within the vocabulary that can occure next + """ + def __init__( + self, + tokenizer, + grammar: str, + grammar_start: str = "start", + num_threads: Optional[int] = None + ): + self.parser = InteractivePredictiveLALRParser( + grammar=grammar, + start=grammar_start + ) + self.tokenizer = tokenizer + self.token_trie = TokenTrie(tokenizer) - # Recursive case: continue traversal - for char, next_node in node.items(): - if char != 'is_complete_token': - _traverse(next_node, current_prefix + char) + if num_threads is None: + self.num_threads = os.cpu_count() // 2 - # Start from the node corresponding to the subprefix - current_node = self.trie - for char in subprefix: - if char not in current_node: - return [] # Subprefix not in trie - current_node = current_node[char] + def step_seq(self, new_seq): + self.parser.step_seq(new_seq) - next_level_prefixes = set() - # Filter children based on legal_chars if provided - children = current_node.items() if not legal_chars else ((char, node) for char, node in current_node.items() if char in legal_chars) + @property + def valid_token_set(self): + """ + Generate the set of valid tokens given the current sequence + + 1) Push all first level token prefixes to the stack + 2) for each token in the stack, validate against the parser + - if valid, add all children to the stack + - if valid AND a token, add to valid_token_set + """ + valid_token_set = set() + token_prefix_stack = collections.deque([""]) + while token_prefix_stack: + print(len(token_prefix_stack)) + token_prefix = token_prefix_stack.pop() + for child_token_prefix in self.token_trie.get_next_level_token_prefixes(token_prefix): + # TODO: Handle EOS token by passing None + if self.parser.is_valid_next_seq(child_token_prefix): + token_prefix_stack.append(child_token_prefix) + if self.token_trie.is_token(child_token_prefix): + valid_token_set.add(child_token_prefix) + + return valid_token_set - # Traverse for each child - for char, child_node in children: - _traverse(child_node, subprefix + char) - return list(next_level_prefixes) +def test_next_token_validator_simple(): + grammar = """ + ?value: "hello" | "world" + """ + tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + ntv = NextTokenValidator(tokenizer, json_grammar, "value") + + valid_toks = ntv.valid_token_set + assert valid_tokns == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} + + +def test_token_trie_sanity_hf_tokenizer(): + """Ensure token trie produces the same number of N 3 letter tokens""" + tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + toktrie = TokenTrie(tokenizer) + + all_prefixes = toktrie.get_next_level_token_prefixes("") + + # every token should be composable from a single unique char, so they will all be len of 1 + assert all([len(p) == 1 for p in all_prefixes]) + + # every token should have one of these prefixes as a start character + assert all([ + t[0] in all_prefixes + for t in toktrie.norm_vocab + ]) + + # construct the set of next level prefixes + all_subprefixes = set() + for pfx in all_prefixes: + all_subprefixes |= toktrie.get_next_level_token_prefixes(pfx) + + import pdb;pdb.set_trace() + + # these should have varying length because some tokens don't have level-2 prefixes + assert len(set([len(spfx) for spfx in all_subprefixes])) > 1 def test_simple_sequence(parser): @@ -273,7 +365,7 @@ def test_valid_next_tokens(parser): def main(): # Usage - ebnf_grammar = """ + json_grammar = """ ?value: dict | list | string @@ -295,12 +387,15 @@ def main(): %ignore WS """ - parser = InteractivePredictiveLALRParser(ebnf_grammar, 'value') + parser = InteractivePredictiveLALRParser(json_grammar, 'value') test_valid_next_tokens(parser) if __name__ == "__main__": - main() + import transformers + test_next_token_validator() + import sys + sys.exit() profile = True if profile: @@ -318,3 +413,5 @@ def main(): ps = pstats.Stats(profile, stream=s).sort_stats(sortby) ps.print_stats() print(s.getvalue()) + else: + main() From 6af9a381236b1149d971ece66d58de22c63a7f4e Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 15:24:11 -0600 Subject: [PATCH 11/76] clean up, add NextTokenValidator.valid_token_id_set --- vllm/lark_interactive.py | 114 ++++++++++++++++++++++++++++----------- 1 file changed, 83 insertions(+), 31 deletions(-) diff --git a/vllm/lark_interactive.py b/vllm/lark_interactive.py index e8094d304223..6d7db55cb0dc 100644 --- a/vllm/lark_interactive.py +++ b/vllm/lark_interactive.py @@ -156,7 +156,7 @@ def step_seq(self, new_seq: str): self._update_candidate_terminals() if not self.valid_next_terminals: - raise ValueError(f"Invalid continuation for `{self.sequence_history}` `{sequence}`") + raise ValueError(f"Invalid continuation for `{self.sequence_history}` `{new_seq}`") def _append_to_sequence(self, new_seq: str): """Set the complete sequences value in the lexer and base""" @@ -205,6 +205,10 @@ def __init__(self, tokenizer, legal_chars: Optional[set[str]] = None): if legal_chars is None or all([char in legal_chars for char in norm_token]): self.norm_vocab[norm_token] = token_id + self.token_to_id_set = collections.defaultdict(set) + for token_str, token_id in self.norm_vocab.items(): + self.token_to_id_set[token_str].add(token_id) + self.trie = {} for word in self.norm_vocab: current_dict = self.trie @@ -283,7 +287,6 @@ def valid_token_set(self): valid_token_set = set() token_prefix_stack = collections.deque([""]) while token_prefix_stack: - print(len(token_prefix_stack)) token_prefix = token_prefix_stack.pop() for child_token_prefix in self.token_trie.get_next_level_token_prefixes(token_prefix): # TODO: Handle EOS token by passing None @@ -294,17 +297,30 @@ def valid_token_set(self): return valid_token_set + @property + def valid_token_id_set(self): + """ + get valid token id based on self.valid_token_set + note that some tokens correspond to multiple IDs + """ + return set.union(*[ + self.token_trie.token_to_id_set[tok] + for tok in self.valid_token_set + ]) + def test_next_token_validator_simple(): - grammar = """ + hello_grammar = """ ?value: "hello" | "world" """ tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - ntv = NextTokenValidator(tokenizer, json_grammar, "value") + ntv = NextTokenValidator(tokenizer, hello_grammar, "value") - valid_toks = ntv.valid_token_set - assert valid_tokns == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} + assert ntv.valid_token_set == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} + + import pdb;pdb.set_trace() + assert ntv.valid_token_id_set == {265, 809, 107, 2805, 21558, 28727, 13436, 22493, 9471} def test_token_trie_sanity_hf_tokenizer(): @@ -344,6 +360,29 @@ def test_simple_sequence(parser): def test_valid_next_tokens(parser): + json_grammar = """ + ?value: dict + | list + | string + | SIGNED_NUMBER -> number + | "true" -> true + | "false" -> false + | "null" -> null + + list : "[" [value ("," value)*] "]" + + dict : "{" [pair ("," pair)*] "}" + pair : string ":" value + + string : ESCAPED_STRING + + %import common.ESCAPED_STRING + %import common.SIGNED_NUMBER + %import common.WS + %ignore WS + """ + + parser = InteractivePredictiveLALRParser(json_grammar, 'value') # random complicated json file courtesy of https://github.com/simdjson/simdjson/issues/1316#issue-748663718 complex_json_file = '{"$schema": "http://json-schema.org/draft-04/schema#", "additionalProperties": false, "properties": {"nc:Vehicle": {"description": "A conveyance designed to carry an operator, passengers and/or cargo, over land.", "oneOf": [{"$ref": "#/definitions/nc:VehicleType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:VehicleType"}}]}, "nc:VehicleAxleQuantity": {"description": "A count of common axles of rotation of one or more wheels of a vehicle, whether power driven or freely rotating.", "oneOf": [{"$ref": "#/definitions/niem-xs:nonNegativeInteger"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:nonNegativeInteger"}}]}, "nc:VehicleMSRPAmount": {"description": "A manufacturer\'s suggested retail price of a vehicle; a price at which a manufacturer recommends a vehicle be sold.", "oneOf": [{"$ref": "#/definitions/nc:AmountType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:AmountType"}}]}, "nc:Amount": {"description": "An amount of money.", "oneOf": [{"$ref": "#/definitions/niem-xs:decimal"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:decimal"}}]}, "nc:Currency": {"description": "A data concept for a unit of money or exchange.", "oneOf": [{"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}, {"type": "array", "items": {"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}}]}, "nc:CurrencyCode": {"description": "A unit of money or exchange.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeType"}, {"type": "array", "items": {"$ref": "#/definitions/iso_4217:CurrencyCodeType"}}]}, "nc:VehicleIdentification": {"description": "A unique identification for a specific vehicle.", "oneOf": [{"$ref": "#/definitions/nc:IdentificationType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:IdentificationType"}}]}, "nc:IdentificationID": {"description": "An identifier.", "oneOf": [{"$ref": "#/definitions/niem-xs:string"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:string"}}]}}, "definitions": {"nc:VehicleType": {"description": "A data type for a conveyance designed to carry an operator, passengers and/or cargo, over land.", "allOf": [{"$ref": "#/definitions/nc:ConveyanceType"}, {"type": "object", "properties": {"nc:VehicleAxleQuantity": {"$ref": "#/properties/nc:VehicleAxleQuantity"}, "nc:VehicleIdentification": {"$ref": "#/properties/nc:VehicleIdentification"}, "nc:VehicleMSRPAmount": {"$ref": "#/properties/nc:VehicleMSRPAmount"}}}]}, "nc:ConveyanceType": {"description": "A data type for a means of transport from place to place.", "allOf": [{"$ref": "#/definitions/_base"}, {"$ref": "#/definitions/nc:ItemType"}, {"type": "object", "properties": {}}]}, "nc:ItemType": {"description": "A data type for an article or thing.", "allOf": [{"$ref": "#/definitions/_base"}, {"type": "object", "properties": {}}]}, "nc:AmountType": {"description": "A data type for an amount of money.", "type": "object", "properties": {"nc:Amount": {"$ref": "#/properties/nc:Amount"}, "nc:Currency": {"$ref": "#/properties/nc:Currency"}}}, "iso_4217:CurrencyCodeType": {"description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}, {"type": "object", "properties": {"rdf:value": {"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}}}]}, "iso_4217:CurrencyCodeSimpleType": {"type": "string", "description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"enum": ["EUR"], "description": "Euro"}, {"enum": ["GBP"], "description": "Pound Sterling"}, {"enum": ["USD"], "description": "US Dollar"}]}, "nc:IdentificationType": {"description": "A data type for a representation of an identity.", "type": "object", "properties": {"nc:IdentificationID": {"$ref": "#/properties/nc:IdentificationID"}}}, "niem-xs:decimal": {"description": "A data type for arbitrary precision decimal numbers.", "type": "number"}, "niem-xs:nonNegativeInteger": {"description": "A data type for an integer with a minimum value of 0.", "type": "number"}, "niem-xs:string": {"description": "A data type for character strings in XML.", "type": "string"}, "_base": {"type": "object", "patternProperties": {"^ism:.*": {"type": "string"}, "^ntk:.*": {"type": "string"}}, "properties": {"@id": {"format": "uriref"}, "@base": {"format": "uriref"}}}}}' @@ -363,41 +402,54 @@ def test_valid_next_tokens(parser): test_chars_per_iter, "characters in vocabulary") -def main(): - # Usage - json_grammar = """ - ?value: dict - | list - | string - | SIGNED_NUMBER -> number - | "true" -> true - | "false" -> false - | "null" -> null - list : "[" [value ("," value)*] "]" +def profile_predictor(): + import pstats + from io import StringIO + import cProfile + hello_grammar = """ + ?value: "hello" | "world" + """ + tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + ntv = NextTokenValidator(tokenizer, hello_grammar, "value") - dict : "{" [pair ("," pair)*] "}" - pair : string ":" value + profile = cProfile.Profile() + profile.enable() + ##### - string : ESCAPED_STRING + valid_toks = ntv.valid_token_set + ntv.step_seq("h") + valid_toks = ntv.valid_token_set + ntv.step_seq("e") + valid_toks = ntv.valid_token_set + ntv.step_seq("l") + valid_toks = ntv.valid_token_set + ntv.step_seq("l") + valid_toks = ntv.valid_token_set + ntv.step_seq("o") - %import common.ESCAPED_STRING - %import common.SIGNED_NUMBER - %import common.WS - %ignore WS - """ + ##### + profile.disable() + + # Sorting the statistics by cumulative time + s = StringIO() + sortby = 'cumulative' + ps = pstats.Stats(profile, stream=s).sort_stats(sortby) + ps.print_stats() + print(s.getvalue()) + + + +def main(): + test_next_token_validator_simple() + profile_predictor() - parser = InteractivePredictiveLALRParser(json_grammar, 'value') - test_valid_next_tokens(parser) if __name__ == "__main__": import transformers - test_next_token_validator() - import sys - sys.exit() - profile = True + profile = False if profile: import cProfile import pstats From c783db75256a702a2690df58c64d5f3c1b89b9b3 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 15:28:53 -0600 Subject: [PATCH 12/76] clean up --- vllm/lark_interactive.py | 48 ++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/vllm/lark_interactive.py b/vllm/lark_interactive.py index 6d7db55cb0dc..9e7e120c40b0 100644 --- a/vllm/lark_interactive.py +++ b/vllm/lark_interactive.py @@ -12,11 +12,13 @@ from lark.exceptions import UnexpectedCharacters, UnexpectedToken - +###################### +# Fix Lark Speed Issue +###################### +""" +https://github.com/lark-parser/lark/issues/1142#issuecomment-1863209804 +""" class FastParserState(ParserState): - """ - https://github.com/lark-parser/lark/issues/1142#issuecomment-1863209804 - """ copy_memo = {} def __copy__(self): @@ -54,6 +56,9 @@ def __copy__(self): copy(self.parser_state), copy(self.lexer_thread), ) +###################### +###################### +###################### def get_partial_pattern_validator(pattern): @@ -252,7 +257,7 @@ class NextTokenValidator: Interface: - step_seq(new_seq): Append a sequence, update internal states - - property valid_token_set: The valid set of tokens within the vocabulary that can occure next + - property valid_token_str_set: The valid set of vocabulary tokens strings which can occur next """ def __init__( self, @@ -265,7 +270,6 @@ def __init__( grammar=grammar, start=grammar_start ) - self.tokenizer = tokenizer self.token_trie = TokenTrie(tokenizer) if num_threads is None: @@ -275,16 +279,16 @@ def step_seq(self, new_seq): self.parser.step_seq(new_seq) @property - def valid_token_set(self): + def valid_token_str_set(self): """ Generate the set of valid tokens given the current sequence 1) Push all first level token prefixes to the stack 2) for each token in the stack, validate against the parser - - if valid, add all children to the stack + - if valid, add all children to the stack for later processing - if valid AND a token, add to valid_token_set """ - valid_token_set = set() + valid_token_str_set = set() token_prefix_stack = collections.deque([""]) while token_prefix_stack: token_prefix = token_prefix_stack.pop() @@ -293,19 +297,19 @@ def valid_token_set(self): if self.parser.is_valid_next_seq(child_token_prefix): token_prefix_stack.append(child_token_prefix) if self.token_trie.is_token(child_token_prefix): - valid_token_set.add(child_token_prefix) + valid_token_str_set.add(child_token_prefix) - return valid_token_set + return valid_token_str_set @property def valid_token_id_set(self): """ - get valid token id based on self.valid_token_set - note that some tokens correspond to multiple IDs + get valid token id based on self.valid_token_str_set + note that some token strings correspond to multiple token IDs """ return set.union(*[ self.token_trie.token_to_id_set[tok] - for tok in self.valid_token_set + for tok in self.valid_token_str_set ]) @@ -317,9 +321,7 @@ def test_next_token_validator_simple(): tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") ntv = NextTokenValidator(tokenizer, hello_grammar, "value") - assert ntv.valid_token_set == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} - - import pdb;pdb.set_trace() + assert ntv.valid_token_str_set == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} assert ntv.valid_token_id_set == {265, 809, 107, 2805, 21558, 28727, 13436, 22493, 9471} @@ -344,8 +346,6 @@ def test_token_trie_sanity_hf_tokenizer(): for pfx in all_prefixes: all_subprefixes |= toktrie.get_next_level_token_prefixes(pfx) - import pdb;pdb.set_trace() - # these should have varying length because some tokens don't have level-2 prefixes assert len(set([len(spfx) for spfx in all_subprefixes])) > 1 @@ -417,15 +417,15 @@ def profile_predictor(): profile.enable() ##### - valid_toks = ntv.valid_token_set + valid_toks = ntv.valid_token_str_set ntv.step_seq("h") - valid_toks = ntv.valid_token_set + valid_toks = ntv.valid_token_str_set ntv.step_seq("e") - valid_toks = ntv.valid_token_set + valid_toks = ntv.valid_token_str_set ntv.step_seq("l") - valid_toks = ntv.valid_token_set + valid_toks = ntv.valid_token_str_set ntv.step_seq("l") - valid_toks = ntv.valid_token_set + valid_toks = ntv.valid_token_str_set ntv.step_seq("o") ##### From bd942029fe22f9e9bae7a339170ae625d34bd444 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 15:32:02 -0600 Subject: [PATCH 13/76] clean up --- vllm/lark_interactive.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/vllm/lark_interactive.py b/vllm/lark_interactive.py index 9e7e120c40b0..bd8d116d42a2 100644 --- a/vllm/lark_interactive.py +++ b/vllm/lark_interactive.py @@ -1,6 +1,5 @@ import collections from copy import deepcopy, copy -import functools import os import regex from typing import Optional @@ -12,12 +11,10 @@ from lark.exceptions import UnexpectedCharacters, UnexpectedToken -###################### +######################################################################### # Fix Lark Speed Issue -###################### -""" -https://github.com/lark-parser/lark/issues/1142#issuecomment-1863209804 -""" +# https://github.com/lark-parser/lark/issues/1142#issuecomment-1863209804 +######################################################################### class FastParserState(ParserState): copy_memo = {} @@ -56,9 +53,8 @@ def __copy__(self): copy(self.parser_state), copy(self.lexer_thread), ) -###################### -###################### -###################### +######################################################################### +######################################################################### def get_partial_pattern_validator(pattern): @@ -272,6 +268,7 @@ def __init__( ) self.token_trie = TokenTrie(tokenizer) + # TODO: threading if num_threads is None: self.num_threads = os.cpu_count() // 2 @@ -313,7 +310,6 @@ def valid_token_id_set(self): ]) - def test_next_token_validator_simple(): hello_grammar = """ ?value: "hello" | "world" From ec6cb14e564b3f9cc18f042bb1cf0f1a56242bed Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 15:50:09 -0600 Subject: [PATCH 14/76] implement GrammarLogitProcessor --- vllm/lark_interactive.py | 83 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 78 insertions(+), 5 deletions(-) diff --git a/vllm/lark_interactive.py b/vllm/lark_interactive.py index bd8d116d42a2..4bd3446c22bc 100644 --- a/vllm/lark_interactive.py +++ b/vllm/lark_interactive.py @@ -262,11 +262,13 @@ def __init__( grammar_start: str = "start", num_threads: Optional[int] = None ): + self.tokenizer = tokenizer + self.token_trie = TokenTrie(tokenizer) + self.parser = InteractivePredictiveLALRParser( grammar=grammar, start=grammar_start ) - self.token_trie = TokenTrie(tokenizer) # TODO: threading if num_threads is None: @@ -310,6 +312,78 @@ def valid_token_id_set(self): ]) +class GrammarLogitProcessor(NextTokenValidator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.generation_token_ids = [] + self.generation_text = "" + + def __call__(self, token_ids, logits): + + # ensure integrity + assert token_ids[:len(self.generation_token_ids)] == self.generation_token_ids + self.generation_token_ids = token_ids + + # step forward + all_text = self.tokenizer.decode(token_ids) + new_text = all_text[len(self.generation_text):] + self.generation_text = all_text + self.step_seq(new_text) + + # get valid token IDs and modify logits + valid_token_ids = self.valid_token_id_set + logits = [ + logit_val if tok_id in valid_token_ids else -float("inf") + for tok_id, logit_val in zip(sorted(self.tokenizer.vocab.values()), logits) + ] + return logits + + +def test_generate_json_randomly_via_logit_processor(): + json_grammar = """ + ?value: dict + | list + | string + | SIGNED_NUMBER -> number + | "true" -> true + | "false" -> false + | "null" -> null + + list : "[" [value ("," value)*] "]" + + dict : "{" [pair ("," pair)*] "}" + pair : string ":" value + + string : ESCAPED_STRING + + %import common.ESCAPED_STRING + %import common.SIGNED_NUMBER + %import common.WS + #%ignore WS # we don't ignore whitespace because that makes the json uninteresting + """ + tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + + logit_processor = GrammarLogitProcessor( + tokenizer, + json_grammar, + grammar_start="value" + ) + + sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts))) + + token_ids = [] + for _ in range(20): + logits = logit_processor( + token_ids=token_ids, + logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) + ) + new_token_id = sample_from_logits(logits) + token_ids.append(new_token_id) + + import pdb;pdb.set_trace() + + def test_next_token_validator_simple(): hello_grammar = """ ?value: "hello" | "world" @@ -355,7 +429,7 @@ def test_simple_sequence(parser): print("valid terms:", parser.valid_next_terminals) -def test_valid_next_tokens(parser): +def test_valid_next_tokens(): json_grammar = """ ?value: dict | list @@ -437,13 +511,12 @@ def profile_predictor(): def main(): - test_next_token_validator_simple() - profile_predictor() - + test_generate_json_randomly_via_logit_processor() if __name__ == "__main__": import transformers + import numpy as np profile = False if profile: From 648c676a523c88c68bc6f85ae9ee60e8c6fc22bc Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 15:51:00 -0600 Subject: [PATCH 15/76] commit grammar work before I delete it so its in my history --- vllm/grammar.py | 196 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 195 insertions(+), 1 deletion(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index ede736024cd0..eb90ea086f31 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -1,4 +1,10 @@ import collections +from parsimonious.grammar import Grammar + +import parsimonious + +from lark import Lark, Transformer, v_args +from parsley import makeGrammar class TokenIndex: @@ -129,6 +135,7 @@ def _get_epsilon_closure(self, state, visited=None): return visited + class TokenConstraintLogitProcessor: def __init__(self, tokenizer, nfa): self.tokenizer = tokenizer @@ -205,11 +212,198 @@ def get_allowable_next_token_set(self, current_text="", nfa_next_tok_states=Fals return legal_tokens +class LarkPDA: + """ + Traverses a Lark Parser's PushDown Automata character by character + """ + def __init__(self, lark_grammar: str, start: str = "value"): + self.parser = Lark( + lark_grammar, + start=start, + lexer="basic" + ) + + def step_seq(self, seq): + for char in seq: + self.step(char) + + def step(self, char): + """ + Updates the canonical state + """ + next_states = self.next_char_state_map()[char] + if not next_states: + raise ValueError(f"Illegal transition from '{self.current_str}', no next state for '{char}'") + self.current_states = next_states + self.current_str += char + + def simulate_step(self, chars, state_set=None): + """ + Return map of chars and their resulting next state-set given a current state-set and new chars + """ + if state_set is None: + state_set = self.current_states + state_map = self.next_char_state_map(state_set) + return {tok: state_map[tok] for tok in chars if tok in state_map} + + def next_char_state_map(self, current_states=None): + """ + Creates a mapping of possible next chars to a set of valid states for each char + """ + if current_states is None: + current_states = self.current_states + + char_to_states = collections.defaultdict(set) + + if bool(current_states & set(self.nfa["final_states"])): + char_to_states[None] = None + + for state in self._resolve_epsilon_closure(current_states): + for char, next_states in self.nfa["transition_function"][state].items(): + if next_states and char != "$": + char_to_states[char].update(next_states) + + return char_to_states + + +def lark_to_pushdown_automata_spec(lark_grammar: str, start: str = "value") -> dict: + parser = Lark(lark_grammar, start=start, lexer="basic") + import pdb;pdb.set_trace() + + +def handle_regex(pattern): + """ + Handle a regex pattern and convert it into a state representation for PDA. + + Args: + pattern (str): The pattern of the regular expression. + + Returns: + str: A unique representation of the regex for the PDA. + """ + # Convert the regex pattern into a suitable representation for the PDA. + # This is a placeholder for the actual conversion logic. + return f"REGEX({pattern})" + + +from parsimonious.nodes import NodeVisitor +from parsimonious.expressions import Literal, Sequence, Regex + +def convert_grammar_to_pda(grammar): + pda = { + 'states': set(), + 'initial_state': 'q0', + 'final_states': {'q_accept'}, + 'stack_alphabet': set(), + 'transition_function': collections.defaultdict(lambda: collections.defaultdict(list)) + } + + def process_expression(state, expr, is_terminal=False): + if isinstance(expr, Literal): + pda['stack_alphabet'].add(expr.literal) + next_state = 'q_accept' if is_terminal else 'q' + str(len(pda['states']) + 1) + pda['transition_function'][state][('epsilon', expr.literal)].append((next_state, [])) + if not is_terminal: + pda['states'].add(next_state) + elif isinstance(expr, Sequence): + current_state = state + for i, member in enumerate(expr.members): + is_last_member = i == len(expr.members) - 1 + process_expression(current_state, member, is_terminal=is_last_member) + if not is_last_member: + current_state = 'q' + str(len(pda['states'])) + elif isinstance(expr, Regex): + raise NotImplementedError("Regex handling not implemented") + + for non_terminal, production in grammar.items(): + pda['states'].add(non_terminal) + pda['stack_alphabet'].add(non_terminal) + process_expression(non_terminal, production) + + return dict(pda) + + +# Test Case +def test_convert_simple_grammar_to_pda(): + grammar = Grammar(""" + expression = term "+" term + term = "number" + """) + + expected_pda = { + 'states': {'expression', 'term', 'q0', 'q_accept'}, + 'initial_state': 'q0', + 'final_states': {'q_accept'}, + 'stack_alphabet': {'expression', 'term', '+', 'number'}, + 'transition_function': { + 'q0': {('epsilon', 'epsilon'): [('expression', ['term', '+', 'term'])]}, + 'expression': {('epsilon', 'term'): [('term', [])]}, + 'term': {('epsilon', 'number'): [('q_accept', [])]} + } + } + + actual_pda = convert_grammar_to_pda(grammar) + import pprint + pprint.pprint(actual_pda) + assert actual_pda == expected_pda, "PDA configuration does not match expected output" + + if __name__ == "__main__": - from automata_toolkit import regex_to_nfa import transformers import numpy as np + test_convert_simple_grammar_to_pda() + import pdb;pdb.set_trace() + + + grammar = r""" + expr = (entry / emptyline)* + entry = section pair* + + section = lpar word rpar ws + pair = key equal value ws? + + key = word+ + value = (word / quoted)+ + word = ~r"[-\w]+" + quoted = ~'"[^\"]+"' + equal = ws? "=" ws? + lpar = "[" + rpar = "]" + ws = ~"\s*" + emptyline = ws+ + """ + + pda_config = convert_grammar_to_pda(Grammar(grammar)) + import pdb;pdb.set_trace() + + try_parsley() + + json_grammar = r""" + ?start: value + + ?value: object + | array + | string + | SIGNED_NUMBER -> number + | "true" -> true + | "false" -> false + | "null" -> null + + array : "[" [value ("," value)*] "]" + object : "{" [pair ("," pair)*] "}" + pair : string ":" value + + string : ESCAPED_STRING + + %import common.ESCAPED_STRING + %import common.SIGNED_NUMBER + %import common.WS + + %ignore WS + """ + #lark_to_pushdown_automata_spec(json_grammar) + tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts))) From 2b9b265ebdc45a372e0d3606cdc86214464da7bd Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 15:51:35 -0600 Subject: [PATCH 16/76] remove old grammar, rename new grammar --- vllm/grammar.py | 783 ++++++++++++++++++++++----------------- vllm/lark_interactive.py | 538 --------------------------- 2 files changed, 444 insertions(+), 877 deletions(-) delete mode 100644 vllm/lark_interactive.py diff --git a/vllm/grammar.py b/vllm/grammar.py index eb90ea086f31..4bd3446c22bc 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -1,171 +1,338 @@ import collections -from parsimonious.grammar import Grammar +from copy import deepcopy, copy +import os +import regex +from typing import Optional + +from lark import Lark +from lark.parsers.lalr_interactive_parser import InteractiveParser +from lark.parsers.lalr_parser_state import ParserState +from lark.lexer import Token, LexerState, PatternStr, PatternRE +from lark.exceptions import UnexpectedCharacters, UnexpectedToken + + +######################################################################### +# Fix Lark Speed Issue +# https://github.com/lark-parser/lark/issues/1142#issuecomment-1863209804 +######################################################################### +class FastParserState(ParserState): + copy_memo = {} + + def __copy__(self): + new_value_stack = [] + for value in self.value_stack: + key = f"{id(self)}_{id(value)}" + if key not in self.copy_memo: + self.copy_memo[key] = deepcopy(value, self.copy_memo) + new_value_stack.append(self.copy_memo[key]) + + new_instance = type(self)( + self.parse_conf, + self.lexer, # XXX copy + copy(self.state_stack), + new_value_stack, + ) -import parsimonious + self.copy_memo[id(self)] = new_instance + return new_instance -from lark import Lark, Transformer, v_args -from parsley import makeGrammar +class FastInteractiveParser(InteractiveParser): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.parser_state = FastParserState( + self.parser_state.parse_conf, + self.parser_state.lexer, + self.parser_state.state_stack, + self.parser_state.value_stack, + ) -class TokenIndex: - def __init__(self, tokenizer): - # map id -> token str including whitespace - self.norm_vocab = {} - for token_id in tokenizer.vocab.values(): - # TODO: look into difference between tokens, e.g. 28705, 35 are both " 9" - # assert norm_token not in self.norm_vocab, - norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[len(tokenizer.bos_token):] - self.norm_vocab[norm_token] = token_id - - # get index allowing efficient retrieval of valid tokens given a sequence - # given tokens ["art", "artist", "argument", "alice"] - # map "a" -> ["ar", "al"] - # map "ar" -> ["art", "artist"] - # map "art" -> [None, "artist"] (None indicates match) - self.char_map = collections.defaultdict(set) - for word in self.norm_vocab: - for i in range(1, len(word) + 1): - prefix = word[:i] - if i < len(word): - self.char_map[prefix].add(word[i]) - else: - # Add None for complete matches - self.char_map[prefix].add(None) + def __copy__(self): + return type(self)( + self.parser, + copy(self.parser_state), + copy(self.lexer_thread), + ) +######################################################################### +######################################################################### - def get_valid_next_charset(self, seq, legal_chars): - results = set(self.char_map[seq]) & legal_chars - return results +def get_partial_pattern_validator(pattern): + """ + Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE + Returns a function which validates a partial string + + e.g. for PatternRE "abc*", returns true for "a", "ab", "abc", "abcccc" + """ + if isinstance(pattern, PatternRE): + compiled_pattern = regex.compile(pattern.value) + return ( + lambda seq: compiled_pattern.fullmatch(seq, partial=True) is not None + ) + elif isinstance(pattern, PatternStr): + base_str = pattern.value + return ( + lambda seq: base_str.startswith(seq) + ) + else: + raise TypeError(f"Invalid pattern type: {type(pattern)}") + - def is_token(self, tok): - return tok in self.norm_vocab +class InteractivePredictiveLALRParser: + """ + Parser which consumes an EBNF grammar and provides helpers to determine allowable language model tokens + Interfaces: + - step_seq(new_seq): Update the parser with a new sequence to append + - is_valid_next_seq(new_seq): Determine whether a candidate sequence is valid -class EpsilonNFA: + Core components for terminal level, and sub-terminal level processing: + - 1) Lark LALR parser: Applies state transitions, determining set of valid next-terminals + - 2) Incremental terminal filter: Eliminates next-terminal candidates if terminal pattern doesn't match """ - Traverses a Character-Based Epsilon-NFA. - Used by find valid next character sequences. + def __init__(self, grammar: str, start: str): + self.parser = Lark( + grammar, + regex=True, # use `regex` not `re` + start=start, + parser='lalr', + ) + base_interactive_parser = self.parser.parse_interactive() + self.interactive_parser = FastInteractiveParser( + base_interactive_parser.parser, + base_interactive_parser.parser_state, + base_interactive_parser.lexer_thread + ) + + self.partial_seq_validator = { + term.name: get_partial_pattern_validator(term.pattern) + for term in self.parser.terminals + } - self.nfa (dict): A dictionary representing the NFA. It includes: - - 'states' (list): A list of states (UUIDsn) inn the NFA. - - 'initial_state' (UUID or any hashable ID): The initial state of the NFA. - - 'final_states' (list): A list of final or accepting states (UUIDs). - - 'alphabets' (list): The set of input symbols (characters). - - 'transition_function' (dict): A dictionary representing the state - transitions. Each key is a state (UUID), and its value is another - dictionary mapping input symbols to lists of next states (UUIDs). + self._ignored_terms = set(self.parser.lexer_conf.ignore) - self.nfa should never be mutated. - """ - def __init__(self, nfa): - self.nfa = nfa + # for processing terminals interactively + self.last_terminal_pos = 0 + self.valid_next_terminals = None - # Set of states you may be in - self.current_states = set([self.nfa["initial_state"]]) - self.current_str = "" + # for calculating `accepts()` efficiently + self._accepts_cache = {} - self.legal_chars = set([char for char in self.nfa["alphabets"] if char != "$"]) + self.sequence_history = "" - self._resolved_epsilon_cache = {} + # initiate + self.step_seq("") - def step_seq(self, seq): - for char in seq: - self.step(char) + def _accepts(self): + if self.sequence_history not in self._accepts_cache: + accepted_terminals = self.interactive_parser.accepts() + self._accepts_cache[self.sequence_history] = accepted_terminals + return self._accepts_cache[self.sequence_history] - def step(self, char): + @property + def terminal_partial_seq(self): """ - Updates the canonical state + Return the incomplete subsequence which will eventually comprise a terminal """ - next_states = self.next_char_state_map()[char] - if not next_states: - raise ValueError(f"Illegal transition from '{self.current_str}', no next state for '{char}'") - self.current_states = next_states - self.current_str += char + return self.sequence_history[self.last_terminal_pos:] - def simulate_step(self, chars, state_set=None): + def step_seq(self, new_seq: str): """ - Return map of chars and their resulting next state-set given a current state-set and new chars + Append sequence to parser and apply state updates + - Append the sequence to the canonical self.sequence_history + - Parse the changes + - Update the character position of the last complete terminal + - Update the set of candidate terminals """ - if state_set is None: - state_set = self.current_states - state_map = self.next_char_state_map(state_set) - return {tok: state_map[tok] for tok in chars if tok in state_map} + self._append_to_sequence(new_seq) + + try: + self.interactive_parser.exhaust_lexer() + except UnexpectedCharacters as e: + self.last_terminal_pos = e.pos_in_stream + else: + self.last_terminal_pos = len(self.sequence_history) + + self._update_candidate_terminals() - def next_char_state_map(self, current_states=None): + if not self.valid_next_terminals: + raise ValueError(f"Invalid continuation for `{self.sequence_history}` `{new_seq}`") + + def _append_to_sequence(self, new_seq: str): + """Set the complete sequences value in the lexer and base""" + self.sequence_history += new_seq + self.interactive_parser.lexer_thread.state.text = self.sequence_history + + def _update_candidate_terminals(self): + """ + Update the set of candidate terminals + - If a new terminal is reached, get the accepted set of terminals from the parser + - If the new sequence doesn't comprise a full terminal, filter based on partial pattern match """ - Creates a mapping of possible next chars to a set of valid states for each char + if not self.terminal_partial_seq: + self.valid_next_terminals = self._accepts() | self._ignored_terms + else: + self.valid_next_terminals = set([ + term for term in self.valid_next_terminals + if self.partial_seq_validator[term](self.terminal_partial_seq) + ]) + + def is_valid_next_seq(self, new_seq: Optional[str]): """ - if current_states is None: - current_states = self.current_states + Check if current un-terminalized sequence + new_seq is valid for any terminal - char_to_states = collections.defaultdict(set) + new_seq can be a string or None representing EOS + """ + if new_seq is None: + return "$END" in self.valid_next_terminals + for term in self.valid_next_terminals: + if term != "$END": + if self.partial_seq_validator[term](self.terminal_partial_seq + new_seq): + return True + return False - if bool(current_states & set(self.nfa["final_states"])): - char_to_states[None] = None - for state in self._resolve_epsilon_closure(current_states): - for char, next_states in self.nfa["transition_function"][state].items(): - if next_states and char != "$": - char_to_states[char].update(next_states) +class TokenTrie: + IS_TOKEN = (None, "is complete token") - return char_to_states + def __init__(self, tokenizer, legal_chars: Optional[set[str]] = None): + """ + Trie structure for efficiently finding tokens which are suffixes of other sequences + """ + self.norm_vocab = {} + for token_id in tokenizer.vocab.values(): + norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[len(tokenizer.bos_token):] + if legal_chars is None or all([char in legal_chars for char in norm_token]): + self.norm_vocab[norm_token] = token_id - def _resolve_epsilon_closure(self, states): - closure = set() - for state in states: - if state in self._resolved_epsilon_cache: - new_closures = self._resolved_epsilon_cache[state] - else: - new_closures = self._get_epsilon_closure(state) - self._resolved_epsilon_cache[state] = new_closures - closure.update(self._get_epsilon_closure(state)) - return closure + self.token_to_id_set = collections.defaultdict(set) + for token_str, token_id in self.norm_vocab.items(): + self.token_to_id_set[token_str].add(token_id) - def _get_epsilon_closure(self, state, visited=None): - if visited is None: - visited = set() + self.trie = {} + for word in self.norm_vocab: + current_dict = self.trie + for char in word: + if char not in current_dict: + current_dict[char] = {} + current_dict = current_dict[char] + current_dict[self.IS_TOKEN] = True + + def get_next_level_token_prefixes(self, subprefix: str, _node=None): + """ + Traverse the trie starting from a specified subprefix to identify all child nodes that represent + the longest possible strings without omitting any nodes that contain complete tokens. + """ + # if not first level of recursion, and at a branching point or is a token, or return self + if _node is not None and (len(_node) > 1 or self.IS_TOKEN in _node): + return {subprefix} + + # get the current node if at the first level of recursion + if _node is None: + _node = self.trie + for char in subprefix: + if char not in _node: + return set() + _node = _node[char] + + # Single child, need to go deeper + results = set() + for char, next_node in _node.items(): + if char != self.IS_TOKEN: + results |= self.get_next_level_token_prefixes(subprefix + char, _node=next_node) + return results - stack = [state] - while stack: - current_state = stack.pop() - if current_state not in visited: - visited.add(current_state) - stack.extend(self.nfa["transition_function"][current_state].get('$', [])) + def is_token(self, seq): + return seq in self.norm_vocab - return visited +class NextTokenValidator: + """ + Given a grammar and a tokenset, construct a parser and token trie. -class TokenConstraintLogitProcessor: - def __init__(self, tokenizer, nfa): + Interface: + - step_seq(new_seq): Append a sequence, update internal states + - property valid_token_str_set: The valid set of vocabulary tokens strings which can occur next + """ + def __init__( + self, + tokenizer, + grammar: str, + grammar_start: str = "start", + num_threads: Optional[int] = None + ): self.tokenizer = tokenizer - self.token_index = TokenIndex(tokenizer) - self.nfa = nfa - self.prev_token_ids = [] - self.prev_text = "" + self.token_trie = TokenTrie(tokenizer) + + self.parser = InteractivePredictiveLALRParser( + grammar=grammar, + start=grammar_start + ) + + # TODO: threading + if num_threads is None: + self.num_threads = os.cpu_count() // 2 + + def step_seq(self, new_seq): + self.parser.step_seq(new_seq) + + @property + def valid_token_str_set(self): + """ + Generate the set of valid tokens given the current sequence + + 1) Push all first level token prefixes to the stack + 2) for each token in the stack, validate against the parser + - if valid, add all children to the stack for later processing + - if valid AND a token, add to valid_token_set + """ + valid_token_str_set = set() + token_prefix_stack = collections.deque([""]) + while token_prefix_stack: + token_prefix = token_prefix_stack.pop() + for child_token_prefix in self.token_trie.get_next_level_token_prefixes(token_prefix): + # TODO: Handle EOS token by passing None + if self.parser.is_valid_next_seq(child_token_prefix): + token_prefix_stack.append(child_token_prefix) + if self.token_trie.is_token(child_token_prefix): + valid_token_str_set.add(child_token_prefix) + + return valid_token_str_set + + @property + def valid_token_id_set(self): + """ + get valid token id based on self.valid_token_str_set + note that some token strings correspond to multiple token IDs + """ + return set.union(*[ + self.token_trie.token_to_id_set[tok] + for tok in self.valid_token_str_set + ]) + + +class GrammarLogitProcessor(NextTokenValidator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.generation_token_ids = [] + self.generation_text = "" def __call__(self, token_ids, logits): # ensure integrity - assert token_ids[:len(self.prev_token_ids)] == self.prev_token_ids - self.prev_token_ids = token_ids - - # get new text and step NFA forward - text = tokenizer.decode(token_ids) - new_text = text[len(self.prev_text):] - self.prev_text = text - self.nfa.step_seq(new_text) - - # get valid new token ids - valid_tokens = set(self.get_allowable_next_token_set()) - valid_token_ids = [ - self.tokenizer.eos_token_id if t is None else self.token_index.norm_vocab[t] - for t in valid_tokens - ] + assert token_ids[:len(self.generation_token_ids)] == self.generation_token_ids + self.generation_token_ids = token_ids - if not valid_token_ids: - raise ValueError("Found no valid tokens, this should never occur.") + # step forward + all_text = self.tokenizer.decode(token_ids) + new_text = all_text[len(self.generation_text):] + self.generation_text = all_text + self.step_seq(new_text) + # get valid token IDs and modify logits + valid_token_ids = self.valid_token_id_set logits = [ logit_val if tok_id in valid_token_ids else -float("inf") for tok_id, logit_val in zip(sorted(self.tokenizer.vocab.values()), logits) @@ -173,261 +340,199 @@ def __call__(self, token_ids, logits): return logits - def get_allowable_next_token_set(self, current_text="", nfa_next_tok_states=False): - """ - Get set of valid tokens. - 1) Ask NFA for legal first char - 3) While legal TokenIndex hasn't been exhausted - A) Ask TokenIndex for legal Nth char set - B) Ask NFA for - """ - if nfa_next_tok_states is None: - return [None] - if nfa_next_tok_states == False: - nfa_next_tok_states = self.nfa.next_char_state_map() - - legal_tokens = [] +def test_generate_json_randomly_via_logit_processor(): + json_grammar = """ + ?value: dict + | list + | string + | SIGNED_NUMBER -> number + | "true" -> true + | "false" -> false + | "null" -> null - if None in nfa_next_tok_states: - legal_tokens.append(None) - del nfa_next_tok_states[None] + list : "[" [value ("," value)*] "]" - for char, next_states in nfa_next_tok_states.items(): - # all current sequences are legal per nfa, find legal next token with token index - new_seq = current_text + char - tokidx_next_chars = self.token_index.get_valid_next_charset( - new_seq, - self.nfa.legal_chars - ) + dict : "{" [pair ("," pair)*] "}" + pair : string ":" value - if self.token_index.is_token(new_seq): - legal_tokens.append(new_seq) + string : ESCAPED_STRING - # given legal next chars in token index, get the subset allowed by NFA and recurse - legal_tokens += self.get_allowable_next_token_set( - new_seq, - self.nfa.simulate_step(tokidx_next_chars, next_states) - ) + %import common.ESCAPED_STRING + %import common.SIGNED_NUMBER + %import common.WS + #%ignore WS # we don't ignore whitespace because that makes the json uninteresting + """ + tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - return legal_tokens + logit_processor = GrammarLogitProcessor( + tokenizer, + json_grammar, + grammar_start="value" + ) + sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts))) -class LarkPDA: - """ - Traverses a Lark Parser's PushDown Automata character by character - """ - def __init__(self, lark_grammar: str, start: str = "value"): - self.parser = Lark( - lark_grammar, - start=start, - lexer="basic" + token_ids = [] + for _ in range(20): + logits = logit_processor( + token_ids=token_ids, + logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) ) + new_token_id = sample_from_logits(logits) + token_ids.append(new_token_id) - def step_seq(self, seq): - for char in seq: - self.step(char) + import pdb;pdb.set_trace() - def step(self, char): - """ - Updates the canonical state - """ - next_states = self.next_char_state_map()[char] - if not next_states: - raise ValueError(f"Illegal transition from '{self.current_str}', no next state for '{char}'") - self.current_states = next_states - self.current_str += char - def simulate_step(self, chars, state_set=None): - """ - Return map of chars and their resulting next state-set given a current state-set and new chars - """ - if state_set is None: - state_set = self.current_states - state_map = self.next_char_state_map(state_set) - return {tok: state_map[tok] for tok in chars if tok in state_map} +def test_next_token_validator_simple(): + hello_grammar = """ + ?value: "hello" | "world" + """ + tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + ntv = NextTokenValidator(tokenizer, hello_grammar, "value") - def next_char_state_map(self, current_states=None): - """ - Creates a mapping of possible next chars to a set of valid states for each char - """ - if current_states is None: - current_states = self.current_states + assert ntv.valid_token_str_set == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} + assert ntv.valid_token_id_set == {265, 809, 107, 2805, 21558, 28727, 13436, 22493, 9471} - char_to_states = collections.defaultdict(set) - if bool(current_states & set(self.nfa["final_states"])): - char_to_states[None] = None +def test_token_trie_sanity_hf_tokenizer(): + """Ensure token trie produces the same number of N 3 letter tokens""" + tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + toktrie = TokenTrie(tokenizer) - for state in self._resolve_epsilon_closure(current_states): - for char, next_states in self.nfa["transition_function"][state].items(): - if next_states and char != "$": - char_to_states[char].update(next_states) + all_prefixes = toktrie.get_next_level_token_prefixes("") - return char_to_states + # every token should be composable from a single unique char, so they will all be len of 1 + assert all([len(p) == 1 for p in all_prefixes]) + # every token should have one of these prefixes as a start character + assert all([ + t[0] in all_prefixes + for t in toktrie.norm_vocab + ]) -def lark_to_pushdown_automata_spec(lark_grammar: str, start: str = "value") -> dict: - parser = Lark(lark_grammar, start=start, lexer="basic") - import pdb;pdb.set_trace() + # construct the set of next level prefixes + all_subprefixes = set() + for pfx in all_prefixes: + all_subprefixes |= toktrie.get_next_level_token_prefixes(pfx) + # these should have varying length because some tokens don't have level-2 prefixes + assert len(set([len(spfx) for spfx in all_subprefixes])) > 1 -def handle_regex(pattern): - """ - Handle a regex pattern and convert it into a state representation for PDA. - Args: - pattern (str): The pattern of the regular expression. +def test_simple_sequence(parser): + for token in ['{', '"', 'k', 'ey', '":', '"val', 'ue', '"']: + print("full:", parser.sequence_history) + print("adding", token) + parser.step_seq(token) + print("partial:", parser.terminal_partial_seq) + print("valid terms:", parser.valid_next_terminals) - Returns: - str: A unique representation of the regex for the PDA. - """ - # Convert the regex pattern into a suitable representation for the PDA. - # This is a placeholder for the actual conversion logic. - return f"REGEX({pattern})" - - -from parsimonious.nodes import NodeVisitor -from parsimonious.expressions import Literal, Sequence, Regex - -def convert_grammar_to_pda(grammar): - pda = { - 'states': set(), - 'initial_state': 'q0', - 'final_states': {'q_accept'}, - 'stack_alphabet': set(), - 'transition_function': collections.defaultdict(lambda: collections.defaultdict(list)) - } - - def process_expression(state, expr, is_terminal=False): - if isinstance(expr, Literal): - pda['stack_alphabet'].add(expr.literal) - next_state = 'q_accept' if is_terminal else 'q' + str(len(pda['states']) + 1) - pda['transition_function'][state][('epsilon', expr.literal)].append((next_state, [])) - if not is_terminal: - pda['states'].add(next_state) - elif isinstance(expr, Sequence): - current_state = state - for i, member in enumerate(expr.members): - is_last_member = i == len(expr.members) - 1 - process_expression(current_state, member, is_terminal=is_last_member) - if not is_last_member: - current_state = 'q' + str(len(pda['states'])) - elif isinstance(expr, Regex): - raise NotImplementedError("Regex handling not implemented") - - for non_terminal, production in grammar.items(): - pda['states'].add(non_terminal) - pda['stack_alphabet'].add(non_terminal) - process_expression(non_terminal, production) - - return dict(pda) - - -# Test Case -def test_convert_simple_grammar_to_pda(): - grammar = Grammar(""" - expression = term "+" term - term = "number" - """) - - expected_pda = { - 'states': {'expression', 'term', 'q0', 'q_accept'}, - 'initial_state': 'q0', - 'final_states': {'q_accept'}, - 'stack_alphabet': {'expression', 'term', '+', 'number'}, - 'transition_function': { - 'q0': {('epsilon', 'epsilon'): [('expression', ['term', '+', 'term'])]}, - 'expression': {('epsilon', 'term'): [('term', [])]}, - 'term': {('epsilon', 'number'): [('q_accept', [])]} - } - } - actual_pda = convert_grammar_to_pda(grammar) - import pprint - pprint.pprint(actual_pda) - assert actual_pda == expected_pda, "PDA configuration does not match expected output" +def test_valid_next_tokens(): + json_grammar = """ + ?value: dict + | list + | string + | SIGNED_NUMBER -> number + | "true" -> true + | "false" -> false + | "null" -> null + list : "[" [value ("," value)*] "]" -if __name__ == "__main__": - import transformers - import numpy as np + dict : "{" [pair ("," pair)*] "}" + pair : string ":" value - test_convert_simple_grammar_to_pda() - import pdb;pdb.set_trace() + string : ESCAPED_STRING + %import common.ESCAPED_STRING + %import common.SIGNED_NUMBER + %import common.WS + %ignore WS + """ - grammar = r""" - expr = (entry / emptyline)* - entry = section pair* + parser = InteractivePredictiveLALRParser(json_grammar, 'value') + # random complicated json file courtesy of https://github.com/simdjson/simdjson/issues/1316#issue-748663718 + complex_json_file = '{"$schema": "http://json-schema.org/draft-04/schema#", "additionalProperties": false, "properties": {"nc:Vehicle": {"description": "A conveyance designed to carry an operator, passengers and/or cargo, over land.", "oneOf": [{"$ref": "#/definitions/nc:VehicleType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:VehicleType"}}]}, "nc:VehicleAxleQuantity": {"description": "A count of common axles of rotation of one or more wheels of a vehicle, whether power driven or freely rotating.", "oneOf": [{"$ref": "#/definitions/niem-xs:nonNegativeInteger"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:nonNegativeInteger"}}]}, "nc:VehicleMSRPAmount": {"description": "A manufacturer\'s suggested retail price of a vehicle; a price at which a manufacturer recommends a vehicle be sold.", "oneOf": [{"$ref": "#/definitions/nc:AmountType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:AmountType"}}]}, "nc:Amount": {"description": "An amount of money.", "oneOf": [{"$ref": "#/definitions/niem-xs:decimal"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:decimal"}}]}, "nc:Currency": {"description": "A data concept for a unit of money or exchange.", "oneOf": [{"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}, {"type": "array", "items": {"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}}]}, "nc:CurrencyCode": {"description": "A unit of money or exchange.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeType"}, {"type": "array", "items": {"$ref": "#/definitions/iso_4217:CurrencyCodeType"}}]}, "nc:VehicleIdentification": {"description": "A unique identification for a specific vehicle.", "oneOf": [{"$ref": "#/definitions/nc:IdentificationType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:IdentificationType"}}]}, "nc:IdentificationID": {"description": "An identifier.", "oneOf": [{"$ref": "#/definitions/niem-xs:string"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:string"}}]}}, "definitions": {"nc:VehicleType": {"description": "A data type for a conveyance designed to carry an operator, passengers and/or cargo, over land.", "allOf": [{"$ref": "#/definitions/nc:ConveyanceType"}, {"type": "object", "properties": {"nc:VehicleAxleQuantity": {"$ref": "#/properties/nc:VehicleAxleQuantity"}, "nc:VehicleIdentification": {"$ref": "#/properties/nc:VehicleIdentification"}, "nc:VehicleMSRPAmount": {"$ref": "#/properties/nc:VehicleMSRPAmount"}}}]}, "nc:ConveyanceType": {"description": "A data type for a means of transport from place to place.", "allOf": [{"$ref": "#/definitions/_base"}, {"$ref": "#/definitions/nc:ItemType"}, {"type": "object", "properties": {}}]}, "nc:ItemType": {"description": "A data type for an article or thing.", "allOf": [{"$ref": "#/definitions/_base"}, {"type": "object", "properties": {}}]}, "nc:AmountType": {"description": "A data type for an amount of money.", "type": "object", "properties": {"nc:Amount": {"$ref": "#/properties/nc:Amount"}, "nc:Currency": {"$ref": "#/properties/nc:Currency"}}}, "iso_4217:CurrencyCodeType": {"description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}, {"type": "object", "properties": {"rdf:value": {"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}}}]}, "iso_4217:CurrencyCodeSimpleType": {"type": "string", "description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"enum": ["EUR"], "description": "Euro"}, {"enum": ["GBP"], "description": "Pound Sterling"}, {"enum": ["USD"], "description": "US Dollar"}]}, "nc:IdentificationType": {"description": "A data type for a representation of an identity.", "type": "object", "properties": {"nc:IdentificationID": {"$ref": "#/properties/nc:IdentificationID"}}}, "niem-xs:decimal": {"description": "A data type for arbitrary precision decimal numbers.", "type": "number"}, "niem-xs:nonNegativeInteger": {"description": "A data type for an integer with a minimum value of 0.", "type": "number"}, "niem-xs:string": {"description": "A data type for character strings in XML.", "type": "string"}, "_base": {"type": "object", "patternProperties": {"^ism:.*": {"type": "string"}, "^ntk:.*": {"type": "string"}}, "properties": {"@id": {"format": "uriref"}, "@base": {"format": "uriref"}}}}}' - section = lpar word rpar ws - pair = key equal value ws? + test_chars_per_iter = 1000 + unicode_chars = [chr(i) for i in range(test_chars_per_iter)] - key = word+ - value = (word / quoted)+ - word = ~r"[-\w]+" - quoted = ~'"[^\"]+"' - equal = ws? "=" ws? - lpar = "[" - rpar = "]" - ws = ~"\s*" - emptyline = ws+ - """ + import time + start = time.time() + for char in complex_json_file: + parser.step_seq(char) + for ch in unicode_chars: + parser.is_valid_next_seq(ch) - pda_config = convert_grammar_to_pda(Grammar(grammar)) - import pdb;pdb.set_trace() + print("took", + (time.time() - start) / (len(complex_json_file)), + "seconds per step with", + test_chars_per_iter, "characters in vocabulary") - try_parsley() - json_grammar = r""" - ?start: value - ?value: object - | array - | string - | SIGNED_NUMBER -> number - | "true" -> true - | "false" -> false - | "null" -> null +def profile_predictor(): + import pstats + from io import StringIO + import cProfile + hello_grammar = """ + ?value: "hello" | "world" + """ + tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + ntv = NextTokenValidator(tokenizer, hello_grammar, "value") - array : "[" [value ("," value)*] "]" - object : "{" [pair ("," pair)*] "}" - pair : string ":" value + profile = cProfile.Profile() + profile.enable() + ##### - string : ESCAPED_STRING + valid_toks = ntv.valid_token_str_set + ntv.step_seq("h") + valid_toks = ntv.valid_token_str_set + ntv.step_seq("e") + valid_toks = ntv.valid_token_str_set + ntv.step_seq("l") + valid_toks = ntv.valid_token_str_set + ntv.step_seq("l") + valid_toks = ntv.valid_token_str_set + ntv.step_seq("o") - %import common.ESCAPED_STRING - %import common.SIGNED_NUMBER - %import common.WS + ##### + profile.disable() - %ignore WS - """ - #lark_to_pushdown_automata_spec(json_grammar) + # Sorting the statistics by cumulative time + s = StringIO() + sortby = 'cumulative' + ps = pstats.Stats(profile, stream=s).sort_stats(sortby) + ps.print_stats() + print(s.getvalue()) - tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts))) - for i in range(4): +def main(): + test_generate_json_randomly_via_logit_processor() - logit_processor = TokenConstraintLogitProcessor( - tokenizer=tokenizer, - nfa=EpsilonNFA(nfa=regex_to_nfa.regex_to_nfa( - r"(large )?(language )((models )+(inference engines ))(are )((useful)+((very )*complex))." - )), - ) - token_ids = [] - while True: - logits = logit_processor( - token_ids=token_ids, - logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) - ) - new_token_id = sample_from_logits(logits) - token_ids.append(new_token_id) - if new_token_id == tokenizer.eos_token_id: - break - print(f"run #{i}") - print("\ttokenid", token_ids) - print("\ttokens:", [tokenizer.decode(tok_id, ) for tok_id in token_ids]) - print("\tresult:", tokenizer.decode(token_ids, skip_special_tokens=False)) +if __name__ == "__main__": + import transformers + import numpy as np + + profile = False + if profile: + import cProfile + import pstats + from io import StringIO + profile = cProfile.Profile() + profile.enable() + main() + profile.disable() + + # Sorting the statistics by cumulative time + s = StringIO() + sortby = 'cumulative' + ps = pstats.Stats(profile, stream=s).sort_stats(sortby) + ps.print_stats() + print(s.getvalue()) + else: + main() diff --git a/vllm/lark_interactive.py b/vllm/lark_interactive.py deleted file mode 100644 index 4bd3446c22bc..000000000000 --- a/vllm/lark_interactive.py +++ /dev/null @@ -1,538 +0,0 @@ -import collections -from copy import deepcopy, copy -import os -import regex -from typing import Optional - -from lark import Lark -from lark.parsers.lalr_interactive_parser import InteractiveParser -from lark.parsers.lalr_parser_state import ParserState -from lark.lexer import Token, LexerState, PatternStr, PatternRE -from lark.exceptions import UnexpectedCharacters, UnexpectedToken - - -######################################################################### -# Fix Lark Speed Issue -# https://github.com/lark-parser/lark/issues/1142#issuecomment-1863209804 -######################################################################### -class FastParserState(ParserState): - copy_memo = {} - - def __copy__(self): - new_value_stack = [] - for value in self.value_stack: - key = f"{id(self)}_{id(value)}" - if key not in self.copy_memo: - self.copy_memo[key] = deepcopy(value, self.copy_memo) - new_value_stack.append(self.copy_memo[key]) - - new_instance = type(self)( - self.parse_conf, - self.lexer, # XXX copy - copy(self.state_stack), - new_value_stack, - ) - - self.copy_memo[id(self)] = new_instance - return new_instance - - -class FastInteractiveParser(InteractiveParser): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.parser_state = FastParserState( - self.parser_state.parse_conf, - self.parser_state.lexer, - self.parser_state.state_stack, - self.parser_state.value_stack, - ) - - def __copy__(self): - return type(self)( - self.parser, - copy(self.parser_state), - copy(self.lexer_thread), - ) -######################################################################### -######################################################################### - - -def get_partial_pattern_validator(pattern): - """ - Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE - Returns a function which validates a partial string - - e.g. for PatternRE "abc*", returns true for "a", "ab", "abc", "abcccc" - """ - if isinstance(pattern, PatternRE): - compiled_pattern = regex.compile(pattern.value) - return ( - lambda seq: compiled_pattern.fullmatch(seq, partial=True) is not None - ) - elif isinstance(pattern, PatternStr): - base_str = pattern.value - return ( - lambda seq: base_str.startswith(seq) - ) - else: - raise TypeError(f"Invalid pattern type: {type(pattern)}") - - -class InteractivePredictiveLALRParser: - """ - Parser which consumes an EBNF grammar and provides helpers to determine allowable language model tokens - - Interfaces: - - step_seq(new_seq): Update the parser with a new sequence to append - - is_valid_next_seq(new_seq): Determine whether a candidate sequence is valid - - Core components for terminal level, and sub-terminal level processing: - - 1) Lark LALR parser: Applies state transitions, determining set of valid next-terminals - - 2) Incremental terminal filter: Eliminates next-terminal candidates if terminal pattern doesn't match - """ - - def __init__(self, grammar: str, start: str): - self.parser = Lark( - grammar, - regex=True, # use `regex` not `re` - start=start, - parser='lalr', - ) - base_interactive_parser = self.parser.parse_interactive() - self.interactive_parser = FastInteractiveParser( - base_interactive_parser.parser, - base_interactive_parser.parser_state, - base_interactive_parser.lexer_thread - ) - - self.partial_seq_validator = { - term.name: get_partial_pattern_validator(term.pattern) - for term in self.parser.terminals - } - - self._ignored_terms = set(self.parser.lexer_conf.ignore) - - # for processing terminals interactively - self.last_terminal_pos = 0 - self.valid_next_terminals = None - - # for calculating `accepts()` efficiently - self._accepts_cache = {} - - self.sequence_history = "" - - # initiate - self.step_seq("") - - def _accepts(self): - if self.sequence_history not in self._accepts_cache: - accepted_terminals = self.interactive_parser.accepts() - self._accepts_cache[self.sequence_history] = accepted_terminals - return self._accepts_cache[self.sequence_history] - - @property - def terminal_partial_seq(self): - """ - Return the incomplete subsequence which will eventually comprise a terminal - """ - return self.sequence_history[self.last_terminal_pos:] - - def step_seq(self, new_seq: str): - """ - Append sequence to parser and apply state updates - - Append the sequence to the canonical self.sequence_history - - Parse the changes - - Update the character position of the last complete terminal - - Update the set of candidate terminals - """ - self._append_to_sequence(new_seq) - - try: - self.interactive_parser.exhaust_lexer() - except UnexpectedCharacters as e: - self.last_terminal_pos = e.pos_in_stream - else: - self.last_terminal_pos = len(self.sequence_history) - - self._update_candidate_terminals() - - if not self.valid_next_terminals: - raise ValueError(f"Invalid continuation for `{self.sequence_history}` `{new_seq}`") - - def _append_to_sequence(self, new_seq: str): - """Set the complete sequences value in the lexer and base""" - self.sequence_history += new_seq - self.interactive_parser.lexer_thread.state.text = self.sequence_history - - def _update_candidate_terminals(self): - """ - Update the set of candidate terminals - - If a new terminal is reached, get the accepted set of terminals from the parser - - If the new sequence doesn't comprise a full terminal, filter based on partial pattern match - """ - if not self.terminal_partial_seq: - self.valid_next_terminals = self._accepts() | self._ignored_terms - else: - self.valid_next_terminals = set([ - term for term in self.valid_next_terminals - if self.partial_seq_validator[term](self.terminal_partial_seq) - ]) - - def is_valid_next_seq(self, new_seq: Optional[str]): - """ - Check if current un-terminalized sequence + new_seq is valid for any terminal - - new_seq can be a string or None representing EOS - """ - if new_seq is None: - return "$END" in self.valid_next_terminals - for term in self.valid_next_terminals: - if term != "$END": - if self.partial_seq_validator[term](self.terminal_partial_seq + new_seq): - return True - return False - - -class TokenTrie: - IS_TOKEN = (None, "is complete token") - - def __init__(self, tokenizer, legal_chars: Optional[set[str]] = None): - """ - Trie structure for efficiently finding tokens which are suffixes of other sequences - """ - self.norm_vocab = {} - for token_id in tokenizer.vocab.values(): - norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[len(tokenizer.bos_token):] - if legal_chars is None or all([char in legal_chars for char in norm_token]): - self.norm_vocab[norm_token] = token_id - - self.token_to_id_set = collections.defaultdict(set) - for token_str, token_id in self.norm_vocab.items(): - self.token_to_id_set[token_str].add(token_id) - - self.trie = {} - for word in self.norm_vocab: - current_dict = self.trie - for char in word: - if char not in current_dict: - current_dict[char] = {} - current_dict = current_dict[char] - current_dict[self.IS_TOKEN] = True - - def get_next_level_token_prefixes(self, subprefix: str, _node=None): - """ - Traverse the trie starting from a specified subprefix to identify all child nodes that represent - the longest possible strings without omitting any nodes that contain complete tokens. - """ - # if not first level of recursion, and at a branching point or is a token, or return self - if _node is not None and (len(_node) > 1 or self.IS_TOKEN in _node): - return {subprefix} - - # get the current node if at the first level of recursion - if _node is None: - _node = self.trie - for char in subprefix: - if char not in _node: - return set() - _node = _node[char] - - # Single child, need to go deeper - results = set() - for char, next_node in _node.items(): - if char != self.IS_TOKEN: - results |= self.get_next_level_token_prefixes(subprefix + char, _node=next_node) - return results - - def is_token(self, seq): - return seq in self.norm_vocab - - -class NextTokenValidator: - """ - Given a grammar and a tokenset, construct a parser and token trie. - - Interface: - - step_seq(new_seq): Append a sequence, update internal states - - property valid_token_str_set: The valid set of vocabulary tokens strings which can occur next - """ - def __init__( - self, - tokenizer, - grammar: str, - grammar_start: str = "start", - num_threads: Optional[int] = None - ): - self.tokenizer = tokenizer - self.token_trie = TokenTrie(tokenizer) - - self.parser = InteractivePredictiveLALRParser( - grammar=grammar, - start=grammar_start - ) - - # TODO: threading - if num_threads is None: - self.num_threads = os.cpu_count() // 2 - - def step_seq(self, new_seq): - self.parser.step_seq(new_seq) - - @property - def valid_token_str_set(self): - """ - Generate the set of valid tokens given the current sequence - - 1) Push all first level token prefixes to the stack - 2) for each token in the stack, validate against the parser - - if valid, add all children to the stack for later processing - - if valid AND a token, add to valid_token_set - """ - valid_token_str_set = set() - token_prefix_stack = collections.deque([""]) - while token_prefix_stack: - token_prefix = token_prefix_stack.pop() - for child_token_prefix in self.token_trie.get_next_level_token_prefixes(token_prefix): - # TODO: Handle EOS token by passing None - if self.parser.is_valid_next_seq(child_token_prefix): - token_prefix_stack.append(child_token_prefix) - if self.token_trie.is_token(child_token_prefix): - valid_token_str_set.add(child_token_prefix) - - return valid_token_str_set - - @property - def valid_token_id_set(self): - """ - get valid token id based on self.valid_token_str_set - note that some token strings correspond to multiple token IDs - """ - return set.union(*[ - self.token_trie.token_to_id_set[tok] - for tok in self.valid_token_str_set - ]) - - -class GrammarLogitProcessor(NextTokenValidator): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.generation_token_ids = [] - self.generation_text = "" - - def __call__(self, token_ids, logits): - - # ensure integrity - assert token_ids[:len(self.generation_token_ids)] == self.generation_token_ids - self.generation_token_ids = token_ids - - # step forward - all_text = self.tokenizer.decode(token_ids) - new_text = all_text[len(self.generation_text):] - self.generation_text = all_text - self.step_seq(new_text) - - # get valid token IDs and modify logits - valid_token_ids = self.valid_token_id_set - logits = [ - logit_val if tok_id in valid_token_ids else -float("inf") - for tok_id, logit_val in zip(sorted(self.tokenizer.vocab.values()), logits) - ] - return logits - - -def test_generate_json_randomly_via_logit_processor(): - json_grammar = """ - ?value: dict - | list - | string - | SIGNED_NUMBER -> number - | "true" -> true - | "false" -> false - | "null" -> null - - list : "[" [value ("," value)*] "]" - - dict : "{" [pair ("," pair)*] "}" - pair : string ":" value - - string : ESCAPED_STRING - - %import common.ESCAPED_STRING - %import common.SIGNED_NUMBER - %import common.WS - #%ignore WS # we don't ignore whitespace because that makes the json uninteresting - """ - tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - - logit_processor = GrammarLogitProcessor( - tokenizer, - json_grammar, - grammar_start="value" - ) - - sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts))) - - token_ids = [] - for _ in range(20): - logits = logit_processor( - token_ids=token_ids, - logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) - ) - new_token_id = sample_from_logits(logits) - token_ids.append(new_token_id) - - import pdb;pdb.set_trace() - - -def test_next_token_validator_simple(): - hello_grammar = """ - ?value: "hello" | "world" - """ - tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - ntv = NextTokenValidator(tokenizer, hello_grammar, "value") - - assert ntv.valid_token_str_set == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} - assert ntv.valid_token_id_set == {265, 809, 107, 2805, 21558, 28727, 13436, 22493, 9471} - - -def test_token_trie_sanity_hf_tokenizer(): - """Ensure token trie produces the same number of N 3 letter tokens""" - tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - toktrie = TokenTrie(tokenizer) - - all_prefixes = toktrie.get_next_level_token_prefixes("") - - # every token should be composable from a single unique char, so they will all be len of 1 - assert all([len(p) == 1 for p in all_prefixes]) - - # every token should have one of these prefixes as a start character - assert all([ - t[0] in all_prefixes - for t in toktrie.norm_vocab - ]) - - # construct the set of next level prefixes - all_subprefixes = set() - for pfx in all_prefixes: - all_subprefixes |= toktrie.get_next_level_token_prefixes(pfx) - - # these should have varying length because some tokens don't have level-2 prefixes - assert len(set([len(spfx) for spfx in all_subprefixes])) > 1 - - -def test_simple_sequence(parser): - for token in ['{', '"', 'k', 'ey', '":', '"val', 'ue', '"']: - print("full:", parser.sequence_history) - print("adding", token) - parser.step_seq(token) - print("partial:", parser.terminal_partial_seq) - print("valid terms:", parser.valid_next_terminals) - - -def test_valid_next_tokens(): - json_grammar = """ - ?value: dict - | list - | string - | SIGNED_NUMBER -> number - | "true" -> true - | "false" -> false - | "null" -> null - - list : "[" [value ("," value)*] "]" - - dict : "{" [pair ("," pair)*] "}" - pair : string ":" value - - string : ESCAPED_STRING - - %import common.ESCAPED_STRING - %import common.SIGNED_NUMBER - %import common.WS - %ignore WS - """ - - parser = InteractivePredictiveLALRParser(json_grammar, 'value') - # random complicated json file courtesy of https://github.com/simdjson/simdjson/issues/1316#issue-748663718 - complex_json_file = '{"$schema": "http://json-schema.org/draft-04/schema#", "additionalProperties": false, "properties": {"nc:Vehicle": {"description": "A conveyance designed to carry an operator, passengers and/or cargo, over land.", "oneOf": [{"$ref": "#/definitions/nc:VehicleType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:VehicleType"}}]}, "nc:VehicleAxleQuantity": {"description": "A count of common axles of rotation of one or more wheels of a vehicle, whether power driven or freely rotating.", "oneOf": [{"$ref": "#/definitions/niem-xs:nonNegativeInteger"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:nonNegativeInteger"}}]}, "nc:VehicleMSRPAmount": {"description": "A manufacturer\'s suggested retail price of a vehicle; a price at which a manufacturer recommends a vehicle be sold.", "oneOf": [{"$ref": "#/definitions/nc:AmountType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:AmountType"}}]}, "nc:Amount": {"description": "An amount of money.", "oneOf": [{"$ref": "#/definitions/niem-xs:decimal"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:decimal"}}]}, "nc:Currency": {"description": "A data concept for a unit of money or exchange.", "oneOf": [{"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}, {"type": "array", "items": {"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}}]}, "nc:CurrencyCode": {"description": "A unit of money or exchange.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeType"}, {"type": "array", "items": {"$ref": "#/definitions/iso_4217:CurrencyCodeType"}}]}, "nc:VehicleIdentification": {"description": "A unique identification for a specific vehicle.", "oneOf": [{"$ref": "#/definitions/nc:IdentificationType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:IdentificationType"}}]}, "nc:IdentificationID": {"description": "An identifier.", "oneOf": [{"$ref": "#/definitions/niem-xs:string"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:string"}}]}}, "definitions": {"nc:VehicleType": {"description": "A data type for a conveyance designed to carry an operator, passengers and/or cargo, over land.", "allOf": [{"$ref": "#/definitions/nc:ConveyanceType"}, {"type": "object", "properties": {"nc:VehicleAxleQuantity": {"$ref": "#/properties/nc:VehicleAxleQuantity"}, "nc:VehicleIdentification": {"$ref": "#/properties/nc:VehicleIdentification"}, "nc:VehicleMSRPAmount": {"$ref": "#/properties/nc:VehicleMSRPAmount"}}}]}, "nc:ConveyanceType": {"description": "A data type for a means of transport from place to place.", "allOf": [{"$ref": "#/definitions/_base"}, {"$ref": "#/definitions/nc:ItemType"}, {"type": "object", "properties": {}}]}, "nc:ItemType": {"description": "A data type for an article or thing.", "allOf": [{"$ref": "#/definitions/_base"}, {"type": "object", "properties": {}}]}, "nc:AmountType": {"description": "A data type for an amount of money.", "type": "object", "properties": {"nc:Amount": {"$ref": "#/properties/nc:Amount"}, "nc:Currency": {"$ref": "#/properties/nc:Currency"}}}, "iso_4217:CurrencyCodeType": {"description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}, {"type": "object", "properties": {"rdf:value": {"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}}}]}, "iso_4217:CurrencyCodeSimpleType": {"type": "string", "description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"enum": ["EUR"], "description": "Euro"}, {"enum": ["GBP"], "description": "Pound Sterling"}, {"enum": ["USD"], "description": "US Dollar"}]}, "nc:IdentificationType": {"description": "A data type for a representation of an identity.", "type": "object", "properties": {"nc:IdentificationID": {"$ref": "#/properties/nc:IdentificationID"}}}, "niem-xs:decimal": {"description": "A data type for arbitrary precision decimal numbers.", "type": "number"}, "niem-xs:nonNegativeInteger": {"description": "A data type for an integer with a minimum value of 0.", "type": "number"}, "niem-xs:string": {"description": "A data type for character strings in XML.", "type": "string"}, "_base": {"type": "object", "patternProperties": {"^ism:.*": {"type": "string"}, "^ntk:.*": {"type": "string"}}, "properties": {"@id": {"format": "uriref"}, "@base": {"format": "uriref"}}}}}' - - test_chars_per_iter = 1000 - unicode_chars = [chr(i) for i in range(test_chars_per_iter)] - - import time - start = time.time() - for char in complex_json_file: - parser.step_seq(char) - for ch in unicode_chars: - parser.is_valid_next_seq(ch) - - print("took", - (time.time() - start) / (len(complex_json_file)), - "seconds per step with", - test_chars_per_iter, "characters in vocabulary") - - - -def profile_predictor(): - import pstats - from io import StringIO - import cProfile - hello_grammar = """ - ?value: "hello" | "world" - """ - tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - ntv = NextTokenValidator(tokenizer, hello_grammar, "value") - - profile = cProfile.Profile() - profile.enable() - ##### - - valid_toks = ntv.valid_token_str_set - ntv.step_seq("h") - valid_toks = ntv.valid_token_str_set - ntv.step_seq("e") - valid_toks = ntv.valid_token_str_set - ntv.step_seq("l") - valid_toks = ntv.valid_token_str_set - ntv.step_seq("l") - valid_toks = ntv.valid_token_str_set - ntv.step_seq("o") - - ##### - profile.disable() - - # Sorting the statistics by cumulative time - s = StringIO() - sortby = 'cumulative' - ps = pstats.Stats(profile, stream=s).sort_stats(sortby) - ps.print_stats() - print(s.getvalue()) - - - -def main(): - test_generate_json_randomly_via_logit_processor() - - -if __name__ == "__main__": - import transformers - import numpy as np - - profile = False - if profile: - import cProfile - import pstats - from io import StringIO - profile = cProfile.Profile() - profile.enable() - main() - profile.disable() - - # Sorting the statistics by cumulative time - s = StringIO() - sortby = 'cumulative' - ps = pstats.Stats(profile, stream=s).sort_stats(sortby) - ps.print_stats() - print(s.getvalue()) - else: - main() From 783bbffc96e6b528cb04882e338c6295fa61a6fd Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 16:46:07 -0600 Subject: [PATCH 17/76] add grammar docs --- docs/source/grammars/grammars.rst | 92 +++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 docs/source/grammars/grammars.rst diff --git a/docs/source/grammars/grammars.rst b/docs/source/grammars/grammars.rst new file mode 100644 index 000000000000..24fb5672af8c --- /dev/null +++ b/docs/source/grammars/grammars.rst @@ -0,0 +1,92 @@ +.. _grammars: + +Grammars +======== + +vLLM offers `Lark `_ based EBNF grammars via ``vllm.grammar.GrammarLogitsProcessor``. + +``GrammarLogitsProcessor`` ensures generated text follows the rules of a grammar. This provides the ability to guarantee your output is syntactically valid JSON, SQL, Python, RegEx, etc. + +Sample Code for JSON +--------------------- + +.. code-block:: python + + json_grammar = """ + ?value: dict + | list + | string + | SIGNED_NUMBER -> number + | "true" -> true + | "false" -> false + | "null" -> null + + list : "[" [value ("," value)*] "]" + + dict : "{" [pair ("," pair)*] "}" + pair : string ":" value + + string : ESCAPED_STRING + + %import common.ESCAPED_STRING + %import common.SIGNED_NUMBER + """ + grammar_logits_processor = GrammarLogitsProcessor( + tokenizer, + json_grammar, + grammar_start="value" + ) + SamplingParams(logits_processor=grammar_logits_processor) + +Resources +--------- + +- `How to write an EBNF grammar for Lark `_ +- `Wikipedia - EBNF `_ +- `Wikipedia - LALR Parser `_ + +Example Lark Grammars +--------------------- + +- `JSON `_ +- `Python3 `_ +- `Resource with many grammars including SQLite, TOML, YAML, Lua, and more `_ + +Performance +----------- + +Expect between 3 and 30 new tokens per second as a baseline, however performance can be improved from the baseline. + +**Constrain legal characters** + +Every legal character in the alphabet must be checked against the parser by default. Mistral tokenizer, for example, has an alphabet of 3,298 characters, here are 40 random examples: + +.. code-block:: + + [ '堂', 'ู', 'ɔ', '🙌', 'Б', '레', '允', 'ả', '\ue934', '如', '試', 'K', '¯', '卷', '園', 'ए', '\\', '酒', 'थ', 'グ', '터', '연', 'Ș', 'ブ', '星', 'ြ', 'å', '軍', '案', '题', '银', '映', '표', '\x11', '級', '醒', 'ေ', '✭', '約', '😤'] + +Likely many of these characters aren't useful in your generation. + +Expect an ~10x speedup if you constrain your generation to UTF-8, eliminating 3,042 unnecessary characters. + +.. code-block:: + + GrammarLogitsProcessor( + ..., + legal_chars=set([chr(i) for i in range(256)]) + ) + +**Design your EBNF with minimal regexp** + +Regexp processing is the most expensive task for GrammarLogitsProcessor. When designing your EBNF, it's better to keep your regexp short and simple if at all possible. + +**Use more threads** + +By default ``GrammarLogitProcessor`` uses ``os.cpu_count() / 2`` threads. You may change this via + +.. code-block:: + + GrammarLogitsProcessor( + ..., + num_threads=4 + ) From 57081a64c4eafc1f6778c4ae55324646a690f4c5 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 16:54:03 -0600 Subject: [PATCH 18/76] cleanup --- vllm/grammar.py | 52 +++++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 4bd3446c22bc..c2c36b8f442e 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -57,27 +57,6 @@ def __copy__(self): ######################################################################### -def get_partial_pattern_validator(pattern): - """ - Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE - Returns a function which validates a partial string - - e.g. for PatternRE "abc*", returns true for "a", "ab", "abc", "abcccc" - """ - if isinstance(pattern, PatternRE): - compiled_pattern = regex.compile(pattern.value) - return ( - lambda seq: compiled_pattern.fullmatch(seq, partial=True) is not None - ) - elif isinstance(pattern, PatternStr): - base_str = pattern.value - return ( - lambda seq: base_str.startswith(seq) - ) - else: - raise TypeError(f"Invalid pattern type: {type(pattern)}") - - class InteractivePredictiveLALRParser: """ Parser which consumes an EBNF grammar and provides helpers to determine allowable language model tokens @@ -106,7 +85,7 @@ def __init__(self, grammar: str, start: str): ) self.partial_seq_validator = { - term.name: get_partial_pattern_validator(term.pattern) + term.name: self._get_partial_pattern_validator(term.pattern) for term in self.parser.terminals } @@ -124,6 +103,28 @@ def __init__(self, grammar: str, start: str): # initiate self.step_seq("") + @staticmethod + def _get_partial_pattern_validator(pattern): + """ + Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE + Returns a function which validates a partial string + + e.g. for PatternRE "abc*", returns true for "a", "ab", "abc", "abcccc" + """ + if isinstance(pattern, PatternRE): + compiled_pattern = regex.compile(pattern.value) + return ( + lambda seq: compiled_pattern.fullmatch(seq, partial=True) is not None + ) + elif isinstance(pattern, PatternStr): + base_str = pattern.value + return ( + lambda seq: base_str.startswith(seq) + ) + else: + raise TypeError(f"Invalid pattern type: {type(pattern)}") + + def _accepts(self): if self.sequence_history not in self._accepts_cache: accepted_terminals = self.interactive_parser.accepts() @@ -372,15 +373,20 @@ def test_generate_json_randomly_via_logit_processor(): sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts))) + np.random.seed = 42 + import time + start = time.time() token_ids = [] for _ in range(20): logits = logit_processor( token_ids=token_ids, - logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) + logits=np.random.uniform(-10, 10, len(tokenizer.vocab),) ) new_token_id = sample_from_logits(logits) token_ids.append(new_token_id) + print("duration", time.time() - start) + import pdb;pdb.set_trace() From 9052a54a8a631867e091363a5a19a73e730b2ecf Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 19 Dec 2023 19:14:11 -0600 Subject: [PATCH 19/76] add cache to TokenTrie, resulting in 35% speedup --- vllm/grammar.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index c2c36b8f442e..01296e6f0c02 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -220,11 +220,25 @@ def __init__(self, tokenizer, legal_chars: Optional[set[str]] = None): current_dict = current_dict[char] current_dict[self.IS_TOKEN] = True - def get_next_level_token_prefixes(self, subprefix: str, _node=None): + self._next_level_token_prefixes_cache = {} + + def get_next_level_token_prefixes(self, subprefix: str): + if subprefix not in self._next_level_token_prefixes_cache: + self._next_level_token_prefixes_cache[subprefix] = ( + self.get_next_level_token_prefixes_uncached(subprefix) + ) + return self._next_level_token_prefixes_cache[subprefix] + + def get_next_level_token_prefixes_uncached(self, subprefix: str, _node=None): """ Traverse the trie starting from a specified subprefix to identify all child nodes that represent the longest possible strings without omitting any nodes that contain complete tokens. """ + # cache + if _node is None: + if subprefix in self._next_level_token_prefixes_cache: + return self._next_level_token_prefixes_cache[subprefix] + # if not first level of recursion, and at a branching point or is a token, or return self if _node is not None and (len(_node) > 1 or self.IS_TOKEN in _node): return {subprefix} @@ -241,7 +255,11 @@ def get_next_level_token_prefixes(self, subprefix: str, _node=None): results = set() for char, next_node in _node.items(): if char != self.IS_TOKEN: - results |= self.get_next_level_token_prefixes(subprefix + char, _node=next_node) + results |= self.get_next_level_token_prefixes_uncached( + subprefix + char, + _node=next_node + ) + return results def is_token(self, seq): From e55ae6bdb0a348bc2b4a5f99f87cd87cbb765c53 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 20 Dec 2023 15:22:47 -0600 Subject: [PATCH 20/76] clean up --- requirements.txt | 1 + vllm/grammar.py | 258 +++++------------------------------------------ 2 files changed, 28 insertions(+), 231 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9f8c729105d6..9a84a8d7e77d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. aioprometheus[starlette] +lark == 1.1.8 # Required for Grammars diff --git a/vllm/grammar.py b/vllm/grammar.py index 01296e6f0c02..5f345649c8dc 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -12,7 +12,7 @@ ######################################################################### -# Fix Lark Speed Issue +# Fix Lark Interactive LALR Parser Speed Issue # https://github.com/lark-parser/lark/issues/1142#issuecomment-1863209804 ######################################################################### class FastParserState(ParserState): @@ -28,7 +28,7 @@ def __copy__(self): new_instance = type(self)( self.parse_conf, - self.lexer, # XXX copy + self.lexer, copy(self.state_stack), new_value_stack, ) @@ -38,6 +38,7 @@ def __copy__(self): class FastInteractiveParser(InteractiveParser): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.parser_state = FastParserState( @@ -81,8 +82,7 @@ def __init__(self, grammar: str, start: str): self.interactive_parser = FastInteractiveParser( base_interactive_parser.parser, base_interactive_parser.parser_state, - base_interactive_parser.lexer_thread - ) + base_interactive_parser.lexer_thread) self.partial_seq_validator = { term.name: self._get_partial_pattern_validator(term.pattern) @@ -113,18 +113,14 @@ def _get_partial_pattern_validator(pattern): """ if isinstance(pattern, PatternRE): compiled_pattern = regex.compile(pattern.value) - return ( - lambda seq: compiled_pattern.fullmatch(seq, partial=True) is not None - ) + return (lambda seq: compiled_pattern.fullmatch(seq, partial=True) + is not None) elif isinstance(pattern, PatternStr): base_str = pattern.value - return ( - lambda seq: base_str.startswith(seq) - ) + return (lambda seq: base_str.startswith(seq)) else: raise TypeError(f"Invalid pattern type: {type(pattern)}") - def _accepts(self): if self.sequence_history not in self._accepts_cache: accepted_terminals = self.interactive_parser.accepts() @@ -158,7 +154,9 @@ def step_seq(self, new_seq: str): self._update_candidate_terminals() if not self.valid_next_terminals: - raise ValueError(f"Invalid continuation for `{self.sequence_history}` `{new_seq}`") + raise ValueError( + f"Invalid continuation for `{self.sequence_history}` `{new_seq}`" + ) def _append_to_sequence(self, new_seq: str): """Set the complete sequences value in the lexer and base""" @@ -189,7 +187,8 @@ def is_valid_next_seq(self, new_seq: Optional[str]): return "$END" in self.valid_next_terminals for term in self.valid_next_terminals: if term != "$END": - if self.partial_seq_validator[term](self.terminal_partial_seq + new_seq): + full_terminal_candidate = self.terminal_partial_seq + new_seq + if self.partial_seq_validator[term](full_terminal_candidate): return True return False @@ -203,8 +202,10 @@ def __init__(self, tokenizer, legal_chars: Optional[set[str]] = None): """ self.norm_vocab = {} for token_id in tokenizer.vocab.values(): - norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[len(tokenizer.bos_token):] - if legal_chars is None or all([char in legal_chars for char in norm_token]): + bos_len = len(tokenizer.bos_token) + norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[bos_len:] + if legal_chars is None or all( + [char in legal_chars for char in norm_token]): self.norm_vocab[norm_token] = token_id self.token_to_id_set = collections.defaultdict(set) @@ -274,20 +275,16 @@ class NextTokenValidator: - step_seq(new_seq): Append a sequence, update internal states - property valid_token_str_set: The valid set of vocabulary tokens strings which can occur next """ - def __init__( - self, - tokenizer, - grammar: str, - grammar_start: str = "start", - num_threads: Optional[int] = None - ): + def __init__(self, + tokenizer, + grammar: str, + grammar_start: str = "start", + num_threads: Optional[int] = None): self.tokenizer = tokenizer self.token_trie = TokenTrie(tokenizer) - self.parser = InteractivePredictiveLALRParser( - grammar=grammar, - start=grammar_start - ) + self.parser = InteractivePredictiveLALRParser(grammar=grammar, + start=grammar_start) # TODO: threading if num_threads is None: @@ -331,7 +328,10 @@ def valid_token_id_set(self): ]) -class GrammarLogitProcessor(NextTokenValidator): +class GrammarLogitsProcessor(NextTokenValidator): + """ + Apply NextTokenValidator in __call__ and set excluded tokens logits to -inf + """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -339,7 +339,6 @@ def __init__(self, *args, **kwargs): self.generation_text = "" def __call__(self, token_ids, logits): - # ensure integrity assert token_ids[:len(self.generation_token_ids)] == self.generation_token_ids self.generation_token_ids = token_ids @@ -357,206 +356,3 @@ def __call__(self, token_ids, logits): for tok_id, logit_val in zip(sorted(self.tokenizer.vocab.values()), logits) ] return logits - - -def test_generate_json_randomly_via_logit_processor(): - json_grammar = """ - ?value: dict - | list - | string - | SIGNED_NUMBER -> number - | "true" -> true - | "false" -> false - | "null" -> null - - list : "[" [value ("," value)*] "]" - - dict : "{" [pair ("," pair)*] "}" - pair : string ":" value - - string : ESCAPED_STRING - - %import common.ESCAPED_STRING - %import common.SIGNED_NUMBER - %import common.WS - #%ignore WS # we don't ignore whitespace because that makes the json uninteresting - """ - tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - - logit_processor = GrammarLogitProcessor( - tokenizer, - json_grammar, - grammar_start="value" - ) - - sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts))) - - np.random.seed = 42 - import time - start = time.time() - token_ids = [] - for _ in range(20): - logits = logit_processor( - token_ids=token_ids, - logits=np.random.uniform(-10, 10, len(tokenizer.vocab),) - ) - new_token_id = sample_from_logits(logits) - token_ids.append(new_token_id) - - print("duration", time.time() - start) - - import pdb;pdb.set_trace() - - -def test_next_token_validator_simple(): - hello_grammar = """ - ?value: "hello" | "world" - """ - tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - ntv = NextTokenValidator(tokenizer, hello_grammar, "value") - - assert ntv.valid_token_str_set == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} - assert ntv.valid_token_id_set == {265, 809, 107, 2805, 21558, 28727, 13436, 22493, 9471} - - -def test_token_trie_sanity_hf_tokenizer(): - """Ensure token trie produces the same number of N 3 letter tokens""" - tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - toktrie = TokenTrie(tokenizer) - - all_prefixes = toktrie.get_next_level_token_prefixes("") - - # every token should be composable from a single unique char, so they will all be len of 1 - assert all([len(p) == 1 for p in all_prefixes]) - - # every token should have one of these prefixes as a start character - assert all([ - t[0] in all_prefixes - for t in toktrie.norm_vocab - ]) - - # construct the set of next level prefixes - all_subprefixes = set() - for pfx in all_prefixes: - all_subprefixes |= toktrie.get_next_level_token_prefixes(pfx) - - # these should have varying length because some tokens don't have level-2 prefixes - assert len(set([len(spfx) for spfx in all_subprefixes])) > 1 - - -def test_simple_sequence(parser): - for token in ['{', '"', 'k', 'ey', '":', '"val', 'ue', '"']: - print("full:", parser.sequence_history) - print("adding", token) - parser.step_seq(token) - print("partial:", parser.terminal_partial_seq) - print("valid terms:", parser.valid_next_terminals) - - -def test_valid_next_tokens(): - json_grammar = """ - ?value: dict - | list - | string - | SIGNED_NUMBER -> number - | "true" -> true - | "false" -> false - | "null" -> null - - list : "[" [value ("," value)*] "]" - - dict : "{" [pair ("," pair)*] "}" - pair : string ":" value - - string : ESCAPED_STRING - - %import common.ESCAPED_STRING - %import common.SIGNED_NUMBER - %import common.WS - %ignore WS - """ - - parser = InteractivePredictiveLALRParser(json_grammar, 'value') - # random complicated json file courtesy of https://github.com/simdjson/simdjson/issues/1316#issue-748663718 - complex_json_file = '{"$schema": "http://json-schema.org/draft-04/schema#", "additionalProperties": false, "properties": {"nc:Vehicle": {"description": "A conveyance designed to carry an operator, passengers and/or cargo, over land.", "oneOf": [{"$ref": "#/definitions/nc:VehicleType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:VehicleType"}}]}, "nc:VehicleAxleQuantity": {"description": "A count of common axles of rotation of one or more wheels of a vehicle, whether power driven or freely rotating.", "oneOf": [{"$ref": "#/definitions/niem-xs:nonNegativeInteger"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:nonNegativeInteger"}}]}, "nc:VehicleMSRPAmount": {"description": "A manufacturer\'s suggested retail price of a vehicle; a price at which a manufacturer recommends a vehicle be sold.", "oneOf": [{"$ref": "#/definitions/nc:AmountType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:AmountType"}}]}, "nc:Amount": {"description": "An amount of money.", "oneOf": [{"$ref": "#/definitions/niem-xs:decimal"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:decimal"}}]}, "nc:Currency": {"description": "A data concept for a unit of money or exchange.", "oneOf": [{"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}, {"type": "array", "items": {"anyOf": [{"$ref": "#/properties/nc:CurrencyCode"}]}}]}, "nc:CurrencyCode": {"description": "A unit of money or exchange.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeType"}, {"type": "array", "items": {"$ref": "#/definitions/iso_4217:CurrencyCodeType"}}]}, "nc:VehicleIdentification": {"description": "A unique identification for a specific vehicle.", "oneOf": [{"$ref": "#/definitions/nc:IdentificationType"}, {"type": "array", "items": {"$ref": "#/definitions/nc:IdentificationType"}}]}, "nc:IdentificationID": {"description": "An identifier.", "oneOf": [{"$ref": "#/definitions/niem-xs:string"}, {"type": "array", "items": {"$ref": "#/definitions/niem-xs:string"}}]}}, "definitions": {"nc:VehicleType": {"description": "A data type for a conveyance designed to carry an operator, passengers and/or cargo, over land.", "allOf": [{"$ref": "#/definitions/nc:ConveyanceType"}, {"type": "object", "properties": {"nc:VehicleAxleQuantity": {"$ref": "#/properties/nc:VehicleAxleQuantity"}, "nc:VehicleIdentification": {"$ref": "#/properties/nc:VehicleIdentification"}, "nc:VehicleMSRPAmount": {"$ref": "#/properties/nc:VehicleMSRPAmount"}}}]}, "nc:ConveyanceType": {"description": "A data type for a means of transport from place to place.", "allOf": [{"$ref": "#/definitions/_base"}, {"$ref": "#/definitions/nc:ItemType"}, {"type": "object", "properties": {}}]}, "nc:ItemType": {"description": "A data type for an article or thing.", "allOf": [{"$ref": "#/definitions/_base"}, {"type": "object", "properties": {}}]}, "nc:AmountType": {"description": "A data type for an amount of money.", "type": "object", "properties": {"nc:Amount": {"$ref": "#/properties/nc:Amount"}, "nc:Currency": {"$ref": "#/properties/nc:Currency"}}}, "iso_4217:CurrencyCodeType": {"description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}, {"type": "object", "properties": {"rdf:value": {"$ref": "#/definitions/iso_4217:CurrencyCodeSimpleType"}}}]}, "iso_4217:CurrencyCodeSimpleType": {"type": "string", "description": "A data type for a currency that qualifies a monetary amount.", "oneOf": [{"enum": ["EUR"], "description": "Euro"}, {"enum": ["GBP"], "description": "Pound Sterling"}, {"enum": ["USD"], "description": "US Dollar"}]}, "nc:IdentificationType": {"description": "A data type for a representation of an identity.", "type": "object", "properties": {"nc:IdentificationID": {"$ref": "#/properties/nc:IdentificationID"}}}, "niem-xs:decimal": {"description": "A data type for arbitrary precision decimal numbers.", "type": "number"}, "niem-xs:nonNegativeInteger": {"description": "A data type for an integer with a minimum value of 0.", "type": "number"}, "niem-xs:string": {"description": "A data type for character strings in XML.", "type": "string"}, "_base": {"type": "object", "patternProperties": {"^ism:.*": {"type": "string"}, "^ntk:.*": {"type": "string"}}, "properties": {"@id": {"format": "uriref"}, "@base": {"format": "uriref"}}}}}' - - test_chars_per_iter = 1000 - unicode_chars = [chr(i) for i in range(test_chars_per_iter)] - - import time - start = time.time() - for char in complex_json_file: - parser.step_seq(char) - for ch in unicode_chars: - parser.is_valid_next_seq(ch) - - print("took", - (time.time() - start) / (len(complex_json_file)), - "seconds per step with", - test_chars_per_iter, "characters in vocabulary") - - - -def profile_predictor(): - import pstats - from io import StringIO - import cProfile - hello_grammar = """ - ?value: "hello" | "world" - """ - tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - ntv = NextTokenValidator(tokenizer, hello_grammar, "value") - - profile = cProfile.Profile() - profile.enable() - ##### - - valid_toks = ntv.valid_token_str_set - ntv.step_seq("h") - valid_toks = ntv.valid_token_str_set - ntv.step_seq("e") - valid_toks = ntv.valid_token_str_set - ntv.step_seq("l") - valid_toks = ntv.valid_token_str_set - ntv.step_seq("l") - valid_toks = ntv.valid_token_str_set - ntv.step_seq("o") - - ##### - profile.disable() - - # Sorting the statistics by cumulative time - s = StringIO() - sortby = 'cumulative' - ps = pstats.Stats(profile, stream=s).sort_stats(sortby) - ps.print_stats() - print(s.getvalue()) - - - -def main(): - test_generate_json_randomly_via_logit_processor() - - -if __name__ == "__main__": - import transformers - import numpy as np - - profile = False - if profile: - import cProfile - import pstats - from io import StringIO - profile = cProfile.Profile() - profile.enable() - main() - profile.disable() - - # Sorting the statistics by cumulative time - s = StringIO() - sortby = 'cumulative' - ps = pstats.Stats(profile, stream=s).sort_stats(sortby) - ps.print_stats() - print(s.getvalue()) - else: - main() From 7706a9c67e999e6f74f66d86b5728e2da80ef7c4 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 20 Dec 2023 15:22:56 -0600 Subject: [PATCH 21/76] WIP: tests --- tests/samplers/test_grammar.py | 207 +++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 tests/samplers/test_grammar.py diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py new file mode 100644 index 000000000000..a6352849a5d8 --- /dev/null +++ b/tests/samplers/test_grammar.py @@ -0,0 +1,207 @@ +import pytest +import random + +from transformers import AutoTokenizer + +from vllm.grammar import TokenTrie + + +MODELS = ["codellama/CodeLlama-7b-hf"] + + +@pytest.fixture +def json_grammar(): + return """ + start: value + value: dict + | list + | string + | SIGNED_NUMBER -> number + | "true" -> true + | "false" -> false + | "null" -> null + + list : "[" [value ("," value)*] "]" + + dict : "{" [pair ("," pair)*] "}" + pair : string ":" value + + string : ESCAPED_STRING + + %import common.ESCAPED_STRING + %import common.SIGNED_NUMBER + %import common.WS + %ignore WS + """ + + +@pytest.fixture +def json_example(): + return """ + {"widget": { + "debug": "on", + "window": { + "title": "Sample Konfabulator Widget", + "name": "main_window", + "width": 500, + "height": 500 + }, + "image": { + "src": "Images/Sun.png", + "name": "sun1", + "hOffset": 250, + "vOffset": 250, + "alignment": "center" + }, + "text": { + "data": "Click Here", + "size": 36, + "style": "bold", + "name": "text1", + "hOffset": 250, + "vOffset": 100, + "alignment": "center", + "onMouseUp": "sun1.opacity = (sun1.opacity / 100) * 90;" + } + }}""".strip() + + +@pytest.fixture +def yaml_grammar(): + return """ +start : yaml + +yaml : data +data : ( scalar | sequence | mapping ) + +scalar : ( number | string | date | BOOLEAN | NIL ) +sequence : ( inline_seq| indented_seq ) +mapping : ( inline_map | indented_map ) + +inline_seq : "[" data ( "," data )* "]" +indented_seq : OPTIONAL_TAB "-" data ( "\n" OPTIONAL_TAB "-" data )* +inline_map : "{" key ":" data ( "," key ":" data )* "}" +indented_map : TAB key ":" data ( "\n" TAB key ":" data )* + +alpha : LCASE_LETTER | UCASE_LETTER +alphanum : alpha | DIGIT +string : "\"" alphanum* "\"" | alphanum+ +key : scalar +number : ("+" | "-")? DIGIT+ ("." DIGIT+)? +date : DIGIT~4 "-" DIGIT~2 "-" DIGIT~2 ( DIGIT~2 ":" DIGIT~2 ":" DIGIT~2 )? + +LCASE_LETTER : "a".."z" +UCASE_LETTER : "A".."Z" +DIGIT : "0".."9" +BOOLEAN : "true" | "false" +NIL : "~" +SPACE : " " +OPTIONAL_TAB : SPACE* +TAB : SPACE+ +""".strip() + + +@pytest.fixture +def yaml_example(): + return """ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.8" + +sphinx: + configuration: docs/source/conf.py + +# If using Sphinx, optionally build your docs in additional formats such as PDF +formats: + - pdf + +# Optionally declare the Python requirements required to build your docs +python: + install: + - requirements: docs/requirements-docs.txt +""".strip() + + +@pytest.mark.parametrize("model_id", MODELS) +def test_next_token_validator_simple( + model, +): + hello_grammar = """ + ?start: "hello" | "world" + """ + tokenizer = AutoTokenizer.from_pretrained(model_id) + ntv = NextTokenValidator(tokenizer, hello_grammar) + + assert ntv.valid_token_str_set == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} + assert ntv.valid_token_id_set == {265, 809, 107, 2805, 21558, 28727, 13436, 22493, 9471} + + +@pytest.mark.parametrize("model_id", MODELS) +@pytest.mark.parametrize("grammar_fixture, example_fixture", [ + ("json_grammar", "json_example"), + ("yaml_grammar", "yaml_example") +]) +def test_can_generate_with_grammar( + model_id, + grammar_fixture, + example_fixture +): + """Assert that example json file is legal to generate with GrammarLogitsProcessor""" + grammar = request.getfixturevalue(grammar_fixture) + example = request.getfixturevalue(example_fixture) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + next_token_validator = NextTokenValidator( + tokenizer, + grammar, + ) + example_remainder = example + while exampleo_remainder: + legal_next_token_strs = list(next_token_validator.valid_token_str_set) + random.shuffle(legal_next_token_strs) + for tok in legal_next_token_strs: + if example_remainder.startswith(tok): + example_remainder = example_remainder[len(tok):] + else: + raise Exception("Couldn't find token to validate legal JSON given JSON grammar") + + +@pytest.mark.parametrize("model_id", MODELS) +def test_token_trie_sanity( + model_id +): + tokenizer = AutoTokenizer.from_pretrained(model_id) + toktrie = TokenTrie(tokenizer) + + all_prefixes = toktrie.get_next_level_token_prefixes("") + + # every token should be composable from a single unique char, so they will all be len of 1 + assert all([len(p) == 1 for p in all_prefixes]) + + # every token should have one of these prefixes as a start character + assert all([ + t[0] in all_prefixes + for t in toktrie.norm_vocab + ]) + + # construct the set of next level prefixes + all_subprefixes = set() + for pfx in all_prefixes: + all_subprefixes |= toktrie.get_next_level_token_prefixes(pfx) + + # these should have varying length because some tokens don't have level-2 prefixes + assert len(set([len(spfx) for spfx in all_subprefixes])) > 1 + + +def test_assert_ends_with_eos(): + assert False + + +def test_integration_with_vllm(): + assert False From 1fdb0f576b6971cb4832c97ce0e561dc8b15f4de Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 05:41:44 -0600 Subject: [PATCH 22/76] misc bug fixes, optimizations + add EOS token generation --- docs/source/grammars/grammars.rst | 65 +++++++++++---- tests/samplers/test_grammar.py | 122 +++++++++++++--------------- vllm/grammar.py | 127 +++++++++++++++++------------- 3 files changed, 174 insertions(+), 140 deletions(-) diff --git a/docs/source/grammars/grammars.rst b/docs/source/grammars/grammars.rst index 24fb5672af8c..64f863de422b 100644 --- a/docs/source/grammars/grammars.rst +++ b/docs/source/grammars/grammars.rst @@ -55,38 +55,71 @@ Example Lark Grammars Performance ----------- -Expect between 3 and 30 new tokens per second as a baseline, however performance can be improved from the baseline. +For a simple JSON grammar, on the authors mid-end laptop using codeLlama-7b's vocabulary, generation occurred at ~10 validated logit sets per second. However performance was improved dramatically from the baseline with a few tweaks to ~400/s. These tweaks include +- Optimizing the grammar increased performance by ~10x +- Constraining legal characters increased performance by ~4x -**Constrain legal characters** +**Design your EBNF grammar with minimal regexp** -Every legal character in the alphabet must be checked against the parser by default. Mistral tokenizer, for example, has an alphabet of 3,298 characters, here are 40 random examples: +Regexp processing is the most expensive task for GrammarLogitsProcessor. When designing your EBNF, it's better to keep your regexp short and simple if at all possible. + +Breaking down the following expressions ESCAPE_STRING into an expression with many faster-terminating regex resulted in an ~8x speedup: .. code-block:: + start: value + ?value: dict + | list + | string + | signed_number -> number + | "true" -> true + | "false" -> false + | "null" -> null - [ '堂', 'ู', 'ɔ', '🙌', 'Б', '레', '允', 'ả', '\ue934', '如', '試', 'K', '¯', '卷', '園', 'ए', '\\', '酒', 'थ', 'グ', '터', '연', 'Ș', 'ブ', '星', 'ြ', 'å', '軍', '案', '题', '银', '映', '표', '\x11', '級', '醒', 'ေ', '✭', '約', '😤'] + list : "[" [value ("," value)*] "]" -Likely many of these characters aren't useful in your generation. + dict : "{" [pair ("," pair)*] "}" + pair : string ":" value -Expect an ~10x speedup if you constrain your generation to UTF-8, eliminating 3,042 unnecessary characters. + string : "\"" escaped_string_char* "\"" + escaped_string_char: STR_INNER_CHAR | ESCAPED_CHAR + ESCAPED_CHAR: "\\" ANY_CHAR + STR_INNER_CHAR: /[^\\\"]/ + ANY_CHAR: /[.]/ -.. code-block:: + signed_number: ["+"|"-"] number + number: float | int + float: int exp | decimal exp? + decimal: int "." int? | "." int + exp: ("e"|"E") signed_int + signed_int: ["+"|"-"] int + int: DIGIT+ + DIGIT: "0".."9" - GrammarLogitsProcessor( - ..., - legal_chars=set([chr(i) for i in range(256)]) - ) + WS: /[ \t\f\r\n]/ + %ignore WS -**Design your EBNF with minimal regexp** + # old slow regex-based expressions: -Regexp processing is the most expensive task for GrammarLogitsProcessor. When designing your EBNF, it's better to keep your regexp short and simple if at all possible. + # %import common.ESCAPED_STRING + # %import common.SIGNED_NUMBER + # %import common.WS -**Use more threads** -By default ``GrammarLogitProcessor`` uses ``os.cpu_count() / 2`` threads. You may change this via +**Constrain legal characters** + +Every legal character in the alphabet must be checked against the parser by default. Mistral tokenizer, for example, has an alphabet of 3,298 characters, here are 40 random examples: + +.. code-block:: + + [ '堂', 'ู', 'ɔ', '🙌', 'Б', '레', '允', 'ả', '\ue934', '如', '試', 'K', '¯', '卷', '園', 'ए', '\\', '酒', 'थ', 'グ', '터', '연', 'Ș', 'ブ', '星', 'ြ', 'å', '軍', '案', '题', '银', '映', '표', '\x11', '級', '醒', 'ေ', '✭', '約', '😤'] + +Likely many of these characters aren't useful in your generation. + +Expect an ~10x speedup if you constrain your generation to UTF-8, eliminating 3,042 unnecessary characters. .. code-block:: GrammarLogitsProcessor( ..., - num_threads=4 + legal_chars=set([chr(i) for i in range(256)]) ) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index a6352849a5d8..668cac54d367 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -3,7 +3,7 @@ from transformers import AutoTokenizer -from vllm.grammar import TokenTrie +from vllm.grammar import TokenTrie, NextTokenValidator MODELS = ["codellama/CodeLlama-7b-hf"] @@ -11,12 +11,12 @@ @pytest.fixture def json_grammar(): - return """ + return r""" start: value value: dict | list | string - | SIGNED_NUMBER -> number + | signed_number -> number | "true" -> true | "false" -> false | "null" -> null @@ -26,11 +26,22 @@ def json_grammar(): dict : "{" [pair ("," pair)*] "}" pair : string ":" value - string : ESCAPED_STRING - - %import common.ESCAPED_STRING - %import common.SIGNED_NUMBER - %import common.WS + string : "\"" escaped_string_char* "\"" + escaped_string_char: STR_INNER_CHAR | ESCAPED_CHAR + ESCAPED_CHAR: "\\" ANY_CHAR + STR_INNER_CHAR: /[^\\\"]/ + ANY_CHAR: /[.]/ + + signed_number: ["+"|"-"] number + number: float | int + float: int exp | decimal exp? + decimal: int "." int? | "." int + exp: ("e"|"E") signed_int + signed_int: ["+"|"-"] int + int: DIGIT+ + DIGIT: "0".."9" + + WS: /[ \t\f\r\n]/ %ignore WS """ @@ -67,70 +78,37 @@ def json_example(): @pytest.fixture -def yaml_grammar(): +def csv_grammar(): return """ -start : yaml - -yaml : data -data : ( scalar | sequence | mapping ) - -scalar : ( number | string | date | BOOLEAN | NIL ) -sequence : ( inline_seq| indented_seq ) -mapping : ( inline_map | indented_map ) - -inline_seq : "[" data ( "," data )* "]" -indented_seq : OPTIONAL_TAB "-" data ( "\n" OPTIONAL_TAB "-" data )* -inline_map : "{" key ":" data ( "," key ":" data )* "}" -indented_map : TAB key ":" data ( "\n" TAB key ":" data )* - -alpha : LCASE_LETTER | UCASE_LETTER -alphanum : alpha | DIGIT -string : "\"" alphanum* "\"" | alphanum+ -key : scalar -number : ("+" | "-")? DIGIT+ ("." DIGIT+)? -date : DIGIT~4 "-" DIGIT~2 "-" DIGIT~2 ( DIGIT~2 ":" DIGIT~2 ":" DIGIT~2 )? - -LCASE_LETTER : "a".."z" -UCASE_LETTER : "A".."Z" -DIGIT : "0".."9" -BOOLEAN : "true" | "false" -NIL : "~" -SPACE : " " -OPTIONAL_TAB : SPACE* -TAB : SPACE+ -""".strip() + start: header _NL row+ + header: "#" " "? (WORD _SEPARATOR?)+ + row: (_anything _SEPARATOR?)+ _NL + _anything: INT | WORD | NON_SEPARATOR_STRING | FLOAT | SIGNED_FLOAT + NON_SEPARATOR_STRING: "/[a-zA-z.;\\\/]+/" + _SEPARATOR: "\t" + | "," + + # using these suboptimal common library terminals is a bad practice + %import common.NEWLINE -> _NL + %import common.WORD + %import common.INT + %import common.FLOAT + %import common.SIGNED_FLOAT + """ @pytest.fixture -def yaml_example(): +def csv_example(): return """ -# Read the Docs configuration file -# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details - -version: 2 - -build: - os: ubuntu-22.04 - tools: - python: "3.8" - -sphinx: - configuration: docs/source/conf.py - -# If using Sphinx, optionally build your docs in additional formats such as PDF -formats: - - pdf - -# Optionally declare the Python requirements required to build your docs -python: - install: - - requirements: docs/requirements-docs.txt +#foo\tbar\tbaz +1\t2\t3 +bif\t\bif\tbif """.strip() @pytest.mark.parametrize("model_id", MODELS) def test_next_token_validator_simple( - model, + model_id, ): hello_grammar = """ ?start: "hello" | "world" @@ -138,21 +116,23 @@ def test_next_token_validator_simple( tokenizer = AutoTokenizer.from_pretrained(model_id) ntv = NextTokenValidator(tokenizer, hello_grammar) + # tokens specific to codeLlama assert ntv.valid_token_str_set == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} - assert ntv.valid_token_id_set == {265, 809, 107, 2805, 21558, 28727, 13436, 22493, 9471} + assert sorted(ntv.valid_token_id_set) == [107, 122, 354, 827, 3952, 11526, 12199, 13762, 14181, 29882, 29893] @pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("grammar_fixture, example_fixture", [ ("json_grammar", "json_example"), - ("yaml_grammar", "yaml_example") + ("csv_grammar", "csv_example") ]) def test_can_generate_with_grammar( model_id, + request, grammar_fixture, example_fixture ): - """Assert that example json file is legal to generate with GrammarLogitsProcessor""" + """Assert that example file is legal to generate with GrammarLogitsProcessor""" grammar = request.getfixturevalue(grammar_fixture) example = request.getfixturevalue(example_fixture) @@ -160,16 +140,22 @@ def test_can_generate_with_grammar( next_token_validator = NextTokenValidator( tokenizer, grammar, + legal_chars=set([chr(i) for i in range(256)]) ) example_remainder = example - while exampleo_remainder: + while example_remainder: legal_next_token_strs = list(next_token_validator.valid_token_str_set) random.shuffle(legal_next_token_strs) for tok in legal_next_token_strs: if example_remainder.startswith(tok): + next_token_validator.step_seq(tok) example_remainder = example_remainder[len(tok):] + break else: - raise Exception("Couldn't find token to validate legal JSON given JSON grammar") + raise Exception(f"Couldn't find token to create legal output given grammar: '{example_remainder}'") + + # EOS should be in the set of next legal tokens + assert None in next_token_validator.valid_token_str_set @pytest.mark.parametrize("model_id", MODELS) @@ -199,7 +185,7 @@ def test_token_trie_sanity( assert len(set([len(spfx) for spfx in all_subprefixes])) > 1 -def test_assert_ends_with_eos(): +def test_assert_fails_for_invalid_examples(): assert False diff --git a/vllm/grammar.py b/vllm/grammar.py index 5f345649c8dc..b2fa6b7a28c6 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -77,6 +77,7 @@ def __init__(self, grammar: str, start: str): regex=True, # use `regex` not `re` start=start, parser='lalr', + cache=True, # results in 2-3x faster loading ) base_interactive_parser = self.parser.parse_interactive() self.interactive_parser = FastInteractiveParser( @@ -91,17 +92,13 @@ def __init__(self, grammar: str, start: str): self._ignored_terms = set(self.parser.lexer_conf.ignore) - # for processing terminals interactively - self.last_terminal_pos = 0 - self.valid_next_terminals = None - # for calculating `accepts()` efficiently self._accepts_cache = {} self.sequence_history = "" - # initiate - self.step_seq("") + # for processing terminals interactively + self.valid_next_terminals = {"": self._accepts() | self._ignored_terms} @staticmethod def _get_partial_pattern_validator(pattern): @@ -127,13 +124,6 @@ def _accepts(self): self._accepts_cache[self.sequence_history] = accepted_terminals return self._accepts_cache[self.sequence_history] - @property - def terminal_partial_seq(self): - """ - Return the incomplete subsequence which will eventually comprise a terminal - """ - return self.sequence_history[self.last_terminal_pos:] - def step_seq(self, new_seq: str): """ Append sequence to parser and apply state updates @@ -142,21 +132,30 @@ def step_seq(self, new_seq: str): - Update the character position of the last complete terminal - Update the set of candidate terminals """ - self._append_to_sequence(new_seq) - - try: - self.interactive_parser.exhaust_lexer() - except UnexpectedCharacters as e: - self.last_terminal_pos = e.pos_in_stream - else: - self.last_terminal_pos = len(self.sequence_history) - - self._update_candidate_terminals() - - if not self.valid_next_terminals: - raise ValueError( - f"Invalid continuation for `{self.sequence_history}` `{new_seq}`" - ) + for char in new_seq: + self._append_to_sequence(char) + + filter_candidate_terminals = True + try: + self.interactive_parser.exhaust_lexer() + except UnexpectedCharacters as e: + pass + except UnexpectedToken as e: + filter_candidate_terminals = False + + self.valid_next_terminals = { + (incomplete_seq + char): term + for incomplete_seq, term in self.valid_next_terminals.items() + } + self.valid_next_terminals[""] = self._accepts() | self._ignored_terms + + if filter_candidate_terminals: + self._update_candidate_terminals() + + if not self.valid_next_terminals: + raise ValueError( + f"Invalid continuation for `{self.sequence_history}` `{new_seq}`" + ) def _append_to_sequence(self, new_seq: str): """Set the complete sequences value in the lexer and base""" @@ -168,14 +167,22 @@ def _update_candidate_terminals(self): Update the set of candidate terminals - If a new terminal is reached, get the accepted set of terminals from the parser - If the new sequence doesn't comprise a full terminal, filter based on partial pattern match + + Handles ambiguity by allowing terminals which are potentially complete """ - if not self.terminal_partial_seq: - self.valid_next_terminals = self._accepts() | self._ignored_terms - else: - self.valid_next_terminals = set([ - term for term in self.valid_next_terminals - if self.partial_seq_validator[term](self.terminal_partial_seq) - ]) + to_prune_sequences = set() + for incomplete_seq, terminals in self.valid_next_terminals.items(): + if incomplete_seq != "": + self.valid_next_terminals[incomplete_seq] = set([ + term for term in self.valid_next_terminals[incomplete_seq] + if term != "$END" + and self.partial_seq_validator[term](incomplete_seq) + ]) + if not self.valid_next_terminals[incomplete_seq]: + to_prune_sequences.add(incomplete_seq) + + for to_prune_seq in to_prune_sequences: + del self.valid_next_terminals[to_prune_seq] def is_valid_next_seq(self, new_seq: Optional[str]): """ @@ -184,12 +191,16 @@ def is_valid_next_seq(self, new_seq: Optional[str]): new_seq can be a string or None representing EOS """ if new_seq is None: - return "$END" in self.valid_next_terminals - for term in self.valid_next_terminals: - if term != "$END": - full_terminal_candidate = self.terminal_partial_seq + new_seq - if self.partial_seq_validator[term](full_terminal_candidate): - return True + return "$END" in [ + term for terminals in self.valid_next_terminals.values() + for term in terminals + ] + for incomplete_seq, terminals in self.valid_next_terminals.items(): + candidate = incomplete_seq + new_seq + for term in terminals: + if term != "$END": + if self.partial_seq_validator[term](candidate): + return True return False @@ -200,21 +211,25 @@ def __init__(self, tokenizer, legal_chars: Optional[set[str]] = None): """ Trie structure for efficiently finding tokens which are suffixes of other sequences """ - self.norm_vocab = {} + self.norm_vocab = collections.defaultdict(set) for token_id in tokenizer.vocab.values(): + if token_id == tokenizer.eos_token_id: + self.norm_vocab[None].add(token_id) + continue bos_len = len(tokenizer.bos_token) norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[bos_len:] if legal_chars is None or all( [char in legal_chars for char in norm_token]): - self.norm_vocab[norm_token] = token_id + self.norm_vocab[norm_token].add(token_id) - self.token_to_id_set = collections.defaultdict(set) - for token_str, token_id in self.norm_vocab.items(): - self.token_to_id_set[token_str].add(token_id) + # faster lookups, reduce time by 10% + self.norm_vocab_set = set(self.norm_vocab) self.trie = {} for word in self.norm_vocab: current_dict = self.trie + if word is None: + continue for char in word: if char not in current_dict: current_dict[char] = {} @@ -264,7 +279,7 @@ def get_next_level_token_prefixes_uncached(self, subprefix: str, _node=None): return results def is_token(self, seq): - return seq in self.norm_vocab + return seq in self.norm_vocab_set class NextTokenValidator: @@ -279,17 +294,14 @@ def __init__(self, tokenizer, grammar: str, grammar_start: str = "start", - num_threads: Optional[int] = None): + legal_chars: Optional[set[str]] = None, + ): self.tokenizer = tokenizer - self.token_trie = TokenTrie(tokenizer) + self.token_trie = TokenTrie(tokenizer, legal_chars=legal_chars) self.parser = InteractivePredictiveLALRParser(grammar=grammar, start=grammar_start) - # TODO: threading - if num_threads is None: - self.num_threads = os.cpu_count() // 2 - def step_seq(self, new_seq): self.parser.step_seq(new_seq) @@ -302,13 +314,16 @@ def valid_token_str_set(self): 2) for each token in the stack, validate against the parser - if valid, add all children to the stack for later processing - if valid AND a token, add to valid_token_set + + TODO: this can be improved with multi-threading """ valid_token_str_set = set() + if self.parser.is_valid_next_seq(None): + valid_token_str_set.add(self.tokenizer.eos_token) token_prefix_stack = collections.deque([""]) while token_prefix_stack: token_prefix = token_prefix_stack.pop() for child_token_prefix in self.token_trie.get_next_level_token_prefixes(token_prefix): - # TODO: Handle EOS token by passing None if self.parser.is_valid_next_seq(child_token_prefix): token_prefix_stack.append(child_token_prefix) if self.token_trie.is_token(child_token_prefix): @@ -323,8 +338,8 @@ def valid_token_id_set(self): note that some token strings correspond to multiple token IDs """ return set.union(*[ - self.token_trie.token_to_id_set[tok] - for tok in self.valid_token_str_set + self.token_trie.norm_vocab[tok_str] + for tok_str in self.valid_token_str_set ]) From b3f6502450a0bd7481d9a2fbfaaa3fa58c32588d Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 05:46:02 -0600 Subject: [PATCH 23/76] fix doc rendering --- docs/source/grammars/grammars.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/grammars/grammars.rst b/docs/source/grammars/grammars.rst index 64f863de422b..807aa103cd05 100644 --- a/docs/source/grammars/grammars.rst +++ b/docs/source/grammars/grammars.rst @@ -66,6 +66,7 @@ Regexp processing is the most expensive task for GrammarLogitsProcessor. When de Breaking down the following expressions ESCAPE_STRING into an expression with many faster-terminating regex resulted in an ~8x speedup: .. code-block:: + start: value ?value: dict | list @@ -104,7 +105,6 @@ Breaking down the following expressions ESCAPE_STRING into an expression with ma # %import common.SIGNED_NUMBER # %import common.WS - **Constrain legal characters** Every legal character in the alphabet must be checked against the parser by default. Mistral tokenizer, for example, has an alphabet of 3,298 characters, here are 40 random examples: From 517e32986c5537b61c3a5f74d3e74b49746d3f03 Mon Sep 17 00:00:00 2001 From: lapp0 Date: Thu, 21 Dec 2023 11:46:59 +0000 Subject: [PATCH 24/76] Update grammars.rst --- docs/source/grammars/grammars.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/grammars/grammars.rst b/docs/source/grammars/grammars.rst index 807aa103cd05..c366a20533ad 100644 --- a/docs/source/grammars/grammars.rst +++ b/docs/source/grammars/grammars.rst @@ -56,6 +56,7 @@ Performance ----------- For a simple JSON grammar, on the authors mid-end laptop using codeLlama-7b's vocabulary, generation occurred at ~10 validated logit sets per second. However performance was improved dramatically from the baseline with a few tweaks to ~400/s. These tweaks include + - Optimizing the grammar increased performance by ~10x - Constraining legal characters increased performance by ~4x From 2087a82348e89a5bfb8addf027694f6a714e01d5 Mon Sep 17 00:00:00 2001 From: lapp0 Date: Thu, 21 Dec 2023 11:47:53 +0000 Subject: [PATCH 25/76] I had an 8x speedup when fixing ESCAPE_STRING, but a 10x speedup from fixing WS and SIGNED_NUMBER in addition --- docs/source/grammars/grammars.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/grammars/grammars.rst b/docs/source/grammars/grammars.rst index c366a20533ad..7c7688220ae1 100644 --- a/docs/source/grammars/grammars.rst +++ b/docs/source/grammars/grammars.rst @@ -64,7 +64,7 @@ For a simple JSON grammar, on the authors mid-end laptop using codeLlama-7b's vo Regexp processing is the most expensive task for GrammarLogitsProcessor. When designing your EBNF, it's better to keep your regexp short and simple if at all possible. -Breaking down the following expressions ESCAPE_STRING into an expression with many faster-terminating regex resulted in an ~8x speedup: +Breaking down the following expressions ESCAPE_STRING into an expression with many faster-terminating regex resulted in an ~10x speedup: .. code-block:: From 45cce24306d0a58a809fe7339d8e740b8bceb858 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 18:09:34 -0600 Subject: [PATCH 26/76] clean up and fix bugs --- docs/source/grammars/grammars.rst | 5 +- tests/samplers/test_grammar.py | 140 ++++++++++++++++++++++++------ vllm/grammar.py | 42 ++++++--- 3 files changed, 143 insertions(+), 44 deletions(-) diff --git a/docs/source/grammars/grammars.rst b/docs/source/grammars/grammars.rst index 807aa103cd05..6ebc9c3b3c82 100644 --- a/docs/source/grammars/grammars.rst +++ b/docs/source/grammars/grammars.rst @@ -120,6 +120,7 @@ Expect an ~10x speedup if you constrain your generation to UTF-8, eliminating 3, .. code-block:: GrammarLogitsProcessor( - ..., - legal_chars=set([chr(i) for i in range(256)]) + tokenizer, + grammar, + legal_chars=set(map(chr, range(256))),, ) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index 668cac54d367..67fc2947a007 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -1,19 +1,25 @@ import pytest +import numpy as np import random +import json from transformers import AutoTokenizer -from vllm.grammar import TokenTrie, NextTokenValidator +from vllm.grammar import TokenTrie, NextTokenValidator, GrammarLogitsProcessor -MODELS = ["codellama/CodeLlama-7b-hf"] +@pytest.fixture +def tokenizer(): + model_id = "codellama/CodeLlama-7b-hf" + return AutoTokenizer.from_pretrained(model_id) @pytest.fixture def json_grammar(): return r""" start: value - value: dict + value: WS* object WS* + object: dict | list | string | signed_number -> number @@ -24,13 +30,13 @@ def json_grammar(): list : "[" [value ("," value)*] "]" dict : "{" [pair ("," pair)*] "}" - pair : string ":" value + pair : WS* string WS* ":" value string : "\"" escaped_string_char* "\"" - escaped_string_char: STR_INNER_CHAR | ESCAPED_CHAR - ESCAPED_CHAR: "\\" ANY_CHAR - STR_INNER_CHAR: /[^\\\"]/ - ANY_CHAR: /[.]/ + escaped_string_char: _STR_INNER_CHAR | _ESCAPED_CHAR + _ESCAPED_CHAR: "\\" _ANY_CHAR + _STR_INNER_CHAR: /[^\\\"]/ + _ANY_CHAR: /[.]/ signed_number: ["+"|"-"] number number: float | int @@ -42,7 +48,6 @@ def json_grammar(): DIGIT: "0".."9" WS: /[ \t\f\r\n]/ - %ignore WS """ @@ -102,18 +107,23 @@ def csv_example(): return """ #foo\tbar\tbaz 1\t2\t3 -bif\t\bif\tbif -""".strip() +bif\tbif\tbif +""".strip() + "\n" # grammar requires newline before eos -@pytest.mark.parametrize("model_id", MODELS) -def test_next_token_validator_simple( - model_id, -): + +def sample_from_logits(logits): + probs = np.exp(logits) / np.sum(np.exp(logits)) + return np.random.choice( + len(logits), + p=probs + ) + + +def test_next_token_validator_simple(tokenizer): hello_grammar = """ ?start: "hello" | "world" """ - tokenizer = AutoTokenizer.from_pretrained(model_id) ntv = NextTokenValidator(tokenizer, hello_grammar) # tokens specific to codeLlama @@ -121,32 +131,32 @@ def test_next_token_validator_simple( assert sorted(ntv.valid_token_id_set) == [107, 122, 354, 827, 3952, 11526, 12199, 13762, 14181, 29882, 29893] -@pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("grammar_fixture, example_fixture", [ ("json_grammar", "json_example"), ("csv_grammar", "csv_example") ]) def test_can_generate_with_grammar( - model_id, + tokenizer, request, grammar_fixture, example_fixture ): - """Assert that example file is legal to generate with GrammarLogitsProcessor""" + """Assert that example file is legal to generate with NextTokenValidator""" grammar = request.getfixturevalue(grammar_fixture) example = request.getfixturevalue(example_fixture) - tokenizer = AutoTokenizer.from_pretrained(model_id) next_token_validator = NextTokenValidator( tokenizer, grammar, - legal_chars=set([chr(i) for i in range(256)]) + legal_chars=set(map(chr, range(256))), ) example_remainder = example while example_remainder: legal_next_token_strs = list(next_token_validator.valid_token_str_set) random.shuffle(legal_next_token_strs) for tok in legal_next_token_strs: + if tok is None: + continue if example_remainder.startswith(tok): next_token_validator.step_seq(tok) example_remainder = example_remainder[len(tok):] @@ -158,11 +168,7 @@ def test_can_generate_with_grammar( assert None in next_token_validator.valid_token_str_set -@pytest.mark.parametrize("model_id", MODELS) -def test_token_trie_sanity( - model_id -): - tokenizer = AutoTokenizer.from_pretrained(model_id) +def test_token_trie_sanity(tokenizer): toktrie = TokenTrie(tokenizer) all_prefixes = toktrie.get_next_level_token_prefixes("") @@ -174,6 +180,7 @@ def test_token_trie_sanity( assert all([ t[0] in all_prefixes for t in toktrie.norm_vocab + if t is not None ]) # construct the set of next level prefixes @@ -185,8 +192,85 @@ def test_token_trie_sanity( assert len(set([len(spfx) for spfx in all_subprefixes])) > 1 -def test_assert_fails_for_invalid_examples(): - assert False +@pytest.mark.parametrize("start_tok, validator", [ + (29945, float), # 5 - float + (285, lambda s: bool(json.dumps(s))), # f for false + (260, lambda s: bool(json.dumps(s))), # t for false + (376, lambda s: str(json.dumps(s))), # " for string +]) +def test_gen_primative(json_grammar, tokenizer, start_tok, validator): + # Note: string may last a + for _ in range(4): + grammar_logits_processor = GrammarLogitsProcessor( + tokenizer, + json_grammar, + legal_chars=set(map(chr, range(256))), + ) + + token_ids = [start_tok] + while True: + logits = grammar_logits_processor( + token_ids=token_ids, + logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) + ) + new_token_id = sample_from_logits(logits) + if new_token_id == tokenizer.eos_token_id: + break + token_ids.append(new_token_id) + + validator(tokenizer.decode(token_ids)) + + +def test_random_grammared_generation(json_grammar, tokenizer): + # Generate JSON token-by-token with random logits until EOS is hit, then validate JSON + # Bias logits so open syntax such that closing syntax such as ]}", + # occur more frequently as time goes on so we don't get stuck in generation + + num_repeats = 8 + + grammar_logits_processor = GrammarLogitsProcessor( + tokenizer, + json_grammar, + legal_chars=set(map(chr, range(256))), + ) + + # bias closing tokens logits to prevent infinite generation + closing_token_ids = set([ + tok_id + for tok_str in ["]", "}", '"', ",", None] + for tok_id in grammar_logits_processor.token_trie.norm_vocab[tok_str] + ]) + closing_tokens_bias = -10 + + # without this it mostly generates numbers since numbers represent far + # more tokens than these, and numbers close much more quickly, are less + # gramatically complicated and result in a less interesting test + opening_token_ids = set([ + tok_id + for tok_str in ["[", "{", '"', ","] + for tok_id in grammar_logits_processor.token_trie.norm_vocab[tok_str] + ]) + opening_tokens_bias = 5 + + token_ids = [] + while True: + logits = grammar_logits_processor( + token_ids=token_ids, + logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) + ) + + for closing_token_id in closing_token_ids: + logits[closing_token_id] += closing_tokens_bias + for opening_token_id in opening_token_ids: + logits[opening_token_id] += opening_tokens_bias + + new_token_id = sample_from_logits(logits) + if new_token_id == tokenizer.eos_token_id: + break + token_ids.append(new_token_id) + closing_tokens_bias += 0.2 + opening_tokens_bias -= 0.1 + def test_integration_with_vllm(): diff --git a/vllm/grammar.py b/vllm/grammar.py index b2fa6b7a28c6..af91b98bd6cc 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -85,6 +85,9 @@ def __init__(self, grammar: str, start: str): base_interactive_parser.parser_state, base_interactive_parser.lexer_thread) + # fallback parser from start of terminal in case of ambiguous (LR(1)) + self._terminal_start_parser = self.interactive_parser.copy() + self.partial_seq_validator = { term.name: self._get_partial_pattern_validator(term.pattern) for term in self.parser.terminals @@ -133,38 +136,45 @@ def step_seq(self, new_seq: str): - Update the set of candidate terminals """ for char in new_seq: - self._append_to_sequence(char) - filter_candidate_terminals = True + # update canonical sequence and lexer sequence + self.sequence_history += new_seq + self.interactive_parser.lexer_thread.state.text = self.sequence_history + + success = False try: self.interactive_parser.exhaust_lexer() except UnexpectedCharacters as e: pass except UnexpectedToken as e: - filter_candidate_terminals = False + # fall back so full token can be reprocessed + self.interactive_parser = self._terminal_start_parser.copy() + self.interactive_parser.lexer_thread.state.text = self.sequence_history + else: + success = True self.valid_next_terminals = { (incomplete_seq + char): term for incomplete_seq, term in self.valid_next_terminals.items() } - self.valid_next_terminals[""] = self._accepts() | self._ignored_terms - if filter_candidate_terminals: - self._update_candidate_terminals() + # if successfully parsed new token, add blank state and set fallback checkpoint + if success: + self.valid_next_terminals[""] = self._accepts() | self._ignored_terms + self._terminal_start_parser = self.interactive_parser.copy() + + self._filter_candidate_terminals() if not self.valid_next_terminals: raise ValueError( f"Invalid continuation for `{self.sequence_history}` `{new_seq}`" ) - def _append_to_sequence(self, new_seq: str): - """Set the complete sequences value in the lexer and base""" - self.sequence_history += new_seq - self.interactive_parser.lexer_thread.state.text = self.sequence_history + print(self.valid_next_terminals) - def _update_candidate_terminals(self): + def _filter_candidate_terminals(self): """ - Update the set of candidate terminals + Filter the set of candidate terminals - If a new terminal is reached, get the accepted set of terminals from the parser - If the new sequence doesn't comprise a full terminal, filter based on partial pattern match @@ -319,7 +329,7 @@ def valid_token_str_set(self): """ valid_token_str_set = set() if self.parser.is_valid_next_seq(None): - valid_token_str_set.add(self.tokenizer.eos_token) + valid_token_str_set.add(None) token_prefix_stack = collections.deque([""]) while token_prefix_stack: token_prefix = token_prefix_stack.pop() @@ -353,7 +363,8 @@ def __init__(self, *args, **kwargs): self.generation_token_ids = [] self.generation_text = "" - def __call__(self, token_ids, logits): + + def _update_seen_token_ids(self, token_ids): # ensure integrity assert token_ids[:len(self.generation_token_ids)] == self.generation_token_ids self.generation_token_ids = token_ids @@ -364,6 +375,9 @@ def __call__(self, token_ids, logits): self.generation_text = all_text self.step_seq(new_text) + def __call__(self, token_ids, logits): + self._update_seen_token_ids(token_ids) + # get valid token IDs and modify logits valid_token_ids = self.valid_token_id_set logits = [ From b8a625fde613cfafd59ef9b15a6a593da8d3d66f Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 18:13:13 -0600 Subject: [PATCH 27/76] update docs --- docs/source/grammars/grammars.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/grammars/grammars.rst b/docs/source/grammars/grammars.rst index d91ccc7af896..4c66a0454210 100644 --- a/docs/source/grammars/grammars.rst +++ b/docs/source/grammars/grammars.rst @@ -57,14 +57,14 @@ Performance For a simple JSON grammar, on the authors mid-end laptop using codeLlama-7b's vocabulary, generation occurred at ~10 validated logit sets per second. However performance was improved dramatically from the baseline with a few tweaks to ~400/s. These tweaks include -- Optimizing the grammar increased performance by ~10x -- Constraining legal characters increased performance by ~4x +- Optimizing the grammar +- Constraining legal characters **Design your EBNF grammar with minimal regexp** Regexp processing is the most expensive task for GrammarLogitsProcessor. When designing your EBNF, it's better to keep your regexp short and simple if at all possible. -Breaking down the following expressions ESCAPE_STRING into an expression with many faster-terminating regex resulted in an ~10x speedup: +Breaking down the following expressions ESCAPE_STRING into an expression with many faster-terminating regex resulted in a dramatic speedup: .. code-block:: @@ -116,7 +116,7 @@ Every legal character in the alphabet must be checked against the parser by defa Likely many of these characters aren't useful in your generation. -Expect an ~10x speedup if you constrain your generation to UTF-8, eliminating 3,042 unnecessary characters. +Expect increased performance if you constrain your generation to UTF-8, eliminating 3,042 unnecessary characters. .. code-block:: From 9d2b1f01b74c021c48b518b5cf567490e34df28f Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 18:40:25 -0600 Subject: [PATCH 28/76] bug fix --- requirements-dev.txt | 2 +- vllm/grammar.py | 148 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 133 insertions(+), 17 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index c9b212c923a4..bebc1c395be4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,4 +12,4 @@ types-setuptools pytest pytest-forked pytest-asyncio - +httpx # https://github.com/vllm-project/vllm/issues/1975 diff --git a/vllm/grammar.py b/vllm/grammar.py index af91b98bd6cc..caaa04af212d 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -2,12 +2,14 @@ from copy import deepcopy, copy import os import regex -from typing import Optional +import torch +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from typing import Optional, List, Set, Union from lark import Lark from lark.parsers.lalr_interactive_parser import InteractiveParser from lark.parsers.lalr_parser_state import ParserState -from lark.lexer import Token, LexerState, PatternStr, PatternRE +from lark.lexer import Token, LexerState, Pattern, PatternStr, PatternRE from lark.exceptions import UnexpectedCharacters, UnexpectedToken @@ -104,7 +106,7 @@ def __init__(self, grammar: str, start: str): self.valid_next_terminals = {"": self._accepts() | self._ignored_terms} @staticmethod - def _get_partial_pattern_validator(pattern): + def _get_partial_pattern_validator(pattern: Pattern): """ Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE Returns a function which validates a partial string @@ -136,9 +138,8 @@ def step_seq(self, new_seq: str): - Update the set of candidate terminals """ for char in new_seq: - # update canonical sequence and lexer sequence - self.sequence_history += new_seq + self.sequence_history += char self.interactive_parser.lexer_thread.state.text = self.sequence_history success = False @@ -149,7 +150,6 @@ def step_seq(self, new_seq: str): except UnexpectedToken as e: # fall back so full token can be reprocessed self.interactive_parser = self._terminal_start_parser.copy() - self.interactive_parser.lexer_thread.state.text = self.sequence_history else: success = True @@ -215,12 +215,15 @@ def is_valid_next_seq(self, new_seq: Optional[str]): class TokenTrie: + """ + Trie structure for efficiently finding tokens which are suffixes of other sequences + """ + IS_TOKEN = (None, "is complete token") - def __init__(self, tokenizer, legal_chars: Optional[set[str]] = None): - """ - Trie structure for efficiently finding tokens which are suffixes of other sequences - """ + def __init__(self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + legal_chars: Optional[Set[str]] = None): self.norm_vocab = collections.defaultdict(set) for token_id in tokenizer.vocab.values(): if token_id == tokenizer.eos_token_id: @@ -248,14 +251,14 @@ def __init__(self, tokenizer, legal_chars: Optional[set[str]] = None): self._next_level_token_prefixes_cache = {} - def get_next_level_token_prefixes(self, subprefix: str): + def get_next_level_token_prefixes(self, subprefix: str) -> Set[str]: if subprefix not in self._next_level_token_prefixes_cache: self._next_level_token_prefixes_cache[subprefix] = ( self.get_next_level_token_prefixes_uncached(subprefix) ) return self._next_level_token_prefixes_cache[subprefix] - def get_next_level_token_prefixes_uncached(self, subprefix: str, _node=None): + def get_next_level_token_prefixes_uncached(self, subprefix: str, _node: dict = None) -> Set[str]: """ Traverse the trie starting from a specified subprefix to identify all child nodes that represent the longest possible strings without omitting any nodes that contain complete tokens. @@ -288,7 +291,7 @@ def get_next_level_token_prefixes_uncached(self, subprefix: str, _node=None): return results - def is_token(self, seq): + def is_token(self, seq: Optional[str]) -> bool: return seq in self.norm_vocab_set @@ -312,7 +315,7 @@ def __init__(self, self.parser = InteractivePredictiveLALRParser(grammar=grammar, start=grammar_start) - def step_seq(self, new_seq): + def step_seq(self, new_seq: str): self.parser.step_seq(new_seq) @property @@ -347,6 +350,7 @@ def valid_token_id_set(self): get valid token id based on self.valid_token_str_set note that some token strings correspond to multiple token IDs """ + print(self.valid_token_str_set) return set.union(*[ self.token_trie.norm_vocab[tok_str] for tok_str in self.valid_token_str_set @@ -364,7 +368,7 @@ def __init__(self, *args, **kwargs): self.generation_text = "" - def _update_seen_token_ids(self, token_ids): + def _update_seen_token_ids(self, token_ids: List[int]): # ensure integrity assert token_ids[:len(self.generation_token_ids)] == self.generation_token_ids self.generation_token_ids = token_ids @@ -375,7 +379,7 @@ def _update_seen_token_ids(self, token_ids): self.generation_text = all_text self.step_seq(new_text) - def __call__(self, token_ids, logits): + def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: self._update_seen_token_ids(token_ids) # get valid token IDs and modify logits @@ -385,3 +389,115 @@ def __call__(self, token_ids, logits): for tok_id, logit_val in zip(sorted(self.tokenizer.vocab.values()), logits) ] return logits + + +if __name__ == "__main__": + from transformers import AutoTokenizer + import numpy as np + model_id = "codellama/CodeLlama-7b-hf" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + json_grammar = r""" + start: value + value: WS* object WS* + object: dict + | list + | string + | signed_number -> number + | "true" -> true + | "false" -> false + | "null" -> null + + list : "[" [value ("," value)*] "]" + + dict : "{" [pair ("," pair)*] "}" + pair : string ":" value + + string : "\"" escaped_string_char* "\"" + escaped_string_char: _STR_INNER_CHAR | _ESCAPED_CHAR + _ESCAPED_CHAR: "\\" _ANY_CHAR + _STR_INNER_CHAR: /[^\\\"]/ + _ANY_CHAR: /[.]/ + + signed_number: ["+"|"-"] number + number: float | int + float: int exp | decimal exp? + decimal: int "." int? | "." int + exp: ("e"|"E") signed_int + signed_int: ["+"|"-"] int + int: DIGIT+ + DIGIT: "0".."9" + + WS: /[ \t\f\r\n]/ + """ + + + sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts))) + + + grammar_logits_processor = GrammarLogitsProcessor( + tokenizer, + json_grammar, + legal_chars=set(map(chr, range(256))), + ) + + start_tok = 260 + token_ids = [start_tok] + while True: + logits = grammar_logits_processor( + token_ids=token_ids, + logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) + ) + new_token_id = sample_from_logits(logits) + if new_token_id == tokenizer.eos_token_id: + break + token_ids.append(new_token_id) + + import pdb;pdb.set_trace() + + """ + legal_chars = set([c for c in map(chr, range(256)) if c.isprintable()]) + print(len(legal_chars)) + grammar_logits_processor = GrammarLogitsProcessor( + tokenizer, + json_grammar, + legal_chars=legal_chars + ) + + closing_token_ids = set([ + tok_id + for tok_str in ["]", "}", '"', ","] + for tok_id in grammar_logits_processor.token_trie.norm_vocab[tok_str] + ]) + closing_tokens_bias = -5 + + + + token_ids = [] + grammar_logits_processor._update_seen_token_ids(token_ids) + for _ in range(100000): + print(tokenizer.decode(token_ids)) + print(repr(tokenizer.decode(token_ids))) + logits = grammar_logits_processor( + token_ids=token_ids, + logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) + ) + + + for closing_token_id in closing_token_ids: + logits[closing_token_id] += closing_tokens_bias + + for opening_token_id in opening_token_ids: + logits[opening_token_id] += opening_tokens_bias + + + new_token_id = sample_from_logits(logits) + if new_token_id == tokenizer.eos_token_id: + break + token_ids.append(new_token_id) + closing_tokens_bias += 0.2 + opening_tokens_bias -= 0.1 + + + print(tokenizer.decode(token_ids)) + """ From 64ef88161bda150db3722bca1c0d3a08bd1e3a21 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 18:41:14 -0600 Subject: [PATCH 29/76] remove stray prints --- vllm/grammar.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index caaa04af212d..684da6ad2418 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -170,8 +170,6 @@ def step_seq(self, new_seq: str): f"Invalid continuation for `{self.sequence_history}` `{new_seq}`" ) - print(self.valid_next_terminals) - def _filter_candidate_terminals(self): """ Filter the set of candidate terminals @@ -350,7 +348,6 @@ def valid_token_id_set(self): get valid token id based on self.valid_token_str_set note that some token strings correspond to multiple token IDs """ - print(self.valid_token_str_set) return set.union(*[ self.token_trie.norm_vocab[tok_str] for tok_str in self.valid_token_str_set From 90b54d70ce1b7884b719c72afdd84abaa6f9908e Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 18:41:53 -0600 Subject: [PATCH 30/76] remove stray debug code --- vllm/grammar.py | 112 ------------------------------------------------ 1 file changed, 112 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 684da6ad2418..571754b4cc6c 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -386,115 +386,3 @@ def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: for tok_id, logit_val in zip(sorted(self.tokenizer.vocab.values()), logits) ] return logits - - -if __name__ == "__main__": - from transformers import AutoTokenizer - import numpy as np - model_id = "codellama/CodeLlama-7b-hf" - tokenizer = AutoTokenizer.from_pretrained(model_id) - - json_grammar = r""" - start: value - value: WS* object WS* - object: dict - | list - | string - | signed_number -> number - | "true" -> true - | "false" -> false - | "null" -> null - - list : "[" [value ("," value)*] "]" - - dict : "{" [pair ("," pair)*] "}" - pair : string ":" value - - string : "\"" escaped_string_char* "\"" - escaped_string_char: _STR_INNER_CHAR | _ESCAPED_CHAR - _ESCAPED_CHAR: "\\" _ANY_CHAR - _STR_INNER_CHAR: /[^\\\"]/ - _ANY_CHAR: /[.]/ - - signed_number: ["+"|"-"] number - number: float | int - float: int exp | decimal exp? - decimal: int "." int? | "." int - exp: ("e"|"E") signed_int - signed_int: ["+"|"-"] int - int: DIGIT+ - DIGIT: "0".."9" - - WS: /[ \t\f\r\n]/ - """ - - - sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts))) - - - grammar_logits_processor = GrammarLogitsProcessor( - tokenizer, - json_grammar, - legal_chars=set(map(chr, range(256))), - ) - - start_tok = 260 - token_ids = [start_tok] - while True: - logits = grammar_logits_processor( - token_ids=token_ids, - logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) - ) - new_token_id = sample_from_logits(logits) - if new_token_id == tokenizer.eos_token_id: - break - token_ids.append(new_token_id) - - import pdb;pdb.set_trace() - - """ - legal_chars = set([c for c in map(chr, range(256)) if c.isprintable()]) - print(len(legal_chars)) - grammar_logits_processor = GrammarLogitsProcessor( - tokenizer, - json_grammar, - legal_chars=legal_chars - ) - - closing_token_ids = set([ - tok_id - for tok_str in ["]", "}", '"', ","] - for tok_id in grammar_logits_processor.token_trie.norm_vocab[tok_str] - ]) - closing_tokens_bias = -5 - - - - token_ids = [] - grammar_logits_processor._update_seen_token_ids(token_ids) - for _ in range(100000): - print(tokenizer.decode(token_ids)) - print(repr(tokenizer.decode(token_ids))) - logits = grammar_logits_processor( - token_ids=token_ids, - logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) - ) - - - for closing_token_id in closing_token_ids: - logits[closing_token_id] += closing_tokens_bias - - for opening_token_id in opening_token_ids: - logits[opening_token_id] += opening_tokens_bias - - - new_token_id = sample_from_logits(logits) - if new_token_id == tokenizer.eos_token_id: - break - token_ids.append(new_token_id) - closing_tokens_bias += 0.2 - opening_tokens_bias -= 0.1 - - - print(tokenizer.decode(token_ids)) - """ From e441584839bbccd99a9d60ac337e3a3927ff73e3 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 19:03:39 -0600 Subject: [PATCH 31/76] clean up tests, add more tests --- tests/samplers/test_grammar.py | 64 ++++++++++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index 67fc2947a007..fce28b4eb9ff 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -152,9 +152,7 @@ def test_can_generate_with_grammar( ) example_remainder = example while example_remainder: - legal_next_token_strs = list(next_token_validator.valid_token_str_set) - random.shuffle(legal_next_token_strs) - for tok in legal_next_token_strs: + for tok in next_token_validator.valid_token_str_set: if tok is None: continue if example_remainder.startswith(tok): @@ -168,6 +166,66 @@ def test_can_generate_with_grammar( assert None in next_token_validator.valid_token_str_set +def test_json_valid_with_edge_cases(tokenizer, json_grammar): + valid_edgecase_jsons = [ + "{\n \"emptyObject\": {\n \"innerEmptyObject\": {}\n }\n}", # empty obj + "{\n \"mixedArray\": [null, 123, \"text\", true, {\"key\": \"value\"}]\n}", # mixed array + "{\n \"deepArray\": [[[[[\"deep\"]]]]]\n}" ,# deeply nested list + "{\n \"\": true,\n \"regularKey\": false\n}", # empty keys + "{\n \"\u043a\u043b\u044e\u0447\": \"\u0437\u043d\u0430\u0447\u0435\u043d\u0438\u0435\",\n \"emoji\ud83d\ude42\": \"value\ud83d\ude00\"\n}", # unicode keys + ] + + next_token_validator = NextTokenValidator( + tokenizer, + json_grammar, + ) + for example in valid_edgecase_jsons: + example_remainder = example + while example_remainder: + for tok in next_token_validator.valid_token_str_set: + if tok is None: + continue + if example_remainder.startswith(tok): + next_token_validator.step_seq(tok) + example_remainder = example_remainder[len(tok):] + break + else: + raise Exception(f"Couldn't find token to create legal output given grammar: '{example_remainder}'") + + # EOS should be in the set of next legal tokens + assert None in next_token_validator.valid_token_str_set + + +def test_json_fails_with_edge_cases(tokenizer, json_grammar): + invalid_edgecase_jsons = [ + "{\n \"key1\": \"value1\",\n \"key2\": \"value2\",\n}", # trailing comma + "{\n \"key\": \"value\" // This is a comment\n}\n", # comment + "{\n \"number\": 1.2.3\n}", # incorrect decimal format + "{\n \"key\": \"value\"unexpected\"\n}", # incorrect str format + "{\n \"object\": {\"key\": \"value\"}\n}\n", # unclosed object + "{\n \"array\": [1, 2,, 3]\n}\n", # double comma + ] + + next_token_validator = NextTokenValidator( + tokenizer, + json_grammar, + ) + for example in invalid_edgecase_jsons: + example_remainder = example + while example_remainder: + for tok in next_token_validator.valid_token_str_set: + if tok is None: + continue + if example_remainder.startswith(tok): + next_token_validator.step_seq(tok) + example_remainder = example_remainder[len(tok):] + break + else: + return True + + assert False, "Invalid json was accepted" + + def test_token_trie_sanity(tokenizer): toktrie = TokenTrie(tokenizer) From 8ba330a1e11ddff96530ea899039a35bc49954ac Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 19:04:00 -0600 Subject: [PATCH 32/76] remove unused import --- tests/samplers/test_grammar.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index fce28b4eb9ff..2b4217a55b17 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -1,6 +1,5 @@ import pytest import numpy as np -import random import json from transformers import AutoTokenizer From df704b1b8ce53fc9b00e3c37d52d0fc9c636911c Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 19:21:17 -0600 Subject: [PATCH 33/76] don't modify requirements-dev.txt --- requirements-dev.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index bebc1c395be4..dd0fab9fa3d2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,4 +12,3 @@ types-setuptools pytest pytest-forked pytest-asyncio -httpx # https://github.com/vllm-project/vllm/issues/1975 From f7eb37f9ce244cf5a3d273127af7036c7423c00e Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 20:11:40 -0600 Subject: [PATCH 34/76] remove unused imports --- vllm/grammar.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 571754b4cc6c..bdc56f1ca655 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -1,6 +1,5 @@ import collections from copy import deepcopy, copy -import os import regex import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -9,7 +8,7 @@ from lark import Lark from lark.parsers.lalr_interactive_parser import InteractiveParser from lark.parsers.lalr_parser_state import ParserState -from lark.lexer import Token, LexerState, Pattern, PatternStr, PatternRE +from lark.lexer import Pattern, PatternStr, PatternRE from lark.exceptions import UnexpectedCharacters, UnexpectedToken From 577e02e403fc89e5bf97dc5e9bb31c55f8e60287 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 20:12:33 -0600 Subject: [PATCH 35/76] fix and clean tests --- tests/samplers/test_grammar.py | 36 +++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index 2b4217a55b17..3a176aff2b7b 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -7,6 +7,9 @@ from vllm.grammar import TokenTrie, NextTokenValidator, GrammarLogitsProcessor +INTEGRATION_TEST_MODELS = ["facebook/opt-125m"] + + @pytest.fixture def tokenizer(): model_id = "codellama/CodeLlama-7b-hf" @@ -35,7 +38,7 @@ def json_grammar(): escaped_string_char: _STR_INNER_CHAR | _ESCAPED_CHAR _ESCAPED_CHAR: "\\" _ANY_CHAR _STR_INNER_CHAR: /[^\\\"]/ - _ANY_CHAR: /[.]/ + _ANY_CHAR: /./ signed_number: ["+"|"-"] number number: float | int @@ -159,7 +162,7 @@ def test_can_generate_with_grammar( example_remainder = example_remainder[len(tok):] break else: - raise Exception(f"Couldn't find token to create legal output given grammar: '{example_remainder}'") + raise Exception(f"Couldn't find token to create legal output given grammar, remaining output: '{example_remainder}'") # EOS should be in the set of next legal tokens assert None in next_token_validator.valid_token_str_set @@ -169,16 +172,17 @@ def test_json_valid_with_edge_cases(tokenizer, json_grammar): valid_edgecase_jsons = [ "{\n \"emptyObject\": {\n \"innerEmptyObject\": {}\n }\n}", # empty obj "{\n \"mixedArray\": [null, 123, \"text\", true, {\"key\": \"value\"}]\n}", # mixed array - "{\n \"deepArray\": [[[[[\"deep\"]]]]]\n}" ,# deeply nested list + "{\n \"deepArray\": [[[[[\"deep\"]]]]]\n}", # deeply nested list "{\n \"\": true,\n \"regularKey\": false\n}", # empty keys - "{\n \"\u043a\u043b\u044e\u0447\": \"\u0437\u043d\u0430\u0447\u0435\u043d\u0438\u0435\",\n \"emoji\ud83d\ude42\": \"value\ud83d\ude00\"\n}", # unicode keys + "{\n \"\\u043a\\u043b\\u044e\\u0447\": \"\\u0437\\u043d\\u0430\\u0447\\u0435\\u043d\\u0438\\u0435\",\n \"emoji\\ud83d\\ude42\": \"value\\ud83d\\ude00\"\n}", # unicode keys ] - next_token_validator = NextTokenValidator( - tokenizer, - json_grammar, - ) + for example in valid_edgecase_jsons: + next_token_validator = NextTokenValidator( + tokenizer, + json_grammar, + ) example_remainder = example while example_remainder: for tok in next_token_validator.valid_token_str_set: @@ -189,7 +193,7 @@ def test_json_valid_with_edge_cases(tokenizer, json_grammar): example_remainder = example_remainder[len(tok):] break else: - raise Exception(f"Couldn't find token to create legal output given grammar: '{example_remainder}'") + raise Exception(f"Couldn't find token to create legal output given grammar, remaining output: '{example_remainder}'") # EOS should be in the set of next legal tokens assert None in next_token_validator.valid_token_str_set @@ -205,11 +209,11 @@ def test_json_fails_with_edge_cases(tokenizer, json_grammar): "{\n \"array\": [1, 2,, 3]\n}\n", # double comma ] - next_token_validator = NextTokenValidator( - tokenizer, - json_grammar, - ) - for example in invalid_edgecase_jsons: + for example in valid_edgecase_jsons: + next_token_validator = NextTokenValidator( + tokenizer, + json_grammar, + ) example_remainder = example while example_remainder: for tok in next_token_validator.valid_token_str_set: @@ -329,6 +333,6 @@ def test_random_grammared_generation(json_grammar, tokenizer): opening_tokens_bias -= 0.1 - -def test_integration_with_vllm(): +@pytest.mark.parametrize("model_id", INTEGRATION_TEST_MODELS) +def test_integration_with_vllm(model_id): assert False From e1bc7acc1e3b423a28e1a6be2667f4e294643b42 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 21 Dec 2023 20:13:03 -0600 Subject: [PATCH 36/76] return tensor --- vllm/grammar.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index bdc56f1ca655..414a013b5e80 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -378,10 +378,10 @@ def _update_seen_token_ids(self, token_ids: List[int]): def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: self._update_seen_token_ids(token_ids) - # get valid token IDs and modify logits - valid_token_ids = self.valid_token_id_set - logits = [ - logit_val if tok_id in valid_token_ids else -float("inf") - for tok_id, logit_val in zip(sorted(self.tokenizer.vocab.values()), logits) - ] + # modify logits given valid token IDs + N = len(logits) + mask = torch.zeros(N, dtype=torch.bool) + valid = torch.tensor(list(self.valid_token_id_set), dtype=torch.long) + mask[valid] = True + logits[~mask] = float('-inf') return logits From 3a55edc3aaa0c57b552a06e0cd043485b270b451 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 22 Dec 2023 04:57:56 -0600 Subject: [PATCH 37/76] update tests --- tests/samplers/test_grammar.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index 3a176aff2b7b..a030e865ca43 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -5,9 +5,7 @@ from transformers import AutoTokenizer from vllm.grammar import TokenTrie, NextTokenValidator, GrammarLogitsProcessor - - -INTEGRATION_TEST_MODELS = ["facebook/opt-125m"] +from vllm import LLM, SamplingParams @pytest.fixture @@ -209,7 +207,7 @@ def test_json_fails_with_edge_cases(tokenizer, json_grammar): "{\n \"array\": [1, 2,, 3]\n}\n", # double comma ] - for example in valid_edgecase_jsons: + for example in invalid_edgecase_jsons: next_token_validator = NextTokenValidator( tokenizer, json_grammar, @@ -333,6 +331,24 @@ def test_random_grammared_generation(json_grammar, tokenizer): opening_tokens_bias -= 0.1 -@pytest.mark.parametrize("model_id", INTEGRATION_TEST_MODELS) -def test_integration_with_vllm(model_id): +def test_integration_with_vllm(vllm_runner, hf_runner): + model_id = "facebook/opt-125m" + dtype = "half" + + grammar_logits_processor = GrammarLogitsProcessor( + hf_runner(model_id, dtype=dtype).tokenizer, + """?start: "hello" | "world" """ + ) + sampling_params = SamplingParams(temperature=0.01, + top_p=0.1, + max_tokens=256, + logits_processors=[grammar_logits_processor]) + llm = LLM(model=model_id, + max_num_batched_tokens=4096, + tensor_parallel_size=1) + prompts = ["Who is the president of Jamaica?", "What is 1+1?"] + outputs = llm.generate(prompts, sampling_params=sampling_params) + + print(outputs) + assert False From 236553d3f3b6f8f8b9d26003d91ef9ade305e2cf Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 22 Dec 2023 05:28:36 -0600 Subject: [PATCH 38/76] handle batch requests --- vllm/grammar.py | 73 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 12 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 414a013b5e80..36aa5f1882c9 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -1,5 +1,6 @@ import collections from copy import deepcopy, copy +from dataclasses import dataclass import regex import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -353,30 +354,78 @@ def valid_token_id_set(self): ]) +# TODO: replace with subclass called NextTokenIDValidator to make things cleaner +@dataclass +class BatchDataItemParser: + text: str + token_ids: List[str] + parser: NextTokenValidator +) + + class GrammarLogitsProcessor(NextTokenValidator): """ Apply NextTokenValidator in __call__ and set excluded tokens logits to -inf """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, + tokenizer, + grammar: str, + grammar_start: str = "start", + legal_chars: Optional[set[str]] = None, + ): + self.tokenizer = tokenizer + self.grammar = grammar + self.grammar_start = grammar_start + self.legal_chars = legal_chars + + # track multiple parsers for batch requests + self.batch_data_item_parsers: List[BatchDataItemParser] = [] + + def _new_batch_data_item_parser(self): + return BatchDataItemParser( + "", + [], + NextTokenValidator( + tokenizer=self.tokenizer, + grammar=self.grammar, + grammar_start=self.grammar_start, + legal_chars=self.legal_chars + ) + ) - self.generation_token_ids = [] - self.generation_text = "" + def _get_batch_data_item_parser(self, token_ids: List[int]): + """ + Get longest batch data item parser which matches the seen tokens. + This is generally the corresponding parser, but if there's a collision + their parsers are interchangable + """ + for batch_data_item_parser in sorted( + self.batch_data_item_parsers, + key=lambda bdip: -len(bdip.token_ids) + ): + if token_ids[:len(bdip.token_ids)] == bdip.token_ids: + return bdip + + # no match, make new + return self._new_batch_data_item_parser() - def _update_seen_token_ids(self, token_ids: List[int]): - # ensure integrity - assert token_ids[:len(self.generation_token_ids)] == self.generation_token_ids - self.generation_token_ids = token_ids + def _update_seen_token_ids(self, bdip: BatchDataItemParser, token_ids: List[int]): + + # update batch item token tracker + bdip.token_ids = token_ids # step forward all_text = self.tokenizer.decode(token_ids) - new_text = all_text[len(self.generation_text):] - self.generation_text = all_text - self.step_seq(new_text) + new_text = all_text[len(bdip.text):] + bdip.text = all_text + bdip.parser.step_seq(new_text) def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: - self._update_seen_token_ids(token_ids) + # get the batch item data and parser for batch item, given provided token sequence + bdip = self._get_batch_data_item_parser(token_ids) + + self._update_seen_token_ids(bdip, token_ids) # modify logits given valid token IDs N = len(logits) From cda47115d1fedf9c713421148290601f6f6315ad Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 22 Dec 2023 05:30:29 -0600 Subject: [PATCH 39/76] bug fixes --- vllm/grammar.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 36aa5f1882c9..69b29d03c818 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -360,10 +360,10 @@ class BatchDataItemParser: text: str token_ids: List[str] parser: NextTokenValidator -) -class GrammarLogitsProcessor(NextTokenValidator): + +class GrammarLogitsProcessor: """ Apply NextTokenValidator in __call__ and set excluded tokens logits to -inf """ @@ -430,7 +430,7 @@ def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: # modify logits given valid token IDs N = len(logits) mask = torch.zeros(N, dtype=torch.bool) - valid = torch.tensor(list(self.valid_token_id_set), dtype=torch.long) + valid = torch.tensor(list(bdip.valid_token_id_set), dtype=torch.long) mask[valid] = True logits[~mask] = float('-inf') return logits From bc66e56b7d2816e8c1e0639a4f26b438251a7375 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 22 Dec 2023 05:31:16 -0600 Subject: [PATCH 40/76] bugfix --- vllm/grammar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 69b29d03c818..6838a2a5d133 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -430,7 +430,7 @@ def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: # modify logits given valid token IDs N = len(logits) mask = torch.zeros(N, dtype=torch.bool) - valid = torch.tensor(list(bdip.valid_token_id_set), dtype=torch.long) + valid = torch.tensor(list(bdip.parser.valid_token_id_set), dtype=torch.long) mask[valid] = True logits[~mask] = float('-inf') return logits From 4c9de04af2124ca497d9dc39fcb405f6bd4a027a Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 22 Dec 2023 05:36:00 -0600 Subject: [PATCH 41/76] fix test --- tests/samplers/test_grammar.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index a030e865ca43..5652c9adce90 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -335,9 +335,12 @@ def test_integration_with_vllm(vllm_runner, hf_runner): model_id = "facebook/opt-125m" dtype = "half" + tokenizer = hf_runner(model_id, dtype=dtype).tokenizer + grammar = """?start: "hello" | "world" """ + grammar_logits_processor = GrammarLogitsProcessor( - hf_runner(model_id, dtype=dtype).tokenizer, - """?start: "hello" | "world" """ + tokenizer, + grammar ) sampling_params = SamplingParams(temperature=0.01, top_p=0.1, @@ -346,9 +349,18 @@ def test_integration_with_vllm(vllm_runner, hf_runner): llm = LLM(model=model_id, max_num_batched_tokens=4096, tensor_parallel_size=1) - prompts = ["Who is the president of Jamaica?", "What is 1+1?"] - outputs = llm.generate(prompts, sampling_params=sampling_params) - print(outputs) + prompts = [ + "Who is the president of Jamaica?", + "What is 1+1?", + "Random prompt unrelated to output", + "Seriously, no matter what the prompt is..." + "it will always follow the grammar" + ] + + request_outputs = llm.generate(prompts, sampling_params=sampling_params) + assert len(request_outputs) == len(prompts) - assert False + for request_output in llm.generate(prompts, sampling_params=sampling_params): + assert len(request_output.output) == 1 + assert request_output.output[0].text in ("hello", "world") From 461d31879b140a2a1e7c90dc13c9c6af7db22f5d Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 22 Dec 2023 05:37:32 -0600 Subject: [PATCH 42/76] reduce number of prompts --- tests/samplers/test_grammar.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index 5652c9adce90..9c02fa7c1baf 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -351,8 +351,6 @@ def test_integration_with_vllm(vllm_runner, hf_runner): tensor_parallel_size=1) prompts = [ - "Who is the president of Jamaica?", - "What is 1+1?", "Random prompt unrelated to output", "Seriously, no matter what the prompt is..." "it will always follow the grammar" From 6fe8bd40b3031420cda8ef3e06d2f525a858e5ee Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 22 Dec 2023 05:37:56 -0600 Subject: [PATCH 43/76] fix test --- tests/samplers/test_grammar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index 9c02fa7c1baf..be63817159aa 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -360,5 +360,5 @@ def test_integration_with_vllm(vllm_runner, hf_runner): assert len(request_outputs) == len(prompts) for request_output in llm.generate(prompts, sampling_params=sampling_params): - assert len(request_output.output) == 1 - assert request_output.output[0].text in ("hello", "world") + assert len(request_output.outputs) == 1 + assert request_output.outputs[0].text in ("hello", "world") From 7d7f8cf7c152a4a97494cd33e716f24ce8a58659 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 22 Dec 2023 05:54:01 -0600 Subject: [PATCH 44/76] fix yarp --- tests/samplers/test_grammar.py | 6 ++---- vllm/grammar.py | 13 ++++++------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index be63817159aa..2ae3fda77400 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -222,9 +222,9 @@ def test_json_fails_with_edge_cases(tokenizer, json_grammar): example_remainder = example_remainder[len(tok):] break else: - return True + return - assert False, "Invalid json was accepted" + raise Exception("Invalid json was accepted") def test_token_trie_sanity(tokenizer): @@ -285,8 +285,6 @@ def test_random_grammared_generation(json_grammar, tokenizer): # Bias logits so open syntax such that closing syntax such as ]}", # occur more frequently as time goes on so we don't get stuck in generation - num_repeats = 8 - grammar_logits_processor = GrammarLogitsProcessor( tokenizer, json_grammar, diff --git a/vllm/grammar.py b/vllm/grammar.py index 6838a2a5d133..6c463951ae18 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -145,9 +145,9 @@ def step_seq(self, new_seq: str): success = False try: self.interactive_parser.exhaust_lexer() - except UnexpectedCharacters as e: + except UnexpectedCharacters: pass - except UnexpectedToken as e: + except UnexpectedToken: # fall back so full token can be reprocessed self.interactive_parser = self._terminal_start_parser.copy() else: @@ -182,7 +182,7 @@ def _filter_candidate_terminals(self): for incomplete_seq, terminals in self.valid_next_terminals.items(): if incomplete_seq != "": self.valid_next_terminals[incomplete_seq] = set([ - term for term in self.valid_next_terminals[incomplete_seq] + term for term in terminals if term != "$END" and self.partial_seq_validator[term](incomplete_seq) ]) @@ -206,9 +206,8 @@ def is_valid_next_seq(self, new_seq: Optional[str]): for incomplete_seq, terminals in self.valid_next_terminals.items(): candidate = incomplete_seq + new_seq for term in terminals: - if term != "$END": - if self.partial_seq_validator[term](candidate): - return True + if term != "$END" and self.partial_seq_validator[term](candidate): + return True return False @@ -399,7 +398,7 @@ def _get_batch_data_item_parser(self, token_ids: List[int]): This is generally the corresponding parser, but if there's a collision their parsers are interchangable """ - for batch_data_item_parser in sorted( + for bdip in sorted( self.batch_data_item_parsers, key=lambda bdip: -len(bdip.token_ids) ): From fcb13d582982f355449deaa9b7dab27707d9d199 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 22 Dec 2023 05:59:10 -0600 Subject: [PATCH 45/76] yarp ruff --- tests/samplers/test_grammar.py | 99 ++++++++++++++++------------------ vllm/grammar.py | 92 ++++++++++++++++--------------- 2 files changed, 94 insertions(+), 97 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index 2ae3fda77400..7109c6da073a 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -111,13 +111,9 @@ def csv_example(): """.strip() + "\n" # grammar requires newline before eos - def sample_from_logits(logits): probs = np.exp(logits) / np.sum(np.exp(logits)) - return np.random.choice( - len(logits), - p=probs - ) + return np.random.choice(len(logits), p=probs) def test_next_token_validator_simple(tokenizer): @@ -127,20 +123,19 @@ def test_next_token_validator_simple(tokenizer): ntv = NextTokenValidator(tokenizer, hello_grammar) # tokens specific to codeLlama - assert ntv.valid_token_str_set == {'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello'} - assert sorted(ntv.valid_token_id_set) == [107, 122, 354, 827, 3952, 11526, 12199, 13762, 14181, 29882, 29893] + assert ntv.valid_token_str_set == { + 'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello' + } + assert sorted(ntv.valid_token_id_set) == [ + 107, 122, 354, 827, 3952, 11526, 12199, 13762, 14181, 29882, 29893 + ] -@pytest.mark.parametrize("grammar_fixture, example_fixture", [ - ("json_grammar", "json_example"), - ("csv_grammar", "csv_example") -]) -def test_can_generate_with_grammar( - tokenizer, - request, - grammar_fixture, - example_fixture -): +@pytest.mark.parametrize("grammar_fixture, example_fixture", + [("json_grammar", "json_example"), + ("csv_grammar", "csv_example")]) +def test_can_generate_with_grammar(tokenizer, request, grammar_fixture, + example_fixture): """Assert that example file is legal to generate with NextTokenValidator""" grammar = request.getfixturevalue(grammar_fixture) example = request.getfixturevalue(example_fixture) @@ -160,7 +155,9 @@ def test_can_generate_with_grammar( example_remainder = example_remainder[len(tok):] break else: - raise Exception(f"Couldn't find token to create legal output given grammar, remaining output: '{example_remainder}'") + raise Exception( + f"Couldn't find token to create legal output given grammar, remaining output: '{example_remainder}'" + ) # EOS should be in the set of next legal tokens assert None in next_token_validator.valid_token_str_set @@ -175,7 +172,6 @@ def test_json_valid_with_edge_cases(tokenizer, json_grammar): "{\n \"\\u043a\\u043b\\u044e\\u0447\": \"\\u0437\\u043d\\u0430\\u0447\\u0435\\u043d\\u0438\\u0435\",\n \"emoji\\ud83d\\ude42\": \"value\\ud83d\\ude00\"\n}", # unicode keys ] - for example in valid_edgecase_jsons: next_token_validator = NextTokenValidator( tokenizer, @@ -191,7 +187,9 @@ def test_json_valid_with_edge_cases(tokenizer, json_grammar): example_remainder = example_remainder[len(tok):] break else: - raise Exception(f"Couldn't find token to create legal output given grammar, remaining output: '{example_remainder}'") + raise Exception( + f"Couldn't find token to create legal output given grammar, remaining output: '{example_remainder}'" + ) # EOS should be in the set of next legal tokens assert None in next_token_validator.valid_token_str_set @@ -236,11 +234,8 @@ def test_token_trie_sanity(tokenizer): assert all([len(p) == 1 for p in all_prefixes]) # every token should have one of these prefixes as a start character - assert all([ - t[0] in all_prefixes - for t in toktrie.norm_vocab - if t is not None - ]) + assert all( + [t[0] in all_prefixes for t in toktrie.norm_vocab if t is not None]) # construct the set of next level prefixes all_subprefixes = set() @@ -251,12 +246,14 @@ def test_token_trie_sanity(tokenizer): assert len(set([len(spfx) for spfx in all_subprefixes])) > 1 -@pytest.mark.parametrize("start_tok, validator", [ - (29945, float), # 5 - float - (285, lambda s: bool(json.dumps(s))), # f for false - (260, lambda s: bool(json.dumps(s))), # t for false - (376, lambda s: str(json.dumps(s))), # " for string -]) +@pytest.mark.parametrize( + "start_tok, validator", + [ + (29945, float), # 5 - float + (285, lambda s: bool(json.dumps(s))), # f for false + (260, lambda s: bool(json.dumps(s))), # t for false + (376, lambda s: str(json.dumps(s))), # " for string + ]) def test_gen_primative(json_grammar, tokenizer, start_tok, validator): # Note: string may last a for _ in range(4): @@ -268,10 +265,10 @@ def test_gen_primative(json_grammar, tokenizer, start_tok, validator): token_ids = [start_tok] while True: - logits = grammar_logits_processor( - token_ids=token_ids, - logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) - ) + logits = grammar_logits_processor(token_ids=token_ids, + logits=np.random.uniform( + -10, 10, + len(tokenizer.vocab))) new_token_id = sample_from_logits(logits) if new_token_id == tokenizer.eos_token_id: break @@ -289,12 +286,11 @@ def test_random_grammared_generation(json_grammar, tokenizer): tokenizer, json_grammar, legal_chars=set(map(chr, range(256))), - ) + ) # bias closing tokens logits to prevent infinite generation closing_token_ids = set([ - tok_id - for tok_str in ["]", "}", '"', ",", None] + tok_id for tok_str in ["]", "}", '"', ",", None] for tok_id in grammar_logits_processor.token_trie.norm_vocab[tok_str] ]) closing_tokens_bias = -10 @@ -303,18 +299,16 @@ def test_random_grammared_generation(json_grammar, tokenizer): # more tokens than these, and numbers close much more quickly, are less # gramatically complicated and result in a less interesting test opening_token_ids = set([ - tok_id - for tok_str in ["[", "{", '"', ","] + tok_id for tok_str in ["[", "{", '"', ","] for tok_id in grammar_logits_processor.token_trie.norm_vocab[tok_str] ]) opening_tokens_bias = 5 token_ids = [] while True: - logits = grammar_logits_processor( - token_ids=token_ids, - logits=np.random.uniform(-10, 10, len(tokenizer.vocab)) - ) + logits = grammar_logits_processor(token_ids=token_ids, + logits=np.random.uniform( + -10, 10, len(tokenizer.vocab))) for closing_token_id in closing_token_ids: logits[closing_token_id] += closing_tokens_bias @@ -336,14 +330,12 @@ def test_integration_with_vllm(vllm_runner, hf_runner): tokenizer = hf_runner(model_id, dtype=dtype).tokenizer grammar = """?start: "hello" | "world" """ - grammar_logits_processor = GrammarLogitsProcessor( - tokenizer, - grammar - ) - sampling_params = SamplingParams(temperature=0.01, - top_p=0.1, - max_tokens=256, - logits_processors=[grammar_logits_processor]) + grammar_logits_processor = GrammarLogitsProcessor(tokenizer, grammar) + sampling_params = SamplingParams( + temperature=0.01, + top_p=0.1, + max_tokens=256, + logits_processors=[grammar_logits_processor]) llm = LLM(model=model_id, max_num_batched_tokens=4096, tensor_parallel_size=1) @@ -357,6 +349,7 @@ def test_integration_with_vllm(vllm_runner, hf_runner): request_outputs = llm.generate(prompts, sampling_params=sampling_params) assert len(request_outputs) == len(prompts) - for request_output in llm.generate(prompts, sampling_params=sampling_params): + for request_output in llm.generate(prompts, + sampling_params=sampling_params): assert len(request_output.outputs) == 1 assert request_output.outputs[0].text in ("hello", "world") diff --git a/vllm/grammar.py b/vllm/grammar.py index 6c463951ae18..2ec1930086c0 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -56,6 +56,8 @@ def __copy__(self): copy(self.parser_state), copy(self.lexer_thread), ) + + ######################################################################### ######################################################################### @@ -160,7 +162,8 @@ def step_seq(self, new_seq: str): # if successfully parsed new token, add blank state and set fallback checkpoint if success: - self.valid_next_terminals[""] = self._accepts() | self._ignored_terms + self.valid_next_terminals[""] = self._accepts( + ) | self._ignored_terms self._terminal_start_parser = self.interactive_parser.copy() self._filter_candidate_terminals() @@ -182,8 +185,7 @@ def _filter_candidate_terminals(self): for incomplete_seq, terminals in self.valid_next_terminals.items(): if incomplete_seq != "": self.valid_next_terminals[incomplete_seq] = set([ - term for term in terminals - if term != "$END" + term for term in terminals if term != "$END" and self.partial_seq_validator[term](incomplete_seq) ]) if not self.valid_next_terminals[incomplete_seq]: @@ -206,7 +208,8 @@ def is_valid_next_seq(self, new_seq: Optional[str]): for incomplete_seq, terminals in self.valid_next_terminals.items(): candidate = incomplete_seq + new_seq for term in terminals: - if term != "$END" and self.partial_seq_validator[term](candidate): + if term != "$END" and self.partial_seq_validator[term]( + candidate): return True return False @@ -219,7 +222,8 @@ class TokenTrie: IS_TOKEN = (None, "is complete token") def __init__(self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: Union[PreTrainedTokenizer, + PreTrainedTokenizerFast], legal_chars: Optional[Set[str]] = None): self.norm_vocab = collections.defaultdict(set) for token_id in tokenizer.vocab.values(): @@ -227,9 +231,10 @@ def __init__(self, self.norm_vocab[None].add(token_id) continue bos_len = len(tokenizer.bos_token) - norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[bos_len:] + norm_token = tokenizer.decode([tokenizer.bos_token_id, + token_id])[bos_len:] if legal_chars is None or all( - [char in legal_chars for char in norm_token]): + [char in legal_chars for char in norm_token]): self.norm_vocab[norm_token].add(token_id) # faster lookups, reduce time by 10% @@ -251,11 +256,12 @@ def __init__(self, def get_next_level_token_prefixes(self, subprefix: str) -> Set[str]: if subprefix not in self._next_level_token_prefixes_cache: self._next_level_token_prefixes_cache[subprefix] = ( - self.get_next_level_token_prefixes_uncached(subprefix) - ) + self.get_next_level_token_prefixes_uncached(subprefix)) return self._next_level_token_prefixes_cache[subprefix] - def get_next_level_token_prefixes_uncached(self, subprefix: str, _node: dict = None) -> Set[str]: + def get_next_level_token_prefixes_uncached(self, + subprefix: str, + _node: dict = None) -> Set[str]: """ Traverse the trie starting from a specified subprefix to identify all child nodes that represent the longest possible strings without omitting any nodes that contain complete tokens. @@ -282,9 +288,7 @@ def get_next_level_token_prefixes_uncached(self, subprefix: str, _node: dict = N for char, next_node in _node.items(): if char != self.IS_TOKEN: results |= self.get_next_level_token_prefixes_uncached( - subprefix + char, - _node=next_node - ) + subprefix + char, _node=next_node) return results @@ -300,12 +304,14 @@ class NextTokenValidator: - step_seq(new_seq): Append a sequence, update internal states - property valid_token_str_set: The valid set of vocabulary tokens strings which can occur next """ - def __init__(self, - tokenizer, - grammar: str, - grammar_start: str = "start", - legal_chars: Optional[set[str]] = None, - ): + + def __init__( + self, + tokenizer, + grammar: str, + grammar_start: str = "start", + legal_chars: Optional[set[str]] = None, + ): self.tokenizer = tokenizer self.token_trie = TokenTrie(tokenizer, legal_chars=legal_chars) @@ -333,7 +339,8 @@ def valid_token_str_set(self): token_prefix_stack = collections.deque([""]) while token_prefix_stack: token_prefix = token_prefix_stack.pop() - for child_token_prefix in self.token_trie.get_next_level_token_prefixes(token_prefix): + for child_token_prefix in self.token_trie.get_next_level_token_prefixes( + token_prefix): if self.parser.is_valid_next_seq(child_token_prefix): token_prefix_stack.append(child_token_prefix) if self.token_trie.is_token(child_token_prefix): @@ -361,17 +368,18 @@ class BatchDataItemParser: parser: NextTokenValidator - class GrammarLogitsProcessor: """ Apply NextTokenValidator in __call__ and set excluded tokens logits to -inf """ - def __init__(self, - tokenizer, - grammar: str, - grammar_start: str = "start", - legal_chars: Optional[set[str]] = None, - ): + + def __init__( + self, + tokenizer, + grammar: str, + grammar_start: str = "start", + legal_chars: Optional[set[str]] = None, + ): self.tokenizer = tokenizer self.grammar = grammar self.grammar_start = grammar_start @@ -382,15 +390,11 @@ def __init__(self, def _new_batch_data_item_parser(self): return BatchDataItemParser( - "", - [], - NextTokenValidator( - tokenizer=self.tokenizer, - grammar=self.grammar, - grammar_start=self.grammar_start, - legal_chars=self.legal_chars - ) - ) + "", [], + NextTokenValidator(tokenizer=self.tokenizer, + grammar=self.grammar, + grammar_start=self.grammar_start, + legal_chars=self.legal_chars)) def _get_batch_data_item_parser(self, token_ids: List[int]): """ @@ -398,18 +402,16 @@ def _get_batch_data_item_parser(self, token_ids: List[int]): This is generally the corresponding parser, but if there's a collision their parsers are interchangable """ - for bdip in sorted( - self.batch_data_item_parsers, - key=lambda bdip: -len(bdip.token_ids) - ): + for bdip in sorted(self.batch_data_item_parsers, + key=lambda bdip: -len(bdip.token_ids)): if token_ids[:len(bdip.token_ids)] == bdip.token_ids: return bdip # no match, make new return self._new_batch_data_item_parser() - - def _update_seen_token_ids(self, bdip: BatchDataItemParser, token_ids: List[int]): + def _update_seen_token_ids(self, bdip: BatchDataItemParser, + token_ids: List[int]): # update batch item token tracker bdip.token_ids = token_ids @@ -420,7 +422,8 @@ def _update_seen_token_ids(self, bdip: BatchDataItemParser, token_ids: List[int] bdip.text = all_text bdip.parser.step_seq(new_text) - def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + def __call__(self, token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: # get the batch item data and parser for batch item, given provided token sequence bdip = self._get_batch_data_item_parser(token_ids) @@ -429,7 +432,8 @@ def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: # modify logits given valid token IDs N = len(logits) mask = torch.zeros(N, dtype=torch.bool) - valid = torch.tensor(list(bdip.parser.valid_token_id_set), dtype=torch.long) + valid = torch.tensor(list(bdip.parser.valid_token_id_set), + dtype=torch.long) mask[valid] = True logits[~mask] = float('-inf') return logits From e71b4ed9b283cfb7aaa4517e4077982d86c98cd1 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 22 Dec 2023 06:00:52 -0600 Subject: [PATCH 46/76] fix ruff yarp --- vllm/grammar.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 2ec1930086c0..b2a17663ead8 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -267,9 +267,8 @@ def get_next_level_token_prefixes_uncached(self, the longest possible strings without omitting any nodes that contain complete tokens. """ # cache - if _node is None: - if subprefix in self._next_level_token_prefixes_cache: - return self._next_level_token_prefixes_cache[subprefix] + if _node is None and subprefix in self._next_level_token_prefixes_cache: + return self._next_level_token_prefixes_cache[subprefix] # if not first level of recursion, and at a branching point or is a token, or return self if _node is not None and (len(_node) > 1 or self.IS_TOKEN in _node): From b11256c73d9c23875881ffaeea4a81c354bf3b3b Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 23 Dec 2023 11:20:01 -0600 Subject: [PATCH 47/76] write test case for issue noticed: can't use more than one char token inside an incomplete rule --- tests/samplers/test_grammar.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index 7109c6da073a..38456ba4699e 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -131,6 +131,15 @@ def test_next_token_validator_simple(tokenizer): ] +def test_can_span_multiple_terminals(tokenizer): + true_segmented_grammar = """ + ?start: "is" "t" "r" "u" "e" + """ + ntv = NextTokenValidator(tokenizer, true_segmented_grammar) + ntv.step_seq("is") + assert "true" in ntv.valid_token_str_set + + @pytest.mark.parametrize("grammar_fixture, example_fixture", [("json_grammar", "json_example"), ("csv_grammar", "csv_example")]) From a400fc8dad929f8455e243ec19d1cd2037094aa1 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 23 Dec 2023 11:23:16 -0600 Subject: [PATCH 48/76] use new recursive IncrementalParser instead, cleaner, fixes prev commit bug --- vllm/grammar.py | 306 ++++++++++++++++++++++++++++-------------------- 1 file changed, 180 insertions(+), 126 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index b2a17663ead8..817e6086d92e 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -1,6 +1,7 @@ import collections from copy import deepcopy, copy -from dataclasses import dataclass +from dataclasses import dataclass, fields +from functools import wraps import regex import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -9,7 +10,7 @@ from lark import Lark from lark.parsers.lalr_interactive_parser import InteractiveParser from lark.parsers.lalr_parser_state import ParserState -from lark.lexer import Pattern, PatternStr, PatternRE +from lark.lexer import Token, Pattern, PatternStr, PatternRE from lark.exceptions import UnexpectedCharacters, UnexpectedToken @@ -62,74 +63,184 @@ def __copy__(self): ######################################################################### -class InteractivePredictiveLALRParser: +def get_pattern_validator(pattern: Pattern, is_complete: bool): """ - Parser which consumes an EBNF grammar and provides helpers to determine allowable language model tokens + Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE + Returns a function which validates a partial string + + e.g. for PatternRE "abc*", returns true for "a", "ab", "abc", "abcccc" + """ + if isinstance(pattern, PatternRE): + compiled_pattern = regex.compile(pattern.value) + if is_complete: + return (lambda seq: compiled_pattern.fullmatch(seq) + is not None) + else: + return (lambda seq: compiled_pattern.fullmatch(seq, partial=True) + is not None) + elif isinstance(pattern, PatternStr): + base_str = pattern.value + if is_complete: + return (lambda seq: seq == base_str) + else: + return (lambda seq: base_str.startswith(seq)) + else: + raise TypeError(f"Invalid pattern type: {type(pattern)}") - Interfaces: - - step_seq(new_seq): Update the parser with a new sequence to append - - is_valid_next_seq(new_seq): Determine whether a candidate sequence is valid - Core components for terminal level, and sub-terminal level processing: - - 1) Lark LALR parser: Applies state transitions, determining set of valid next-terminals - - 2) Incremental terminal filter: Eliminates next-terminal candidates if terminal pattern doesn't match +def memoize_with_key(*key_attrs): + """ + Decorator for memoizing class methods based on specified instance attributes. """ + def decorator(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + # create a unique key based on the specified attributes of instance + key_elements = [getattr(self, attr) for attr in key_attrs] + key = (method.__name__, tuple(key_elements), args, tuple(sorted(kwargs.items()))) - def __init__(self, grammar: str, start: str): - self.parser = Lark( - grammar, - regex=True, # use `regex` not `re` - start=start, - parser='lalr', - cache=True, # results in 2-3x faster loading - ) - base_interactive_parser = self.parser.parse_interactive() - self.interactive_parser = FastInteractiveParser( - base_interactive_parser.parser, - base_interactive_parser.parser_state, - base_interactive_parser.lexer_thread) - - # fallback parser from start of terminal in case of ambiguous (LR(1)) - self._terminal_start_parser = self.interactive_parser.copy() - - self.partial_seq_validator = { - term.name: self._get_partial_pattern_validator(term.pattern) - for term in self.parser.terminals - } + # check if cached + if key in self._memo: + return self._memo[key] - self._ignored_terms = set(self.parser.lexer_conf.ignore) + # call + result = method(self, *args, **kwargs) + self._memo[key] = result + return result - # for calculating `accepts()` efficiently - self._accepts_cache = {} + return wrapper + return decorator - self.sequence_history = "" - # for processing terminals interactively - self.valid_next_terminals = {"": self._accepts() | self._ignored_terms} +@dataclass +class IncrementalParser: + interactive_parser: FastInteractiveParser + tokens: tuple + partial_token: str + terminal_candidates: list + _ignored_terms: set + _seq_validator: dict + _memo: dict + + @classmethod + def from_lark_parser(cls, lark_parser): + print(lark_parser.terminals) + base_interactive_parser = lark_parser.parse_interactive() + interactive_parser = FastInteractiveParser( + base_interactive_parser.parser, + base_interactive_parser.parser_state, + base_interactive_parser.lexer_thread) + interactive_parser.lexer_thread.state.text = "" + + _seq_validator = { + (term.name, "partial"): get_pattern_validator(term.pattern, is_complete=False) + for term in lark_parser.terminals + } + _seq_validator.update({ + (term.name, "complete"): get_pattern_validator(term.pattern, is_complete=True) + for term in lark_parser.terminals + }) + + _seq_validator[("$END", "partial")] = lambda seq: seq is None + _seq_validator[("$END", "complete")] = lambda seq: seq is None + + + return cls( + interactive_parser=interactive_parser, + tokens=tuple(), + partial_token="", + terminal_candidates=None, + _ignored_terms=set(lark_parser.lexer_conf.ignore), + _seq_validator=_seq_validator, + _memo={} + ) - @staticmethod - def _get_partial_pattern_validator(pattern: Pattern): - """ - Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE - Returns a function which validates a partial string + def new(self, **kwargs): + instance_dict = { + f.name: getattr(self, f.name) + for f in fields(self) + } + instance_dict.update(kwargs) + return self.__class__(**instance_dict) - e.g. for PatternRE "abc*", returns true for "a", "ab", "abc", "abcccc" + @memoize_with_key('tokens', 'partial_token') + def new_parser_for_appended_char(self, char: str): """ - if isinstance(pattern, PatternRE): - compiled_pattern = regex.compile(pattern.value) - return (lambda seq: compiled_pattern.fullmatch(seq, partial=True) - is not None) - elif isinstance(pattern, PatternStr): - base_str = pattern.value - return (lambda seq: base_str.startswith(seq)) + - Construct extended (maybe-partial) token candidate + - If no partial matches, None + - If partial matches, but not complete, + return new parser with updated partial token str and updated terminal candidates + - If complete match, reset partial token, return parser with token-updated parser state + """ + assert len(char) == 1 + + new_maybe_partial_token = self.partial_token + char + new_allowed_terminals = self.filter_terminals( + self.allowed_terminals, + new_maybe_partial_token, + require_complete=False + ) + if not new_allowed_terminals: + return None + + complete_terminals = self.filter_terminals( + self.allowed_terminals, + new_maybe_partial_token, + require_complete=True + ) + if complete_terminals: + assert len(complete_terminals) == 1 + new_token_str = next(iter(complete_terminals)) + return self.new( + interactive_parser=self.get_stepped_parser_state(new_token_str), + tokens=tuple(list(self.tokens) + [new_token_str]), + partial_token="", + terminal_candidates=None, + ) else: - raise TypeError(f"Invalid pattern type: {type(pattern)}") + return self.new( + partial_token=new_maybe_partial_token, + terminal_candidates=new_allowed_terminals, + ) + + def filter_terminals(self, checked_terminals, seq, require_complete): + validator_type = "complete" if require_complete else "partial" + return set([ + term for term in checked_terminals + if self._seq_validator[(term, validator_type)](seq) + ]) + + @memoize_with_key('tokens') + def get_stepped_parser_state(self, new_token_str): + ip = copy(self.interactive_parser) + ip.feed_token( + Token(new_token_str, '') + ) + return ip + + + @memoize_with_key('tokens') + def accepts(self): + return set(self.interactive_parser.accepts()) | self._ignored_terms + + @property + def allowed_terminals(self): + if self.terminal_candidates is not None: + return self.terminal_candidates + return self.accepts() - def _accepts(self): - if self.sequence_history not in self._accepts_cache: - accepted_terminals = self.interactive_parser.accepts() - self._accepts_cache[self.sequence_history] = accepted_terminals - return self._accepts_cache[self.sequence_history] + +class SpeculativeParser: + def __init__(self, grammar: str, start: str): + self.parser = Lark( + grammar, + regex=True, # use `regex` not `re` + start=start, + parser='lalr', + cache=True, # results in 2-3x faster loading + ) + self.incr_parser = IncrementalParser.from_lark_parser(self.parser) + self.fallback_incr_parser = copy(self.incr_parser) def step_seq(self, new_seq: str): """ @@ -140,78 +251,21 @@ def step_seq(self, new_seq: str): - Update the set of candidate terminals """ for char in new_seq: - # update canonical sequence and lexer sequence - self.sequence_history += char - self.interactive_parser.lexer_thread.state.text = self.sequence_history - - success = False - try: - self.interactive_parser.exhaust_lexer() - except UnexpectedCharacters: - pass - except UnexpectedToken: - # fall back so full token can be reprocessed - self.interactive_parser = self._terminal_start_parser.copy() + new_incr_parser = self.incr_parser.new_parser_for_appended_char(char) + if new_incr_parser is None: + self.incr_parser = self.fallback_incr_parser else: - success = True - - self.valid_next_terminals = { - (incomplete_seq + char): term - for incomplete_seq, term in self.valid_next_terminals.items() - } - - # if successfully parsed new token, add blank state and set fallback checkpoint - if success: - self.valid_next_terminals[""] = self._accepts( - ) | self._ignored_terms - self._terminal_start_parser = self.interactive_parser.copy() - - self._filter_candidate_terminals() - - if not self.valid_next_terminals: - raise ValueError( - f"Invalid continuation for `{self.sequence_history}` `{new_seq}`" - ) - - def _filter_candidate_terminals(self): - """ - Filter the set of candidate terminals - - If a new terminal is reached, get the accepted set of terminals from the parser - - If the new sequence doesn't comprise a full terminal, filter based on partial pattern match - - Handles ambiguity by allowing terminals which are potentially complete - """ - to_prune_sequences = set() - for incomplete_seq, terminals in self.valid_next_terminals.items(): - if incomplete_seq != "": - self.valid_next_terminals[incomplete_seq] = set([ - term for term in terminals if term != "$END" - and self.partial_seq_validator[term](incomplete_seq) - ]) - if not self.valid_next_terminals[incomplete_seq]: - to_prune_sequences.add(incomplete_seq) - - for to_prune_seq in to_prune_sequences: - del self.valid_next_terminals[to_prune_seq] + self.incr_parser = new_incr_parser def is_valid_next_seq(self, new_seq: Optional[str]): - """ - Check if current un-terminalized sequence + new_seq is valid for any terminal - - new_seq can be a string or None representing EOS - """ if new_seq is None: - return "$END" in [ - term for terminals in self.valid_next_terminals.values() - for term in terminals - ] - for incomplete_seq, terminals in self.valid_next_terminals.items(): - candidate = incomplete_seq + new_seq - for term in terminals: - if term != "$END" and self.partial_seq_validator[term]( - candidate): - return True - return False + return "$END" in self.incr_parser.allowed_terminals + new_incr_parser = self.incr_parser + for i, char in enumerate(new_seq): + new_incr_parser = new_incr_parser.new_parser_for_appended_char(char) + if new_incr_parser is None: + return False + return True class TokenTrie: @@ -314,8 +368,8 @@ def __init__( self.tokenizer = tokenizer self.token_trie = TokenTrie(tokenizer, legal_chars=legal_chars) - self.parser = InteractivePredictiveLALRParser(grammar=grammar, - start=grammar_start) + self.parser = SpeculativeParser(grammar=grammar, + start=grammar_start) def step_seq(self, new_seq: str): self.parser.step_seq(new_seq) From 2bf537fbc94bd40378af08ecea48b4679efc0840 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 23 Dec 2023 12:47:15 -0600 Subject: [PATCH 49/76] improve with memoization --- vllm/grammar.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 817e6086d92e..0a4fe47eae25 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -66,9 +66,9 @@ def __copy__(self): def get_pattern_validator(pattern: Pattern, is_complete: bool): """ Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE - Returns a function which validates a partial string + Returns a function which validates a complete or partial string - e.g. for PatternRE "abc*", returns true for "a", "ab", "abc", "abcccc" + e.g. for PatternRE "abc*", is_complete=False returns true for "a", "ab", "abc", "abcccc" """ if isinstance(pattern, PatternRE): compiled_pattern = regex.compile(pattern.value) @@ -89,22 +89,20 @@ def get_pattern_validator(pattern: Pattern, is_complete: bool): def memoize_with_key(*key_attrs): - """ - Decorator for memoizing class methods based on specified instance attributes. - """ def decorator(method): + mname = method.__name__ @wraps(method) - def wrapper(self, *args, **kwargs): - # create a unique key based on the specified attributes of instance - key_elements = [getattr(self, attr) for attr in key_attrs] - key = (method.__name__, tuple(key_elements), args, tuple(sorted(kwargs.items()))) + def wrapper(self, *args): + # Construct a simple key from key attributes and method arguments + key_elements = tuple(getattr(self, attr, None) for attr in key_attrs) + key = (mname, key_elements, args) - # check if cached + # Check cache for existing result if key in self._memo: return self._memo[key] - # call - result = method(self, *args, **kwargs) + # Call the method and store the result + result = method(self, *args) self._memo[key] = result return result @@ -115,7 +113,7 @@ def wrapper(self, *args, **kwargs): @dataclass class IncrementalParser: interactive_parser: FastInteractiveParser - tokens: tuple + tokens_key: str # "\n" separated, for caching purposes partial_token: str terminal_candidates: list _ignored_terms: set @@ -147,7 +145,7 @@ def from_lark_parser(cls, lark_parser): return cls( interactive_parser=interactive_parser, - tokens=tuple(), + tokens_key="", partial_token="", terminal_candidates=None, _ignored_terms=set(lark_parser.lexer_conf.ignore), @@ -163,7 +161,7 @@ def new(self, **kwargs): instance_dict.update(kwargs) return self.__class__(**instance_dict) - @memoize_with_key('tokens', 'partial_token') + @memoize_with_key('tokens_key', 'partial_token') def new_parser_for_appended_char(self, char: str): """ - Construct extended (maybe-partial) token candidate @@ -193,7 +191,7 @@ def new_parser_for_appended_char(self, char: str): new_token_str = next(iter(complete_terminals)) return self.new( interactive_parser=self.get_stepped_parser_state(new_token_str), - tokens=tuple(list(self.tokens) + [new_token_str]), + tokens_key=self.tokens_key + "\n" + new_token_str, partial_token="", terminal_candidates=None, ) @@ -210,7 +208,7 @@ def filter_terminals(self, checked_terminals, seq, require_complete): if self._seq_validator[(term, validator_type)](seq) ]) - @memoize_with_key('tokens') + @memoize_with_key('tokens_key') def get_stepped_parser_state(self, new_token_str): ip = copy(self.interactive_parser) ip.feed_token( @@ -218,8 +216,7 @@ def get_stepped_parser_state(self, new_token_str): ) return ip - - @memoize_with_key('tokens') + @memoize_with_key('tokens_key') def accepts(self): return set(self.interactive_parser.accepts()) | self._ignored_terms From 4d6eeb2b7c998b85015aba34478149831c9023fe Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sun, 24 Dec 2023 07:44:22 -0600 Subject: [PATCH 50/76] much better memoization + bugfixes --- vllm/grammar.py | 200 +++++++++++++++++++++++++++++++----------------- 1 file changed, 128 insertions(+), 72 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 0a4fe47eae25..352786b0911b 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -73,56 +73,80 @@ def get_pattern_validator(pattern: Pattern, is_complete: bool): if isinstance(pattern, PatternRE): compiled_pattern = regex.compile(pattern.value) if is_complete: - return (lambda seq: compiled_pattern.fullmatch(seq) - is not None) + # False: No match + # int: match length + def get_fullmatch_length(seq): + r = compiled_pattern.match(seq) + if r is None or not r.spans(): + return None + spans = r.spans()[0] + return spans[1] - spans[0] + return get_fullmatch_length else: return (lambda seq: compiled_pattern.fullmatch(seq, partial=True) is not None) elif isinstance(pattern, PatternStr): base_str = pattern.value if is_complete: - return (lambda seq: seq == base_str) + def get_strmatch_length(seq): + if not seq.startswith(base_str): + return None + return len(base_str) + return get_strmatch_length else: return (lambda seq: base_str.startswith(seq)) else: raise TypeError(f"Invalid pattern type: {type(pattern)}") -def memoize_with_key(*key_attrs): - def decorator(method): - mname = method.__name__ - @wraps(method) - def wrapper(self, *args): - # Construct a simple key from key attributes and method arguments - key_elements = tuple(getattr(self, attr, None) for attr in key_attrs) - key = (mname, key_elements, args) +def memoize_by_instance(method): + """ + Memoize by id(self) and fn args + """ + mname = method.__name__ + @wraps(method) + def wrapper(self, *args): + key = (mname, id(self), args) + if key in self._memo: + return self._memo[key] + result = method(self, *args) + self._memo[key] = result + return result - # Check cache for existing result - if key in self._memo: - return self._memo[key] + return wrapper - # Call the method and store the result - result = method(self, *args) - self._memo[key] = result - return result - return wrapper - return decorator +@dataclass +class IncrementalParserState: + """ + Parsing utility which tracks state provided + - sequence of prior terminal ids + - incomplete `partial_token` string + the set of prior terminal_ids and a partial token comprise a unique parser state + Core function exposed is `self.new_parser_for_appended(new_seq)` + - Returns a new IncrementalParserState based with new_seq applied -@dataclass -class IncrementalParser: - interactive_parser: FastInteractiveParser - tokens_key: str # "\n" separated, for caching purposes + Memoization strategy is + - 1) Ensure uniqueness of (prior_terminal_ids, partial_token) + - 2) Cache class methods via `memoize_by_instance` which considers id(self) and fn arguments + """ + + # unique state key + prior_terminal_ids: tuple[str] partial_token: str + + # function of key + interactive_parser: FastInteractiveParser terminal_candidates: list + + # shared across instances _ignored_terms: set _seq_validator: dict _memo: dict @classmethod def from_lark_parser(cls, lark_parser): - print(lark_parser.terminals) base_interactive_parser = lark_parser.parse_interactive() interactive_parser = FastInteractiveParser( base_interactive_parser.parser, @@ -145,7 +169,7 @@ def from_lark_parser(cls, lark_parser): return cls( interactive_parser=interactive_parser, - tokens_key="", + prior_terminal_ids=tuple(), partial_token="", terminal_candidates=None, _ignored_terms=set(lark_parser.lexer_conf.ignore), @@ -153,62 +177,95 @@ def from_lark_parser(cls, lark_parser): _memo={} ) - def new(self, **kwargs): + def new(self, prior_terminal_ids, partial_token, **kwargs): + # cache + key = (prior_terminal_ids, partial_token) + if key in self._memo: + return self._memo[key] instance_dict = { f.name: getattr(self, f.name) for f in fields(self) } instance_dict.update(kwargs) - return self.__class__(**instance_dict) + inst = self.__class__(**instance_dict) + self._memo[key] = inst + return inst - @memoize_with_key('tokens_key', 'partial_token') - def new_parser_for_appended_char(self, char: str): + @memoize_by_instance + def new_parser_for_appended(self, new_seq: str): """ - Construct extended (maybe-partial) token candidate - - If no partial matches, None - - If partial matches, but not complete, + - If complete match, create new-terminal incremented parser state + - there is leftover from new_seq, recurse on the new parser + - If partial matches, return new parser with updated partial token str and updated terminal candidates - - If complete match, reset partial token, return parser with token-updated parser state + - If no partial matches, return None """ - assert len(char) == 1 - - new_maybe_partial_token = self.partial_token + char - new_allowed_terminals = self.filter_terminals( - self.allowed_terminals, - new_maybe_partial_token, - require_complete=False - ) - if not new_allowed_terminals: - return None + new_maybe_partial_token = self.partial_token + new_seq - complete_terminals = self.filter_terminals( - self.allowed_terminals, + complete_terminal = self.get_complete_terminal( + tuple(sorted(self.allowed_terminals)), new_maybe_partial_token, - require_complete=True ) - if complete_terminals: - assert len(complete_terminals) == 1 - new_token_str = next(iter(complete_terminals)) - return self.new( - interactive_parser=self.get_stepped_parser_state(new_token_str), - tokens_key=self.tokens_key + "\n" + new_token_str, + if complete_terminal is not None: + terminal_name = complete_terminal["terminal_id"] + if terminal_name in self._ignored_terms: + new_interactive_parser = self.interactive_parser + else: + new_interactive_parser = self.get_stepped_parser_state(terminal_name) + new_parser = self.new( + interactive_parser=new_interactive_parser, + prior_terminal_ids=tuple(list(self.prior_terminal_ids) + [terminal_name]), partial_token="", terminal_candidates=None, ) - else: + + leftover = len(new_seq) - complete_terminal["match_length"] + if leftover: + return new_parser.new_parser_for_appended( + new_maybe_partial_token[-leftover:] + ) + else: + return new_parser + + partial_terminal_ids = self.get_partial_terminal_ids( + tuple(sorted(self.allowed_terminals)), + new_maybe_partial_token, + ) + if partial_terminal_ids: return self.new( + prior_terminal_ids=self.prior_terminal_ids, partial_token=new_maybe_partial_token, - terminal_candidates=new_allowed_terminals, + terminal_candidates=partial_terminal_ids, ) + else: + return None - def filter_terminals(self, checked_terminals, seq, require_complete): - validator_type = "complete" if require_complete else "partial" + + def get_complete_terminal(self, checked_terminals, seq): + terminal_matchlens = { + term: self._seq_validator[(term, "complete")](seq) + for term in checked_terminals + } + terminal_matchlens = {term: ml for term, ml in terminal_matchlens.items() if ml} + if not terminal_matchlens: + return None + if len(terminal_matchlens) > 1: + terminal_matchlens = { + t: ml for t, ml in terminal_matchlens.items() + if t not in self._ignored_terms + } + assert len(terminal_matchlens) == 1 + result = next(iter(terminal_matchlens.items())) + return {"terminal_id": result[0], "match_length": result[1]} + + def get_partial_terminal_ids(self, checked_terminals, seq): return set([ term for term in checked_terminals - if self._seq_validator[(term, validator_type)](seq) + if self._seq_validator[(term, "partial")](seq) ]) - @memoize_with_key('tokens_key') + @memoize_by_instance def get_stepped_parser_state(self, new_token_str): ip = copy(self.interactive_parser) ip.feed_token( @@ -216,11 +273,12 @@ def get_stepped_parser_state(self, new_token_str): ) return ip - @memoize_with_key('tokens_key') + @memoize_by_instance def accepts(self): return set(self.interactive_parser.accepts()) | self._ignored_terms @property + @memoize_by_instance def allowed_terminals(self): if self.terminal_candidates is not None: return self.terminal_candidates @@ -236,8 +294,8 @@ def __init__(self, grammar: str, start: str): parser='lalr', cache=True, # results in 2-3x faster loading ) - self.incr_parser = IncrementalParser.from_lark_parser(self.parser) - self.fallback_incr_parser = copy(self.incr_parser) + self.incr_parser = IncrementalParserState.from_lark_parser(self.parser) + self.fallback_incr_parser = self.incr_parser def step_seq(self, new_seq: str): """ @@ -247,21 +305,18 @@ def step_seq(self, new_seq: str): - Update the character position of the last complete terminal - Update the set of candidate terminals """ - for char in new_seq: - new_incr_parser = self.incr_parser.new_parser_for_appended_char(char) - if new_incr_parser is None: - self.incr_parser = self.fallback_incr_parser - else: - self.incr_parser = new_incr_parser + new_incr_parser = self.incr_parser.new_parser_for_appended(new_seq) + if new_incr_parser is None: + self.incr_parser = self.fallback_incr_parser + else: + self.incr_parser = new_incr_parser def is_valid_next_seq(self, new_seq: Optional[str]): if new_seq is None: return "$END" in self.incr_parser.allowed_terminals - new_incr_parser = self.incr_parser - for i, char in enumerate(new_seq): - new_incr_parser = new_incr_parser.new_parser_for_appended_char(char) - if new_incr_parser is None: - return False + return self.incr_parser.new_parser_for_appended(new_seq) is not None + if new_incr_parser is None: + return False return True @@ -288,6 +343,7 @@ def __init__(self, [char in legal_chars for char in norm_token]): self.norm_vocab[norm_token].add(token_id) + # faster lookups, reduce time by 10% self.norm_vocab_set = set(self.norm_vocab) From 988e95f705520e9b2954330e9e5460e1640ce1e9 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sun, 24 Dec 2023 07:56:18 -0600 Subject: [PATCH 51/76] don't use a trie, with memoization it's inefficient --- vllm/grammar.py | 93 +++++++------------------------------------------ 1 file changed, 12 insertions(+), 81 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 352786b0911b..36779d1149a5 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -320,17 +320,16 @@ def is_valid_next_seq(self, new_seq: Optional[str]): return True -class TokenTrie: +class TokenVocab: """ - Trie structure for efficiently finding tokens which are suffixes of other sequences + Normalized token vocabulary accounting for whitespace and multiple IDs per token """ - IS_TOKEN = (None, "is complete token") - def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], legal_chars: Optional[Set[str]] = None): + self.norm_vocab = collections.defaultdict(set) for token_id in tokenizer.vocab.values(): if token_id == tokenizer.eos_token_id: @@ -343,63 +342,11 @@ def __init__(self, [char in legal_chars for char in norm_token]): self.norm_vocab[norm_token].add(token_id) + def __iter__(self): + return iter(self.norm_vocab) - # faster lookups, reduce time by 10% - self.norm_vocab_set = set(self.norm_vocab) - - self.trie = {} - for word in self.norm_vocab: - current_dict = self.trie - if word is None: - continue - for char in word: - if char not in current_dict: - current_dict[char] = {} - current_dict = current_dict[char] - current_dict[self.IS_TOKEN] = True - - self._next_level_token_prefixes_cache = {} - - def get_next_level_token_prefixes(self, subprefix: str) -> Set[str]: - if subprefix not in self._next_level_token_prefixes_cache: - self._next_level_token_prefixes_cache[subprefix] = ( - self.get_next_level_token_prefixes_uncached(subprefix)) - return self._next_level_token_prefixes_cache[subprefix] - - def get_next_level_token_prefixes_uncached(self, - subprefix: str, - _node: dict = None) -> Set[str]: - """ - Traverse the trie starting from a specified subprefix to identify all child nodes that represent - the longest possible strings without omitting any nodes that contain complete tokens. - """ - # cache - if _node is None and subprefix in self._next_level_token_prefixes_cache: - return self._next_level_token_prefixes_cache[subprefix] - - # if not first level of recursion, and at a branching point or is a token, or return self - if _node is not None and (len(_node) > 1 or self.IS_TOKEN in _node): - return {subprefix} - - # get the current node if at the first level of recursion - if _node is None: - _node = self.trie - for char in subprefix: - if char not in _node: - return set() - _node = _node[char] - - # Single child, need to go deeper - results = set() - for char, next_node in _node.items(): - if char != self.IS_TOKEN: - results |= self.get_next_level_token_prefixes_uncached( - subprefix + char, _node=next_node) - - return results - - def is_token(self, seq: Optional[str]) -> bool: - return seq in self.norm_vocab_set + def __get__(self, tok_str): + return self.norm_vocab[tok_str] class NextTokenValidator: @@ -419,7 +366,7 @@ def __init__( legal_chars: Optional[set[str]] = None, ): self.tokenizer = tokenizer - self.token_trie = TokenTrie(tokenizer, legal_chars=legal_chars) + self.vocab = TokenVocab(tokenizer, legal_chars=legal_chars) self.parser = SpeculativeParser(grammar=grammar, start=grammar_start) @@ -431,27 +378,11 @@ def step_seq(self, new_seq: str): def valid_token_str_set(self): """ Generate the set of valid tokens given the current sequence - - 1) Push all first level token prefixes to the stack - 2) for each token in the stack, validate against the parser - - if valid, add all children to the stack for later processing - - if valid AND a token, add to valid_token_set - - TODO: this can be improved with multi-threading """ valid_token_str_set = set() - if self.parser.is_valid_next_seq(None): - valid_token_str_set.add(None) - token_prefix_stack = collections.deque([""]) - while token_prefix_stack: - token_prefix = token_prefix_stack.pop() - for child_token_prefix in self.token_trie.get_next_level_token_prefixes( - token_prefix): - if self.parser.is_valid_next_seq(child_token_prefix): - token_prefix_stack.append(child_token_prefix) - if self.token_trie.is_token(child_token_prefix): - valid_token_str_set.add(child_token_prefix) - + for tok in self.vocab: + if self.parser.is_valid_next_seq(tok): + valid_token_str_set.add(tok) return valid_token_str_set @property @@ -461,7 +392,7 @@ def valid_token_id_set(self): note that some token strings correspond to multiple token IDs """ return set.union(*[ - self.token_trie.norm_vocab[tok_str] + self.vocab[tok_str] for tok_str in self.valid_token_str_set ]) From 8db40b278cd40a3ad4cc990ed072f8d40cc3c562 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sun, 24 Dec 2023 07:59:37 -0600 Subject: [PATCH 52/76] improve performance 15% with generator --- vllm/grammar.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 36779d1149a5..be3b546d1faf 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -379,11 +379,9 @@ def valid_token_str_set(self): """ Generate the set of valid tokens given the current sequence """ - valid_token_str_set = set() for tok in self.vocab: if self.parser.is_valid_next_seq(tok): - valid_token_str_set.add(tok) - return valid_token_str_set + yield tok @property def valid_token_id_set(self): From 40f94386be5ba5d06c84a40a5b89d75a472bdab4 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sun, 24 Dec 2023 08:51:21 -0600 Subject: [PATCH 53/76] refactor to be cleaner, faster, stateless --- vllm/grammar.py | 169 +++++++++++------------------------------------- 1 file changed, 39 insertions(+), 130 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index be3b546d1faf..072fa810b59c 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -1,7 +1,7 @@ import collections from copy import deepcopy, copy from dataclasses import dataclass, fields -from functools import wraps +from functools import wraps, lru_cache import regex import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -124,7 +124,7 @@ class IncrementalParserState: - incomplete `partial_token` string the set of prior terminal_ids and a partial token comprise a unique parser state - Core function exposed is `self.new_parser_for_appended(new_seq)` + Core function exposed is `self.step_seq(new_seq)` - Returns a new IncrementalParserState based with new_seq applied Memoization strategy is @@ -145,8 +145,17 @@ class IncrementalParserState: _seq_validator: dict _memo: dict + @classmethod - def from_lark_parser(cls, lark_parser): + @lru_cache(1000) + def from_grammar(cls, grammar: str, start: str): + lark_parser = Lark( + grammar, + regex=True, # use `regex` not `re` + start=start, + parser='lalr', + cache=True, # results in 2-3x faster loading + ) base_interactive_parser = lark_parser.parse_interactive() interactive_parser = FastInteractiveParser( base_interactive_parser.parser, @@ -192,7 +201,7 @@ def new(self, prior_terminal_ids, partial_token, **kwargs): return inst @memoize_by_instance - def new_parser_for_appended(self, new_seq: str): + def step_seq(self, new_seq: str): """ - Construct extended (maybe-partial) token candidate - If complete match, create new-terminal incremented parser state @@ -222,7 +231,7 @@ def new_parser_for_appended(self, new_seq: str): leftover = len(new_seq) - complete_terminal["match_length"] if leftover: - return new_parser.new_parser_for_appended( + return new_parser.step_seq( new_maybe_partial_token[-leftover:] ) else: @@ -284,45 +293,17 @@ def allowed_terminals(self): return self.terminal_candidates return self.accepts() - -class SpeculativeParser: - def __init__(self, grammar: str, start: str): - self.parser = Lark( - grammar, - regex=True, # use `regex` not `re` - start=start, - parser='lalr', - cache=True, # results in 2-3x faster loading - ) - self.incr_parser = IncrementalParserState.from_lark_parser(self.parser) - self.fallback_incr_parser = self.incr_parser - - def step_seq(self, new_seq: str): - """ - Append sequence to parser and apply state updates - - Append the sequence to the canonical self.sequence_history - - Parse the changes - - Update the character position of the last complete terminal - - Update the set of candidate terminals - """ - new_incr_parser = self.incr_parser.new_parser_for_appended(new_seq) - if new_incr_parser is None: - self.incr_parser = self.fallback_incr_parser - else: - self.incr_parser = new_incr_parser - def is_valid_next_seq(self, new_seq: Optional[str]): if new_seq is None: - return "$END" in self.incr_parser.allowed_terminals - return self.incr_parser.new_parser_for_appended(new_seq) is not None - if new_incr_parser is None: - return False - return True + return "$END" in self.allowed_terminals + return self.step_seq(new_seq) is not None class TokenVocab: """ Normalized token vocabulary accounting for whitespace and multiple IDs per token + - iter: iterate over normalized token strings + - vocab[token_str]: return token id set """ def __init__(self, @@ -345,19 +326,11 @@ def __init__(self, def __iter__(self): return iter(self.norm_vocab) - def __get__(self, tok_str): + def __getitem__(self, tok_str): return self.norm_vocab[tok_str] class NextTokenValidator: - """ - Given a grammar and a tokenset, construct a parser and token trie. - - Interface: - - step_seq(new_seq): Append a sequence, update internal states - - property valid_token_str_set: The valid set of vocabulary tokens strings which can occur next - """ - def __init__( self, tokenizer, @@ -368,107 +341,43 @@ def __init__( self.tokenizer = tokenizer self.vocab = TokenVocab(tokenizer, legal_chars=legal_chars) - self.parser = SpeculativeParser(grammar=grammar, - start=grammar_start) - - def step_seq(self, new_seq: str): - self.parser.step_seq(new_seq) + self.root_parser = IncrementalParserState.from_grammar( + grammar, + grammar_start + ) - @property - def valid_token_str_set(self): + def get_valid_next_token_strs(self, full_seq): """ - Generate the set of valid tokens given the current sequence + Generate valid token strings given the full sequence """ - for tok in self.vocab: - if self.parser.is_valid_next_seq(tok): - yield tok + parser = self.root_parser.step_seq(full_seq) + for tok_str in self.vocab: + if parser.is_valid_next_seq(tok_str): + yield tok_str - @property - def valid_token_id_set(self): + + def get_valid_next_token_ids(self, full_seq): """ - get valid token id based on self.valid_token_str_set - note that some token strings correspond to multiple token IDs + Generate valid token ids given the full sequence """ - return set.union(*[ - self.vocab[tok_str] - for tok_str in self.valid_token_str_set - ]) + for tok_str in self.get_valid_next_token_strs(full_seq): + yield from self.vocab[tok_str] -# TODO: replace with subclass called NextTokenIDValidator to make things cleaner -@dataclass -class BatchDataItemParser: - text: str - token_ids: List[str] - parser: NextTokenValidator - - -class GrammarLogitsProcessor: +class GrammarLogitsProcessor(NextTokenValidator): """ Apply NextTokenValidator in __call__ and set excluded tokens logits to -inf """ - - def __init__( - self, - tokenizer, - grammar: str, - grammar_start: str = "start", - legal_chars: Optional[set[str]] = None, - ): - self.tokenizer = tokenizer - self.grammar = grammar - self.grammar_start = grammar_start - self.legal_chars = legal_chars - - # track multiple parsers for batch requests - self.batch_data_item_parsers: List[BatchDataItemParser] = [] - - def _new_batch_data_item_parser(self): - return BatchDataItemParser( - "", [], - NextTokenValidator(tokenizer=self.tokenizer, - grammar=self.grammar, - grammar_start=self.grammar_start, - legal_chars=self.legal_chars)) - - def _get_batch_data_item_parser(self, token_ids: List[int]): - """ - Get longest batch data item parser which matches the seen tokens. - This is generally the corresponding parser, but if there's a collision - their parsers are interchangable - """ - for bdip in sorted(self.batch_data_item_parsers, - key=lambda bdip: -len(bdip.token_ids)): - if token_ids[:len(bdip.token_ids)] == bdip.token_ids: - return bdip - - # no match, make new - return self._new_batch_data_item_parser() - - def _update_seen_token_ids(self, bdip: BatchDataItemParser, - token_ids: List[int]): - - # update batch item token tracker - bdip.token_ids = token_ids - - # step forward - all_text = self.tokenizer.decode(token_ids) - new_text = all_text[len(bdip.text):] - bdip.text = all_text - bdip.parser.step_seq(new_text) - def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: - # get the batch item data and parser for batch item, given provided token sequence - bdip = self._get_batch_data_item_parser(token_ids) - - self._update_seen_token_ids(bdip, token_ids) + # get valid token IDs given prior tokens + sequence = self.tokenizer.decode(token_ids) + valid_token_ids = self.get_valid_next_token_id_set(sequence) # modify logits given valid token IDs N = len(logits) mask = torch.zeros(N, dtype=torch.bool) - valid = torch.tensor(list(bdip.parser.valid_token_id_set), - dtype=torch.long) + valid = torch.tensor(valid_token_ids, dtype=torch.long) mask[valid] = True logits[~mask] = float('-inf') return logits From db143ba7d42151ddecda664bdb29a9e685b13cdd Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sun, 24 Dec 2023 09:54:04 -0600 Subject: [PATCH 54/76] do lookup rather than recursive cache check to prevent recursion limit excess --- vllm/grammar.py | 94 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 16 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 072fa810b59c..3f16108a37fc 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -116,6 +116,42 @@ def wrapper(self, *args): return wrapper +class TrieNode: + def __init__(self): + self.children = {} + self.is_end_of_word = False + self.value = None + +class Trie: + def __init__(self): + self.root = TrieNode() + + def __setitem__(self, key, value): + node = self.root + for char in key: + if char not in node.children: + node.children[char] = TrieNode() + node = node.children[char] + node.is_end_of_word = True + node.value = value + + def get_best(self, word): + node = self.root + prefix = "" + best_value = node.value + for char in word: + if char in node.children: + prefix += char + node = node.children[char] + if node.is_end_of_word: + best_value = node.value + else: + break # break if char not in trie + remainder = word[len(prefix):] + assert best_value is not None + return prefix, best_value, remainder + + @dataclass class IncrementalParserState: """ @@ -136,6 +172,9 @@ class IncrementalParserState: prior_terminal_ids: tuple[str] partial_token: str + # orthogonal unique state key + full_seq: str + # function of key interactive_parser: FastInteractiveParser terminal_candidates: list @@ -144,6 +183,7 @@ class IncrementalParserState: _ignored_terms: set _seq_validator: dict _memo: dict + _full_seq_trie: Trie @classmethod @@ -176,30 +216,48 @@ def from_grammar(cls, grammar: str, start: str): _seq_validator[("$END", "complete")] = lambda seq: seq is None - return cls( + parser = cls( interactive_parser=interactive_parser, prior_terminal_ids=tuple(), + full_seq="", partial_token="", terminal_candidates=None, _ignored_terms=set(lark_parser.lexer_conf.ignore), _seq_validator=_seq_validator, - _memo={} + _memo={}, + _full_seq_trie=Trie() ) + parser._full_seq_trie[""] = parser + return parser - def new(self, prior_terminal_ids, partial_token, **kwargs): - # cache - key = (prior_terminal_ids, partial_token) - if key in self._memo: - return self._memo[key] + def new(self, prior_terminal_ids, partial_token, full_seq, **kwargs): + # check cache + term_key = (prior_terminal_ids, partial_token) + if term_key in self._memo: + return self._memo[term_key] + + # create instance_dict = { f.name: getattr(self, f.name) for f in fields(self) } instance_dict.update(kwargs) inst = self.__class__(**instance_dict) - self._memo[key] = inst + + # update cache + self._memo[term_key] = inst + self._full_seq_trie[full_seq] = inst return inst + def __getitem__(self, full_seq): + """ + Get the parser of a full sequence + """ + match_seq, parser, remainder_seq = self._full_seq_trie.get_best(full_seq) + if remainder_seq: + parser = parser.step_seq(remainder_seq) + return parser + @memoize_by_instance def step_seq(self, new_seq: str): """ @@ -222,18 +280,21 @@ def step_seq(self, new_seq: str): new_interactive_parser = self.interactive_parser else: new_interactive_parser = self.get_stepped_parser_state(terminal_name) + + ml = complete_terminal["match_length"] + remainder_seq = new_seq[ml:] + processed_seq = new_seq[:ml] + new_parser = self.new( + full_seq=self.full_seq + processed_seq, interactive_parser=new_interactive_parser, prior_terminal_ids=tuple(list(self.prior_terminal_ids) + [terminal_name]), partial_token="", terminal_candidates=None, ) - leftover = len(new_seq) - complete_terminal["match_length"] - if leftover: - return new_parser.step_seq( - new_maybe_partial_token[-leftover:] - ) + if remainder_seq: + return new_parser.step_seq(remainder_seq) else: return new_parser @@ -243,6 +304,7 @@ def step_seq(self, new_seq: str): ) if partial_terminal_ids: return self.new( + full_seq=self.full_seq + new_seq, prior_terminal_ids=self.prior_terminal_ids, partial_token=new_maybe_partial_token, terminal_candidates=partial_terminal_ids, @@ -350,7 +412,7 @@ def get_valid_next_token_strs(self, full_seq): """ Generate valid token strings given the full sequence """ - parser = self.root_parser.step_seq(full_seq) + parser = self.root_parser[full_seq] for tok_str in self.vocab: if parser.is_valid_next_seq(tok_str): yield tok_str @@ -372,12 +434,12 @@ def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: # get valid token IDs given prior tokens sequence = self.tokenizer.decode(token_ids) - valid_token_ids = self.get_valid_next_token_id_set(sequence) + valid_token_ids = self.get_valid_next_token_ids(sequence) + valid = torch.tensor(list(valid_token_ids), dtype=torch.long) # modify logits given valid token IDs N = len(logits) mask = torch.zeros(N, dtype=torch.bool) - valid = torch.tensor(valid_token_ids, dtype=torch.long) mask[valid] = True logits[~mask] = float('-inf') return logits From 40df2cfddaf93fb2db8ee1a72b505df3ed4583f8 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sun, 24 Dec 2023 18:26:45 -0600 Subject: [PATCH 55/76] refactor: cleaner, faster --- vllm/grammar.py | 186 ++++++++++++++++++++++++++---------------------- 1 file changed, 99 insertions(+), 87 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 3f16108a37fc..a56add0543f7 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -63,38 +63,58 @@ def __copy__(self): ######################################################################### -def get_pattern_validator(pattern: Pattern, is_complete: bool): + + +def get_pattern_validator(pattern: Pattern): """ Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE Returns a function which validates a complete or partial string e.g. for PatternRE "abc*", is_complete=False returns true for "a", "ab", "abc", "abcccc" + + Returns Tuple with 2 values + - 0) The processed sequence + - 1) None if doesn't complete terminal, "" if completes terminal with no remainder, or "remainder" """ if isinstance(pattern, PatternRE): compiled_pattern = regex.compile(pattern.value) - if is_complete: - # False: No match - # int: match length - def get_fullmatch_length(seq): - r = compiled_pattern.match(seq) - if r is None or not r.spans(): - return None - spans = r.spans()[0] - return spans[1] - spans[0] - return get_fullmatch_length - else: - return (lambda seq: compiled_pattern.fullmatch(seq, partial=True) - is not None) + @lru_cache(int(1e6)) + def get_re_matched_parts(seq): + # match complete terminal, potentially with leftover seq + complete_terminal_match = compiled_pattern.match(seq) + if complete_terminal_match: + spans = complete_terminal_match.spans() + if spans: + span = complete_terminal_match.spans()[0] + if span[0] == 0: + processed_seq = seq[:span[1]] + remainder_seq = seq[span[1]:] + return processed_seq, remainder_seq + + # match doesn't complete terminal, but the sequence is fully allowed + partial_terminal_match = compiled_pattern.fullmatch(seq, partial=True) + if partial_terminal_match: + return seq, None + + return None, None + + return get_re_matched_parts + elif isinstance(pattern, PatternStr): base_str = pattern.value - if is_complete: - def get_strmatch_length(seq): - if not seq.startswith(base_str): - return None - return len(base_str) - return get_strmatch_length - else: - return (lambda seq: base_str.startswith(seq)) + @lru_cache(int(1e6)) + def get_str_matched_parts(seq): + if seq.startswith(base_str): + processed_seq = seq[:len(base_str)] + remainder_seq = seq[len(base_str):] + return processed_seq, remainder_seq + elif base_str.startswith(seq): + return seq, None + else: + return None, None + + return get_str_matched_parts + else: raise TypeError(f"Invalid pattern type: {type(pattern)}") @@ -126,7 +146,7 @@ class Trie: def __init__(self): self.root = TrieNode() - def __setitem__(self, key, value): + def insert(self, key, value): node = self.root for char in key: if char not in node.children: @@ -147,6 +167,9 @@ def get_best(self, word): best_value = node.value else: break # break if char not in trie + if node.is_end_of_word: + best_value = node.value + remainder = word[len(prefix):] assert best_value is not None return prefix, best_value, remainder @@ -204,16 +227,10 @@ def from_grammar(cls, grammar: str, start: str): interactive_parser.lexer_thread.state.text = "" _seq_validator = { - (term.name, "partial"): get_pattern_validator(term.pattern, is_complete=False) + (term.name): get_pattern_validator(term.pattern) for term in lark_parser.terminals } - _seq_validator.update({ - (term.name, "complete"): get_pattern_validator(term.pattern, is_complete=True) - for term in lark_parser.terminals - }) - - _seq_validator[("$END", "partial")] = lambda seq: seq is None - _seq_validator[("$END", "complete")] = lambda seq: seq is None + _seq_validator["$END"] = lambda seq: tuple(["" if seq is None else None] * 2) parser = cls( @@ -227,26 +244,22 @@ def from_grammar(cls, grammar: str, start: str): _memo={}, _full_seq_trie=Trie() ) - parser._full_seq_trie[""] = parser + parser._full_seq_trie.insert("", parser) return parser - def new(self, prior_terminal_ids, partial_token, full_seq, **kwargs): - # check cache - term_key = (prior_terminal_ids, partial_token) + def new(self, **kwargs): + term_key = (kwargs["prior_terminal_ids"], kwargs["partial_token"]) if term_key in self._memo: return self._memo[term_key] - # create instance_dict = { f.name: getattr(self, f.name) for f in fields(self) } instance_dict.update(kwargs) inst = self.__class__(**instance_dict) - - # update cache self._memo[term_key] = inst - self._full_seq_trie[full_seq] = inst + return inst def __getitem__(self, full_seq): @@ -256,6 +269,7 @@ def __getitem__(self, full_seq): match_seq, parser, remainder_seq = self._full_seq_trie.get_best(full_seq) if remainder_seq: parser = parser.step_seq(remainder_seq) + self._full_seq_trie.insert(full_seq, parser) return parser @memoize_by_instance @@ -268,72 +282,67 @@ def step_seq(self, new_seq: str): return new parser with updated partial token str and updated terminal candidates - If no partial matches, return None """ + if new_seq == "": + return self + new_maybe_partial_token = self.partial_token + new_seq - complete_terminal = self.get_complete_terminal( - tuple(sorted(self.allowed_terminals)), - new_maybe_partial_token, + best_terminal, processed_seq, remainder_seq = self.get_best_matched_terminal( + self.allowed_terminals, + new_maybe_partial_token ) - if complete_terminal is not None: - terminal_name = complete_terminal["terminal_id"] - if terminal_name in self._ignored_terms: + + if best_terminal is None: + return None + + # candidate doesn't complete terminal + if remainder_seq is None: + partial_terminal_ids = self.get_partial_terminal_ids( + self.allowed_terminals, + new_maybe_partial_token, + ) + return self.new( + full_seq=self.full_seq + new_seq, + prior_terminal_ids=self.prior_terminal_ids, + partial_token=new_maybe_partial_token, + terminal_candidates=partial_terminal_ids, + ) + + # terminal completes rule + else: + if best_terminal in self._ignored_terms: new_interactive_parser = self.interactive_parser else: - new_interactive_parser = self.get_stepped_parser_state(terminal_name) - - ml = complete_terminal["match_length"] - remainder_seq = new_seq[ml:] - processed_seq = new_seq[:ml] + new_interactive_parser = self.get_stepped_parser_state(best_terminal) new_parser = self.new( - full_seq=self.full_seq + processed_seq, + full_seq=self.full_seq[:-len(self.partial_token)] + processed_seq, interactive_parser=new_interactive_parser, - prior_terminal_ids=tuple(list(self.prior_terminal_ids) + [terminal_name]), + prior_terminal_ids=hash((self.prior_terminal_ids, best_terminal)), partial_token="", terminal_candidates=None, ) - if remainder_seq: - return new_parser.step_seq(remainder_seq) - else: + # no leftover to process + if remainder_seq == "": return new_parser - partial_terminal_ids = self.get_partial_terminal_ids( - tuple(sorted(self.allowed_terminals)), - new_maybe_partial_token, - ) - if partial_terminal_ids: - return self.new( - full_seq=self.full_seq + new_seq, - prior_terminal_ids=self.prior_terminal_ids, - partial_token=new_maybe_partial_token, - terminal_candidates=partial_terminal_ids, - ) - else: - return None + # process remainder + else: + return new_parser.step_seq(remainder_seq) + def get_best_matched_terminal(self, checked_terminals, seq): + for terminal in checked_terminals: + processed_seq, remainder_seq = self._seq_validator[terminal](seq) + if processed_seq: + return terminal, processed_seq, remainder_seq - def get_complete_terminal(self, checked_terminals, seq): - terminal_matchlens = { - term: self._seq_validator[(term, "complete")](seq) - for term in checked_terminals - } - terminal_matchlens = {term: ml for term, ml in terminal_matchlens.items() if ml} - if not terminal_matchlens: - return None - if len(terminal_matchlens) > 1: - terminal_matchlens = { - t: ml for t, ml in terminal_matchlens.items() - if t not in self._ignored_terms - } - assert len(terminal_matchlens) == 1 - result = next(iter(terminal_matchlens.items())) - return {"terminal_id": result[0], "match_length": result[1]} + return None, None, None def get_partial_terminal_ids(self, checked_terminals, seq): return set([ term for term in checked_terminals - if self._seq_validator[(term, "partial")](seq) + if self._seq_validator[term](seq)[0] is not None ]) @memoize_by_instance @@ -352,9 +361,10 @@ def accepts(self): @memoize_by_instance def allowed_terminals(self): if self.terminal_candidates is not None: - return self.terminal_candidates - return self.accepts() + return tuple(sorted(self.terminal_candidates)) + return tuple(sorted(self.accepts())) + @memoize_by_instance def is_valid_next_seq(self, new_seq: Optional[str]): if new_seq is None: return "$END" in self.allowed_terminals @@ -413,6 +423,8 @@ def get_valid_next_token_strs(self, full_seq): Generate valid token strings given the full sequence """ parser = self.root_parser[full_seq] + if parser is None: + return for tok_str in self.vocab: if parser.is_valid_next_seq(tok_str): yield tok_str From aee2a8ef06652e7a384fa9430c494471e3981353 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 25 Dec 2023 13:54:45 -0600 Subject: [PATCH 56/76] bug fixes and test fixes --- tests/samplers/test_grammar.py | 159 +++++++++++++++------------------ vllm/grammar.py | 35 +++++--- 2 files changed, 95 insertions(+), 99 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index 38456ba4699e..ca2eb3430caf 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer -from vllm.grammar import TokenTrie, NextTokenValidator, GrammarLogitsProcessor +from vllm.grammar import NextTokenValidator, GrammarLogitsProcessor from vllm import LLM, SamplingParams @@ -34,9 +34,9 @@ def json_grammar(): string : "\"" escaped_string_char* "\"" escaped_string_char: _STR_INNER_CHAR | _ESCAPED_CHAR - _ESCAPED_CHAR: "\\" _ANY_CHAR + _ESCAPED_CHAR: "\\" _ESCAPABLE_CHAR _STR_INNER_CHAR: /[^\\\"]/ - _ANY_CHAR: /./ + _ESCAPABLE_CHAR: /[\\\/bfnrtu]/ signed_number: ["+"|"-"] number number: float | int @@ -48,7 +48,7 @@ def json_grammar(): DIGIT: "0".."9" WS: /[ \t\f\r\n]/ - """ + """.strip() @pytest.fixture @@ -121,12 +121,14 @@ def test_next_token_validator_simple(tokenizer): ?start: "hello" | "world" """ ntv = NextTokenValidator(tokenizer, hello_grammar) + valid_next_str_set = set(ntv.get_valid_next_token_strs("")) + valid_next_id_set = set(ntv.get_valid_next_token_ids("")) # tokens specific to codeLlama - assert ntv.valid_token_str_set == { + assert valid_next_str_set == { 'wo', 'hell', 'h', 'he', 'hel', 'world', 'wor', 'w', 'hello' } - assert sorted(ntv.valid_token_id_set) == [ + assert sorted(valid_next_id_set) == [ 107, 122, 354, 827, 3952, 11526, 12199, 13762, 14181, 29882, 29893 ] @@ -136,8 +138,7 @@ def test_can_span_multiple_terminals(tokenizer): ?start: "is" "t" "r" "u" "e" """ ntv = NextTokenValidator(tokenizer, true_segmented_grammar) - ntv.step_seq("is") - assert "true" in ntv.valid_token_str_set + assert "true" in set(ntv.get_valid_next_token_strs("is")) @pytest.mark.parametrize("grammar_fixture, example_fixture", @@ -155,12 +156,13 @@ def test_can_generate_with_grammar(tokenizer, request, grammar_fixture, legal_chars=set(map(chr, range(256))), ) example_remainder = example + generation = "" while example_remainder: - for tok in next_token_validator.valid_token_str_set: + for tok in next_token_validator.get_valid_next_token_strs(generation): if tok is None: continue if example_remainder.startswith(tok): - next_token_validator.step_seq(tok) + generation += tok example_remainder = example_remainder[len(tok):] break else: @@ -169,99 +171,82 @@ def test_can_generate_with_grammar(tokenizer, request, grammar_fixture, ) # EOS should be in the set of next legal tokens - assert None in next_token_validator.valid_token_str_set + assert None in next_token_validator.get_valid_next_token_strs(generation) -def test_json_valid_with_edge_cases(tokenizer, json_grammar): - valid_edgecase_jsons = [ +@pytest.mark.parametrize( + "json_example", + [ "{\n \"emptyObject\": {\n \"innerEmptyObject\": {}\n }\n}", # empty obj "{\n \"mixedArray\": [null, 123, \"text\", true, {\"key\": \"value\"}]\n}", # mixed array "{\n \"deepArray\": [[[[[\"deep\"]]]]]\n}", # deeply nested list "{\n \"\": true,\n \"regularKey\": false\n}", # empty keys "{\n \"\\u043a\\u043b\\u044e\\u0447\": \"\\u0437\\u043d\\u0430\\u0447\\u0435\\u043d\\u0438\\u0435\",\n \"emoji\\ud83d\\ude42\": \"value\\ud83d\\ude00\"\n}", # unicode keys ] +) +def test_json_valid_with_edge_cases(tokenizer, json_grammar, json_example): + next_token_validator = NextTokenValidator( + tokenizer, + json_grammar, + ) + example_remainder = json_example + generation = "" + while example_remainder: + for tok in next_token_validator.get_valid_next_token_strs(generation): + if tok is None: + continue + if example_remainder.startswith(tok): + generation += tok + example_remainder = example_remainder[len(tok):] + break + else: + raise Exception( + f"Couldn't find token to create legal output given grammar, remaining output: '{example_remainder}'" + ) - for example in valid_edgecase_jsons: - next_token_validator = NextTokenValidator( - tokenizer, - json_grammar, - ) - example_remainder = example - while example_remainder: - for tok in next_token_validator.valid_token_str_set: - if tok is None: - continue - if example_remainder.startswith(tok): - next_token_validator.step_seq(tok) - example_remainder = example_remainder[len(tok):] - break - else: - raise Exception( - f"Couldn't find token to create legal output given grammar, remaining output: '{example_remainder}'" - ) - - # EOS should be in the set of next legal tokens - assert None in next_token_validator.valid_token_str_set - - -def test_json_fails_with_edge_cases(tokenizer, json_grammar): - invalid_edgecase_jsons = [ + # EOS should be in the set of next legal tokens + assert None in next_token_validator.get_valid_next_token_strs(generation) + + +@pytest.mark.parametrize( + "json_example", + [ "{\n \"key1\": \"value1\",\n \"key2\": \"value2\",\n}", # trailing comma "{\n \"key\": \"value\" // This is a comment\n}\n", # comment "{\n \"number\": 1.2.3\n}", # incorrect decimal format "{\n \"key\": \"value\"unexpected\"\n}", # incorrect str format - "{\n \"object\": {\"key\": \"value\"}\n}\n", # unclosed object + "{\n \"object\": {\"key\": \"value\", }\n}", # trailing comma "{\n \"array\": [1, 2,, 3]\n}\n", # double comma ] +) +def test_json_fails_with_edge_cases(tokenizer, json_grammar, json_example): + next_token_validator = NextTokenValidator( + tokenizer, + json_grammar, + ) + example_remainder = json_example + generation = "" + while example_remainder: + for tok in next_token_validator.get_valid_next_token_strs(generation): + if tok is None: + continue + if example_remainder.startswith(tok): + generation += tok + example_remainder = example_remainder[len(tok):] + break + else: + return - for example in invalid_edgecase_jsons: - next_token_validator = NextTokenValidator( - tokenizer, - json_grammar, - ) - example_remainder = example - while example_remainder: - for tok in next_token_validator.valid_token_str_set: - if tok is None: - continue - if example_remainder.startswith(tok): - next_token_validator.step_seq(tok) - example_remainder = example_remainder[len(tok):] - break - else: - return - - raise Exception("Invalid json was accepted") - - -def test_token_trie_sanity(tokenizer): - toktrie = TokenTrie(tokenizer) - - all_prefixes = toktrie.get_next_level_token_prefixes("") - - # every token should be composable from a single unique char, so they will all be len of 1 - assert all([len(p) == 1 for p in all_prefixes]) - - # every token should have one of these prefixes as a start character - assert all( - [t[0] in all_prefixes for t in toktrie.norm_vocab if t is not None]) - - # construct the set of next level prefixes - all_subprefixes = set() - for pfx in all_prefixes: - all_subprefixes |= toktrie.get_next_level_token_prefixes(pfx) - - # these should have varying length because some tokens don't have level-2 prefixes - assert len(set([len(spfx) for spfx in all_subprefixes])) > 1 + raise Exception(f"Invalid json was accepted: '{json_example}'") @pytest.mark.parametrize( "start_tok, validator", [ - (29945, float), # 5 - float - (285, lambda s: bool(json.dumps(s))), # f for false - (260, lambda s: bool(json.dumps(s))), # t for false - (376, lambda s: str(json.dumps(s))), # " for string + ("5", float), # 5 - float + ("f", lambda s: bool(json.dumps(s))), # f for false + ("t", lambda s: bool(json.dumps(s))), # t for false + ('"', lambda s: str(json.dumps(s))), # " for string ]) def test_gen_primative(json_grammar, tokenizer, start_tok, validator): # Note: string may last a @@ -272,7 +257,9 @@ def test_gen_primative(json_grammar, tokenizer, start_tok, validator): legal_chars=set(map(chr, range(256))), ) - token_ids = [start_tok] + start_token_id = list(grammar_logits_processor.vocab[start_tok])[0] + + token_ids = [start_token_id] while True: logits = grammar_logits_processor(token_ids=token_ids, logits=np.random.uniform( @@ -300,7 +287,7 @@ def test_random_grammared_generation(json_grammar, tokenizer): # bias closing tokens logits to prevent infinite generation closing_token_ids = set([ tok_id for tok_str in ["]", "}", '"', ",", None] - for tok_id in grammar_logits_processor.token_trie.norm_vocab[tok_str] + for tok_id in grammar_logits_processor.vocab[tok_str] ]) closing_tokens_bias = -10 @@ -309,7 +296,7 @@ def test_random_grammared_generation(json_grammar, tokenizer): # gramatically complicated and result in a less interesting test opening_token_ids = set([ tok_id for tok_str in ["[", "{", '"', ","] - for tok_id in grammar_logits_processor.token_trie.norm_vocab[tok_str] + for tok_id in grammar_logits_processor.vocab[tok_str] ]) opening_tokens_bias = 5 @@ -337,7 +324,7 @@ def test_integration_with_vllm(vllm_runner, hf_runner): dtype = "half" tokenizer = hf_runner(model_id, dtype=dtype).tokenizer - grammar = """?start: "hello" | "world" """ + grammar = ""?start: "hello" | "world" "" grammar_logits_processor = GrammarLogitsProcessor(tokenizer, grammar) sampling_params = SamplingParams( diff --git a/vllm/grammar.py b/vllm/grammar.py index a56add0543f7..fb815ee50a5b 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -158,21 +158,20 @@ def insert(self, key, value): def get_best(self, word): node = self.root prefix = "" + best_prefix = "" best_value = node.value for char in word: if char in node.children: - prefix += char node = node.children[char] + prefix += char if node.is_end_of_word: + best_prefix = prefix best_value = node.value else: break # break if char not in trie - if node.is_end_of_word: - best_value = node.value - remainder = word[len(prefix):] - assert best_value is not None - return prefix, best_value, remainder + remainder = word[len(best_prefix):] + return best_prefix, best_value, remainder @dataclass @@ -183,7 +182,7 @@ class IncrementalParserState: - incomplete `partial_token` string the set of prior terminal_ids and a partial token comprise a unique parser state - Core function exposed is `self.step_seq(new_seq)` + Core function exposed is `self.step(new_seq)` - Returns a new IncrementalParserState based with new_seq applied Memoization strategy is @@ -208,6 +207,10 @@ class IncrementalParserState: _memo: dict _full_seq_trie: Trie + def __repr__(self): + shown = ["partial_token", "full_seq", "terminal_candidates", "prior_terminal_ids"] + attrs_str = ", ".join(f"{s}={repr(getattr(self, s))}" for s in shown) + return f"{self.__class__.__name__}({attrs_str})" @classmethod @lru_cache(1000) @@ -267,13 +270,15 @@ def __getitem__(self, full_seq): Get the parser of a full sequence """ match_seq, parser, remainder_seq = self._full_seq_trie.get_best(full_seq) + if parser is None: + return if remainder_seq: - parser = parser.step_seq(remainder_seq) + parser = parser.step(remainder_seq) self._full_seq_trie.insert(full_seq, parser) return parser @memoize_by_instance - def step_seq(self, new_seq: str): + def step(self, new_seq: str): """ - Construct extended (maybe-partial) token candidate - If complete match, create new-terminal incremented parser state @@ -291,7 +296,6 @@ def step_seq(self, new_seq: str): self.allowed_terminals, new_maybe_partial_token ) - if best_terminal is None: return None @@ -315,8 +319,13 @@ def step_seq(self, new_seq: str): else: new_interactive_parser = self.get_stepped_parser_state(best_terminal) + if self.partial_token: + base_seq = self.full_seq[:-len(self.partial_token)] + else: + base_seq = self.full_seq + new_parser = self.new( - full_seq=self.full_seq[:-len(self.partial_token)] + processed_seq, + full_seq=base_seq + processed_seq, interactive_parser=new_interactive_parser, prior_terminal_ids=hash((self.prior_terminal_ids, best_terminal)), partial_token="", @@ -329,7 +338,7 @@ def step_seq(self, new_seq: str): # process remainder else: - return new_parser.step_seq(remainder_seq) + return new_parser.step(remainder_seq) def get_best_matched_terminal(self, checked_terminals, seq): for terminal in checked_terminals: @@ -368,7 +377,7 @@ def allowed_terminals(self): def is_valid_next_seq(self, new_seq: Optional[str]): if new_seq is None: return "$END" in self.allowed_terminals - return self.step_seq(new_seq) is not None + return self.step(new_seq) is not None class TokenVocab: From 8fd06fa3bef18d290447c25b896041c07906be0f Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Mon, 25 Dec 2023 22:36:00 -0600 Subject: [PATCH 57/76] memoize instances where a unique instance is defined by (state stack, partial terminal) --- tests/samplers/test_grammar.py | 2 +- vllm/grammar.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index ca2eb3430caf..314e184527cb 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -324,7 +324,7 @@ def test_integration_with_vllm(vllm_runner, hf_runner): dtype = "half" tokenizer = hf_runner(model_id, dtype=dtype).tokenizer - grammar = ""?start: "hello" | "world" "" + grammar = """?start: "hello" | "world" """ grammar_logits_processor = GrammarLogitsProcessor(tokenizer, grammar) sampling_params = SamplingParams( diff --git a/vllm/grammar.py b/vllm/grammar.py index fb815ee50a5b..a5b499fa4069 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -50,6 +50,12 @@ def __init__(self, *args, **kwargs): self.parser_state.state_stack, self.parser_state.value_stack, ) + self.hash_val = None + + def __hash__(self): + if self.hash_val is None: + self.hash_val = hash(tuple(self.parser_state.state_stack)) + return self.hash_val def __copy__(self): return type(self)( @@ -251,9 +257,9 @@ def from_grammar(cls, grammar: str, start: str): return parser def new(self, **kwargs): - term_key = (kwargs["prior_terminal_ids"], kwargs["partial_token"]) - if term_key in self._memo: - return self._memo[term_key] + parser_state_key = (hash(kwargs["interactive_parser"]), kwargs["partial_token"]) + if parser_state_key in self._memo: + return self._memo[parser_state_key] instance_dict = { f.name: getattr(self, f.name) @@ -261,7 +267,8 @@ def new(self, **kwargs): } instance_dict.update(kwargs) inst = self.__class__(**instance_dict) - self._memo[term_key] = inst + + self._memo[parser_state_key] = inst return inst @@ -306,6 +313,7 @@ def step(self, new_seq: str): new_maybe_partial_token, ) return self.new( + interactive_parser=self.interactive_parser, full_seq=self.full_seq + new_seq, prior_terminal_ids=self.prior_terminal_ids, partial_token=new_maybe_partial_token, From ab51de98994c524eaa8698abb1f97df33a3d4639 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 26 Dec 2023 10:20:21 -0600 Subject: [PATCH 58/76] update comments --- vllm/grammar.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index a5b499fa4069..bc98a2a4c198 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -76,10 +76,8 @@ def get_pattern_validator(pattern: Pattern): Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE Returns a function which validates a complete or partial string - e.g. for PatternRE "abc*", is_complete=False returns true for "a", "ab", "abc", "abcccc" - Returns Tuple with 2 values - - 0) The processed sequence + - 0) The processed portion of the sequence (None if no match at all) - 1) None if doesn't complete terminal, "" if completes terminal with no remainder, or "remainder" """ if isinstance(pattern, PatternRE): @@ -183,28 +181,28 @@ def get_best(self, word): @dataclass class IncrementalParserState: """ - Parsing utility which tracks state provided - - sequence of prior terminal ids + Parsing utility which enforces uniqueness of + - interactive parser state stack - incomplete `partial_token` string - the set of prior terminal_ids and a partial token comprise a unique parser state + the state of the parser and the incomplete token comprise a unique parser state Core function exposed is `self.step(new_seq)` - Returns a new IncrementalParserState based with new_seq applied Memoization strategy is - - 1) Ensure uniqueness of (prior_terminal_ids, partial_token) + - 1) Ensure uniqueness of (interactive_parser, partial_token) - 2) Cache class methods via `memoize_by_instance` which considers id(self) and fn arguments """ # unique state key - prior_terminal_ids: tuple[str] + interactive_parser: FastInteractiveParser partial_token: str # orthogonal unique state key full_seq: str # function of key - interactive_parser: FastInteractiveParser + prior_terminal_ids: tuple[str] terminal_candidates: list # shared across instances @@ -257,6 +255,7 @@ def from_grammar(cls, grammar: str, start: str): return parser def new(self, **kwargs): + """Cached create now state""" parser_state_key = (hash(kwargs["interactive_parser"]), kwargs["partial_token"]) if parser_state_key in self._memo: return self._memo[parser_state_key] @@ -273,9 +272,7 @@ def new(self, **kwargs): return inst def __getitem__(self, full_seq): - """ - Get the parser of a full sequence - """ + """Get the parser state, given a full sequence""" match_seq, parser, remainder_seq = self._full_seq_trie.get_best(full_seq) if parser is None: return From eb4b7a635eaeaa7382abf080a48ee2aea9f1b867 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 26 Dec 2023 12:19:14 -0600 Subject: [PATCH 59/76] update docs --- docs/source/grammars/grammars.rst | 94 +++++++++++++++++++++++-------- 1 file changed, 69 insertions(+), 25 deletions(-) diff --git a/docs/source/grammars/grammars.rst b/docs/source/grammars/grammars.rst index 4c66a0454210..8b035868bd81 100644 --- a/docs/source/grammars/grammars.rst +++ b/docs/source/grammars/grammars.rst @@ -1,9 +1,13 @@ .. _grammars: +.. toctree:: + :maxdepth: 2 + + Grammars ======== -vLLM offers `Lark `_ based EBNF grammars via ``vllm.grammar.GrammarLogitsProcessor``. +vLLM offers `Lark `_ style EBNF grammars via ``vllm.grammar.GrammarLogitsProcessor``. ``GrammarLogitsProcessor`` ensures generated text follows the rules of a grammar. This provides the ability to guarantee your output is syntactically valid JSON, SQL, Python, RegEx, etc. @@ -12,11 +16,13 @@ Sample Code for JSON .. code-block:: python - json_grammar = """ - ?value: dict + json_grammar = r""" + start: value + value: WS* object WS* + object: dict | list | string - | SIGNED_NUMBER -> number + | signed_number -> number | "true" -> true | "false" -> false | "null" -> null @@ -24,12 +30,24 @@ Sample Code for JSON list : "[" [value ("," value)*] "]" dict : "{" [pair ("," pair)*] "}" - pair : string ":" value + pair : WS* string WS* ":" value + + string : "\"" escaped_string_char* "\"" + escaped_string_char: _STR_INNER_CHAR | _ESCAPED_CHAR + _ESCAPED_CHAR: "\\" _ESCAPABLE_CHAR + _STR_INNER_CHAR: /[^\\\"]/ + _ESCAPABLE_CHAR: /[\\\/bfnrtu]/ - string : ESCAPED_STRING + signed_number: ["+"|"-"] number + number: float | int + float: int exp | decimal exp? + decimal: int "." int? | "." int + exp: ("e"|"E") signed_int + signed_int: ["+"|"-"] int + int: DIGIT+ + DIGIT: "0".."9" - %import common.ESCAPED_STRING - %import common.SIGNED_NUMBER + WS: /[ \t\f\r\n]/ """ grammar_logits_processor = GrammarLogitsProcessor( tokenizer, @@ -38,27 +56,18 @@ Sample Code for JSON ) SamplingParams(logits_processor=grammar_logits_processor) -Resources ---------- - -- `How to write an EBNF grammar for Lark `_ -- `Wikipedia - EBNF `_ -- `Wikipedia - LALR Parser `_ - -Example Lark Grammars ---------------------- - -- `JSON `_ -- `Python3 `_ -- `Resource with many grammars including SQLite, TOML, YAML, Lua, and more `_ Performance ----------- -For a simple JSON grammar, on the authors mid-end laptop using codeLlama-7b's vocabulary, generation occurred at ~10 validated logit sets per second. However performance was improved dramatically from the baseline with a few tweaks to ~400/s. These tweaks include +For the provided JSON grammar in the subsection below, constrained to only keyboard characters, on the authors mid-end laptop using codeLlama-7b's vocabulary, generation occurred at the following rates: -- Optimizing the grammar -- Constraining legal characters +- first 10 tokens: 3.47 tokens / second +- first 100 tokens: 8.61 tokens / second +- first 1,000 tokens: 14.41 tokens / second +- first 10,000 tokens: 23.80 tokens / second + +There is a "warmup" period where token legality is cached based on parser state. The first generation and first tokens within that generation are the slowest. **Design your EBNF grammar with minimal regexp** @@ -76,7 +85,7 @@ Breaking down the following expressions ESCAPE_STRING into an expression with ma | "true" -> true | "false" -> false | "null" -> null - +python parser test case list : "[" [value ("," value)*] "]" dict : "{" [pair ("," pair)*] "}" @@ -125,3 +134,38 @@ Expect increased performance if you constrain your generation to UTF-8, eliminat grammar, legal_chars=set(map(chr, range(256))),, ) + +Example 2: constrain the grammar to the set of keyboard typeable characters: + +.. code-block:: + + def keyboard_chars(): + keyboard_chars = "" + keyboard_chars += "abcdefghijklmnopqrstuvwxyz" + keyboard_chars += "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + keyboard_chars += "0123456789" + keyboard_chars += "`~!@#$%^&*()-_=+[{]}\\|;:'\",<.>/? " + keyboard_chars += "\t\n" + return keyboard_chars + GrammarLogitsProcessor( + tokenizer, + grammar, + legal_chars=set(keyboard_chars()), + ) + + +Resources +--------- + +- `How to write an EBNF grammar for Lark `_ +- `Wikipedia - EBNF `_ +- `Wikipedia - LALR Parser `_ + +Example Lark Grammars +--------------------- + +Note: These grammars should + +- `JSON `_ +- `Python3 `_ +- `Resource with many grammars including SQLite, TOML, YAML, Lua, and more `_ From af59f2c6657044807b3c0ca11f027ad76b28f6b7 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 26 Dec 2023 12:20:44 -0600 Subject: [PATCH 60/76] try this for docs? --- docs/source/grammars/grammars.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/grammars/grammars.rst b/docs/source/grammars/grammars.rst index 8b035868bd81..4b8f891e5b65 100644 --- a/docs/source/grammars/grammars.rst +++ b/docs/source/grammars/grammars.rst @@ -1,9 +1,10 @@ -.. _grammars: - .. toctree:: :maxdepth: 2 +.. _grammars: + + Grammars ======== From 6eb38207cb3073d2ef57d8dceb9308cd1dd6cb16 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 26 Dec 2023 12:21:32 -0600 Subject: [PATCH 61/76] try adding toc --- docs/source/grammars/grammars.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/grammars/grammars.rst b/docs/source/grammars/grammars.rst index 4b8f891e5b65..949d620acdaf 100644 --- a/docs/source/grammars/grammars.rst +++ b/docs/source/grammars/grammars.rst @@ -1,5 +1,5 @@ -.. toctree:: - :maxdepth: 2 +.. contents:: Table of Contents + :depth: 3 .. _grammars: From befb92ca0b048f260ac908faa9952e03b2f25f6e Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 26 Dec 2023 13:28:05 -0600 Subject: [PATCH 62/76] ruff yapf --- tests/samplers/test_grammar.py | 6 +-- vllm/grammar.py | 87 +++++++++++++++++----------------- 2 files changed, 45 insertions(+), 48 deletions(-) diff --git a/tests/samplers/test_grammar.py b/tests/samplers/test_grammar.py index 314e184527cb..b71ab522cf60 100644 --- a/tests/samplers/test_grammar.py +++ b/tests/samplers/test_grammar.py @@ -182,8 +182,7 @@ def test_can_generate_with_grammar(tokenizer, request, grammar_fixture, "{\n \"deepArray\": [[[[[\"deep\"]]]]]\n}", # deeply nested list "{\n \"\": true,\n \"regularKey\": false\n}", # empty keys "{\n \"\\u043a\\u043b\\u044e\\u0447\": \"\\u0437\\u043d\\u0430\\u0447\\u0435\\u043d\\u0438\\u0435\",\n \"emoji\\ud83d\\ude42\": \"value\\ud83d\\ude00\"\n}", # unicode keys - ] -) + ]) def test_json_valid_with_edge_cases(tokenizer, json_grammar, json_example): next_token_validator = NextTokenValidator( tokenizer, @@ -217,8 +216,7 @@ def test_json_valid_with_edge_cases(tokenizer, json_grammar, json_example): "{\n \"key\": \"value\"unexpected\"\n}", # incorrect str format "{\n \"object\": {\"key\": \"value\", }\n}", # trailing comma "{\n \"array\": [1, 2,, 3]\n}\n", # double comma - ] -) + ]) def test_json_fails_with_edge_cases(tokenizer, json_grammar, json_example): next_token_validator = NextTokenValidator( tokenizer, diff --git a/vllm/grammar.py b/vllm/grammar.py index bc98a2a4c198..cd04d44d716c 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -11,7 +11,6 @@ from lark.parsers.lalr_interactive_parser import InteractiveParser from lark.parsers.lalr_parser_state import ParserState from lark.lexer import Token, Pattern, PatternStr, PatternRE -from lark.exceptions import UnexpectedCharacters, UnexpectedToken ######################################################################### @@ -69,8 +68,6 @@ def __copy__(self): ######################################################################### - - def get_pattern_validator(pattern: Pattern): """ Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE @@ -82,6 +79,7 @@ def get_pattern_validator(pattern: Pattern): """ if isinstance(pattern, PatternRE): compiled_pattern = regex.compile(pattern.value) + @lru_cache(int(1e6)) def get_re_matched_parts(seq): # match complete terminal, potentially with leftover seq @@ -96,7 +94,8 @@ def get_re_matched_parts(seq): return processed_seq, remainder_seq # match doesn't complete terminal, but the sequence is fully allowed - partial_terminal_match = compiled_pattern.fullmatch(seq, partial=True) + partial_terminal_match = compiled_pattern.fullmatch(seq, + partial=True) if partial_terminal_match: return seq, None @@ -106,6 +105,7 @@ def get_re_matched_parts(seq): elif isinstance(pattern, PatternStr): base_str = pattern.value + @lru_cache(int(1e6)) def get_str_matched_parts(seq): if seq.startswith(base_str): @@ -128,6 +128,7 @@ def memoize_by_instance(method): Memoize by id(self) and fn args """ mname = method.__name__ + @wraps(method) def wrapper(self, *args): key = (mname, id(self), args) @@ -141,12 +142,15 @@ def wrapper(self, *args): class TrieNode: + def __init__(self): self.children = {} self.is_end_of_word = False self.value = None + class Trie: + def __init__(self): self.root = TrieNode() @@ -212,7 +216,10 @@ class IncrementalParserState: _full_seq_trie: Trie def __repr__(self): - shown = ["partial_token", "full_seq", "terminal_candidates", "prior_terminal_ids"] + shown = [ + "partial_token", "full_seq", "terminal_candidates", + "prior_terminal_ids" + ] attrs_str = ", ".join(f"{s}={repr(getattr(self, s))}" for s in shown) return f"{self.__class__.__name__}({attrs_str})" @@ -228,42 +235,36 @@ def from_grammar(cls, grammar: str, start: str): ) base_interactive_parser = lark_parser.parse_interactive() interactive_parser = FastInteractiveParser( - base_interactive_parser.parser, - base_interactive_parser.parser_state, - base_interactive_parser.lexer_thread) + base_interactive_parser.parser, + base_interactive_parser.parser_state, + base_interactive_parser.lexer_thread) interactive_parser.lexer_thread.state.text = "" - _seq_validator = { - (term.name): get_pattern_validator(term.pattern) - for term in lark_parser.terminals - } - _seq_validator["$END"] = lambda seq: tuple(["" if seq is None else None] * 2) - - - parser = cls( - interactive_parser=interactive_parser, - prior_terminal_ids=tuple(), - full_seq="", - partial_token="", - terminal_candidates=None, - _ignored_terms=set(lark_parser.lexer_conf.ignore), - _seq_validator=_seq_validator, - _memo={}, - _full_seq_trie=Trie() - ) + _seq_validator = {(term.name): get_pattern_validator(term.pattern) + for term in lark_parser.terminals} + _seq_validator["$END"] = lambda seq: tuple( + ["" if seq is None else None] * 2) + + parser = cls(interactive_parser=interactive_parser, + prior_terminal_ids=tuple(), + full_seq="", + partial_token="", + terminal_candidates=None, + _ignored_terms=set(lark_parser.lexer_conf.ignore), + _seq_validator=_seq_validator, + _memo={}, + _full_seq_trie=Trie()) parser._full_seq_trie.insert("", parser) return parser def new(self, **kwargs): """Cached create now state""" - parser_state_key = (hash(kwargs["interactive_parser"]), kwargs["partial_token"]) + parser_state_key = (hash(kwargs["interactive_parser"]), + kwargs["partial_token"]) if parser_state_key in self._memo: return self._memo[parser_state_key] - instance_dict = { - f.name: getattr(self, f.name) - for f in fields(self) - } + instance_dict = {f.name: getattr(self, f.name) for f in fields(self)} instance_dict.update(kwargs) inst = self.__class__(**instance_dict) @@ -273,7 +274,8 @@ def new(self, **kwargs): def __getitem__(self, full_seq): """Get the parser state, given a full sequence""" - match_seq, parser, remainder_seq = self._full_seq_trie.get_best(full_seq) + match_seq, parser, remainder_seq = self._full_seq_trie.get_best( + full_seq) if parser is None: return if remainder_seq: @@ -297,9 +299,7 @@ def step(self, new_seq: str): new_maybe_partial_token = self.partial_token + new_seq best_terminal, processed_seq, remainder_seq = self.get_best_matched_terminal( - self.allowed_terminals, - new_maybe_partial_token - ) + self.allowed_terminals, new_maybe_partial_token) if best_terminal is None: return None @@ -322,7 +322,8 @@ def step(self, new_seq: str): if best_terminal in self._ignored_terms: new_interactive_parser = self.interactive_parser else: - new_interactive_parser = self.get_stepped_parser_state(best_terminal) + new_interactive_parser = self.get_stepped_parser_state( + best_terminal) if self.partial_token: base_seq = self.full_seq[:-len(self.partial_token)] @@ -332,7 +333,8 @@ def step(self, new_seq: str): new_parser = self.new( full_seq=base_seq + processed_seq, interactive_parser=new_interactive_parser, - prior_terminal_ids=hash((self.prior_terminal_ids, best_terminal)), + prior_terminal_ids=hash( + (self.prior_terminal_ids, best_terminal)), partial_token="", terminal_candidates=None, ) @@ -362,9 +364,7 @@ def get_partial_terminal_ids(self, checked_terminals, seq): @memoize_by_instance def get_stepped_parser_state(self, new_token_str): ip = copy(self.interactive_parser) - ip.feed_token( - Token(new_token_str, '') - ) + ip.feed_token(Token(new_token_str, '')) return ip @memoize_by_instance @@ -417,6 +417,7 @@ def __getitem__(self, tok_str): class NextTokenValidator: + def __init__( self, tokenizer, @@ -428,9 +429,7 @@ def __init__( self.vocab = TokenVocab(tokenizer, legal_chars=legal_chars) self.root_parser = IncrementalParserState.from_grammar( - grammar, - grammar_start - ) + grammar, grammar_start) def get_valid_next_token_strs(self, full_seq): """ @@ -443,7 +442,6 @@ def get_valid_next_token_strs(self, full_seq): if parser.is_valid_next_seq(tok_str): yield tok_str - def get_valid_next_token_ids(self, full_seq): """ Generate valid token ids given the full sequence @@ -456,6 +454,7 @@ class GrammarLogitsProcessor(NextTokenValidator): """ Apply NextTokenValidator in __call__ and set excluded tokens logits to -inf """ + def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: # get valid token IDs given prior tokens From 7afc27bbbf6a4fc9f4522801bb402a5f5cecea53 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 26 Dec 2023 20:17:35 -0600 Subject: [PATCH 63/76] add grammars to openai --- vllm/entrypoints/openai/api_server.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0f131ce6f4dc..bce4a93907cd 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -243,6 +243,16 @@ async def create_chat_completion(request: ChatCompletionRequest, if error_check_ret is not None: return error_check_ret + grammar = request_dict.pop("grammar") + if grammar: + grammar_logits_processor = GrammarLogitsProcessor( + tokenizer=llm_engine.model_config.tokenizer, + grammar=grammar + ) + logits_processors = [grammar_logits_processor] + else: + logits_processors = [] + model_name = request.model request_id = f"cmpl-{random_uuid()}" created_time = int(time.monotonic()) @@ -266,6 +276,7 @@ async def create_chat_completion(request: ChatCompletionRequest, use_beam_search=request.use_beam_search, skip_special_tokens=request.skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, + logits_processors=logits_processors, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) From c41f8e44df524964e7b24806dc440cb7e12dd1ef Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 26 Dec 2023 20:23:13 -0600 Subject: [PATCH 64/76] fixes --- vllm/entrypoints/openai/api_server.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bce4a93907cd..278edfe1415c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,6 +18,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse, Response +from vllm.grammar import GrammarLogitsProcessor from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.metrics import add_global_metrics_labels @@ -243,12 +244,9 @@ async def create_chat_completion(request: ChatCompletionRequest, if error_check_ret is not None: return error_check_ret - grammar = request_dict.pop("grammar") - if grammar: + if request.grammar: grammar_logits_processor = GrammarLogitsProcessor( - tokenizer=llm_engine.model_config.tokenizer, - grammar=grammar - ) + tokenizer=engine.model_config.tokenizer, grammar=request.grammar) logits_processors = [grammar_logits_processor] else: logits_processors = [] From 2b2b024dfd7a7048a06884e44c70a36d7357e3e3 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 27 Dec 2023 13:59:18 -0600 Subject: [PATCH 65/76] use grammars in /v1/completions only, shouldn't apply to /v1/chat/completions --- vllm/entrypoints/openai/api_server.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 278edfe1415c..bbcc24d73cf5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -244,13 +244,6 @@ async def create_chat_completion(request: ChatCompletionRequest, if error_check_ret is not None: return error_check_ret - if request.grammar: - grammar_logits_processor = GrammarLogitsProcessor( - tokenizer=engine.model_config.tokenizer, grammar=request.grammar) - logits_processors = [grammar_logits_processor] - else: - logits_processors = [] - model_name = request.model request_id = f"cmpl-{random_uuid()}" created_time = int(time.monotonic()) @@ -274,7 +267,6 @@ async def create_chat_completion(request: ChatCompletionRequest, use_beam_search=request.use_beam_search, skip_special_tokens=request.skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, - logits_processors=logits_processors, ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) @@ -499,6 +491,14 @@ async def create_completion(request: CompletionRequest, raw_request: Request): if error_check_ret is not None: return error_check_ret + logger.info(f"grammar: {request.grammar}") + if request.grammar: + grammar_logits_processor = GrammarLogitsProcessor( + tokenizer=engine.model_config.tokenizer, grammar=request.grammar) + logits_processors = [grammar_logits_processor] + else: + logits_processors = [] + created_time = int(time.monotonic()) try: spaces_between_special_tokens = request.spaces_between_special_tokens @@ -522,6 +522,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): prompt_logprobs=request.logprobs if request.echo else None, skip_special_tokens=request.skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, + logits_processors=logits_processors ) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) From c2b026d6c6eb11ced825ac2f8d468b5917cb5fb5 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 27 Dec 2023 14:14:28 -0600 Subject: [PATCH 66/76] yapf --- vllm/entrypoints/openai/api_server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bbcc24d73cf5..d1182ec030c0 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -522,8 +522,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): prompt_logprobs=request.logprobs if request.echo else None, skip_special_tokens=request.skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, - logits_processors=logits_processors - ) + logits_processors=logits_processors) except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) From 0fa9d0753e80e7006ec08d9bf4296575e42a1f80 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 27 Dec 2023 15:01:25 -0600 Subject: [PATCH 67/76] add grammar to protocol --- vllm/entrypoints/openai/protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7a86a19c4bf8..37bda516b9be 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -106,6 +106,7 @@ class CompletionRequest(BaseModel): spaces_between_special_tokens: Optional[bool] = True repetition_penalty: Optional[float] = 1.0 min_p: Optional[float] = 0.0 + grammar: Optional[str] = None class LogProbs(BaseModel): From e23f93c3feea7533bd88a458b2fb1e5729e782ce Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Wed, 27 Dec 2023 15:08:00 -0600 Subject: [PATCH 68/76] bug fix in api server --- format.sh | 2 +- vllm/entrypoints/openai/api_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/format.sh b/format.sh index 251839893c97..d38daf4d018e 100755 --- a/format.sh +++ b/format.sh @@ -34,7 +34,7 @@ tool_version_check() { } tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)" -tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)" +#tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)" tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)" YAPF_FLAGS=( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d1182ec030c0..d30ef6cf0453 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -494,7 +494,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): logger.info(f"grammar: {request.grammar}") if request.grammar: grammar_logits_processor = GrammarLogitsProcessor( - tokenizer=engine.model_config.tokenizer, grammar=request.grammar) + tokenizer=tokenizer, grammar=request.grammar) logits_processors = [grammar_logits_processor] else: logits_processors = [] From 81ec332f9495824269f0a207140c2ae968b0ab02 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 28 Dec 2023 03:04:50 -0600 Subject: [PATCH 69/76] integrate into engine --- vllm/entrypoints/openai/api_server.py | 11 +++++++---- vllm/grammar.py | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d30ef6cf0453..17b863ba4480 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,7 +18,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse, Response -from vllm.grammar import GrammarLogitsProcessor +from vllm.grammar import GrammarLogitsProcessor, RayRemoteGrammarLogitsProcessor from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.metrics import add_global_metrics_labels @@ -491,10 +491,13 @@ async def create_completion(request: CompletionRequest, raw_request: Request): if error_check_ret is not None: return error_check_ret - logger.info(f"grammar: {request.grammar}") if request.grammar: - grammar_logits_processor = GrammarLogitsProcessor( - tokenizer=tokenizer, grammar=request.grammar) + if engine.worker_use_ray: + grammar_logits_processor = RayRemoteGrammarLogitsProcessor( + tokenizer=tokenizer, grammar=request.grammar) + else: + grammar_logits_processor = GrammarLogitsProcessor( + tokenizer=tokenizer, grammar=request.grammar) logits_processors = [grammar_logits_processor] else: logits_processors = [] diff --git a/vllm/grammar.py b/vllm/grammar.py index cd04d44d716c..c8ff77a97d0b 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -7,6 +7,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing import Optional, List, Set, Union +import ray + from lark import Lark from lark.parsers.lalr_interactive_parser import InteractiveParser from lark.parsers.lalr_parser_state import ParserState @@ -468,3 +470,22 @@ def __call__(self, token_ids: List[int], mask[valid] = True logits[~mask] = float('-inf') return logits + + +@ray.remote +class GrammarLogitsProcessorActor: + def __init__(self, *args, **kwargs): + self.processor = GrammarLogitsProcessor(*args, **kwargs) + + def process_logits(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + return self.processor(token_ids, logits) + + +class RayRemoteGrammarLogitsProcessor: + def __init__(self, *args, **kwargs): + self.actor = GrammarLogitsProcessorActor.remote(*args, **kwargs) + + def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + logits_cpu = logits.cpu() + result_id = self.actor.process_logits.remote(token_ids, logits_cpu) + return ray.get(result_id) From a1f9352f8705f7f52142f01dfff7243549248896 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 28 Dec 2023 21:20:25 -0600 Subject: [PATCH 70/76] remove dead code --- vllm/grammar.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index c8ff77a97d0b..9c5937a26056 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -208,7 +208,6 @@ class IncrementalParserState: full_seq: str # function of key - prior_terminal_ids: tuple[str] terminal_candidates: list # shared across instances @@ -220,7 +219,6 @@ class IncrementalParserState: def __repr__(self): shown = [ "partial_token", "full_seq", "terminal_candidates", - "prior_terminal_ids" ] attrs_str = ", ".join(f"{s}={repr(getattr(self, s))}" for s in shown) return f"{self.__class__.__name__}({attrs_str})" @@ -248,7 +246,6 @@ def from_grammar(cls, grammar: str, start: str): ["" if seq is None else None] * 2) parser = cls(interactive_parser=interactive_parser, - prior_terminal_ids=tuple(), full_seq="", partial_token="", terminal_candidates=None, @@ -314,7 +311,6 @@ def step(self, new_seq: str): return self.new( interactive_parser=self.interactive_parser, full_seq=self.full_seq + new_seq, - prior_terminal_ids=self.prior_terminal_ids, partial_token=new_maybe_partial_token, terminal_candidates=partial_terminal_ids, ) @@ -335,8 +331,6 @@ def step(self, new_seq: str): new_parser = self.new( full_seq=base_seq + processed_seq, interactive_parser=new_interactive_parser, - prior_terminal_ids=hash( - (self.prior_terminal_ids, best_terminal)), partial_token="", terminal_candidates=None, ) @@ -429,7 +423,6 @@ def __init__( ): self.tokenizer = tokenizer self.vocab = TokenVocab(tokenizer, legal_chars=legal_chars) - self.root_parser = IncrementalParserState.from_grammar( grammar, grammar_start) From db09714d4598368bb3c2cd314b9f466ffc4b7b5c Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 28 Dec 2023 21:20:32 -0600 Subject: [PATCH 71/76] cleaner --- vllm/grammar.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 9c5937a26056..1137ab5e64a3 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -432,10 +432,8 @@ def get_valid_next_token_strs(self, full_seq): """ parser = self.root_parser[full_seq] if parser is None: - return - for tok_str in self.vocab: - if parser.is_valid_next_seq(tok_str): - yield tok_str + return [] + return filter(parser.is_valid_next_seq, self.vocab) def get_valid_next_token_ids(self, full_seq): """ From 350f364f7374dbc02a9e6a390893469b86a407e8 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 29 Dec 2023 00:17:26 -0600 Subject: [PATCH 72/76] implementation with no partial state tracked by parser --- vllm/grammar.py | 106 +++++++++++++++++++++--------------------------- 1 file changed, 47 insertions(+), 59 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 1137ab5e64a3..8667056ac866 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -202,10 +202,6 @@ class IncrementalParserState: # unique state key interactive_parser: FastInteractiveParser - partial_token: str - - # orthogonal unique state key - full_seq: str # function of key terminal_candidates: list @@ -217,11 +213,7 @@ class IncrementalParserState: _full_seq_trie: Trie def __repr__(self): - shown = [ - "partial_token", "full_seq", "terminal_candidates", - ] - attrs_str = ", ".join(f"{s}={repr(getattr(self, s))}" for s in shown) - return f"{self.__class__.__name__}({attrs_str})" + return f"{self.__class__.__name__}({self.interactive_parser.parser_state.state_stack})" @classmethod @lru_cache(1000) @@ -246,8 +238,6 @@ def from_grammar(cls, grammar: str, start: str): ["" if seq is None else None] * 2) parser = cls(interactive_parser=interactive_parser, - full_seq="", - partial_token="", terminal_candidates=None, _ignored_terms=set(lark_parser.lexer_conf.ignore), _seq_validator=_seq_validator, @@ -258,8 +248,7 @@ def from_grammar(cls, grammar: str, start: str): def new(self, **kwargs): """Cached create now state""" - parser_state_key = (hash(kwargs["interactive_parser"]), - kwargs["partial_token"]) + parser_state_key = hash(kwargs["interactive_parser"]) if parser_state_key in self._memo: return self._memo[parser_state_key] @@ -278,9 +267,15 @@ def __getitem__(self, full_seq): if parser is None: return if remainder_seq: - parser = parser.step(remainder_seq) - self._full_seq_trie.insert(full_seq, parser) - return parser + result = parser.step(remainder_seq) + if result is None: + return None + remainder_seq, parser = result + processed_seq = full_seq + if remainder_seq: + processed_seq = processed_seq[:-len(remainder_seq)] + self._full_seq_trie.insert(processed_seq, parser) + return remainder_seq, parser @memoize_by_instance def step(self, new_seq: str): @@ -293,58 +288,43 @@ def step(self, new_seq: str): - If no partial matches, return None """ if new_seq == "": - return self - - new_maybe_partial_token = self.partial_token + new_seq + return "", self best_terminal, processed_seq, remainder_seq = self.get_best_matched_terminal( - self.allowed_terminals, new_maybe_partial_token) + new_seq) + + # invalid if best_terminal is None: return None # candidate doesn't complete terminal - if remainder_seq is None: - partial_terminal_ids = self.get_partial_terminal_ids( - self.allowed_terminals, - new_maybe_partial_token, - ) - return self.new( - interactive_parser=self.interactive_parser, - full_seq=self.full_seq + new_seq, - partial_token=new_maybe_partial_token, - terminal_candidates=partial_terminal_ids, - ) - - # terminal completes rule - else: - if best_terminal in self._ignored_terms: - new_interactive_parser = self.interactive_parser - else: - new_interactive_parser = self.get_stepped_parser_state( - best_terminal) + elif remainder_seq is None: + return processed_seq, self - if self.partial_token: - base_seq = self.full_seq[:-len(self.partial_token)] + # candidate completes terminal + else: + new_parser = self._next_with_new_terminal(best_terminal) + if remainder_seq == "": + return "", new_parser else: - base_seq = self.full_seq + return new_parser.step(remainder_seq) - new_parser = self.new( - full_seq=base_seq + processed_seq, - interactive_parser=new_interactive_parser, - partial_token="", - terminal_candidates=None, - ) - # no leftover to process - if remainder_seq == "": - return new_parser + @memoize_by_instance + def _next_with_new_terminal(self, terminal): + if terminal in self._ignored_terms: + new_interactive_parser = self.interactive_parser + else: + new_interactive_parser = self.get_stepped_parser_state( + terminal) - # process remainder - else: - return new_parser.step(remainder_seq) + return self.new( + interactive_parser=new_interactive_parser, + terminal_candidates=None, + ) - def get_best_matched_terminal(self, checked_terminals, seq): - for terminal in checked_terminals: + def get_best_matched_terminal(self, seq): + for terminal in self.accepts(): processed_seq, remainder_seq = self._seq_validator[terminal](seq) if processed_seq: return terminal, processed_seq, remainder_seq @@ -430,10 +410,18 @@ def get_valid_next_token_strs(self, full_seq): """ Generate valid token strings given the full sequence """ - parser = self.root_parser[full_seq] - if parser is None: + + result = self.root_parser[full_seq] + if result is None: return [] - return filter(parser.is_valid_next_seq, self.vocab) + partial_term, parser = result + for token in self.vocab: + if token is None: + if partial_term == "" and parser.is_valid_next_seq(token): + yield None + else: + if parser.is_valid_next_seq(partial_term + token): + yield token def get_valid_next_token_ids(self, full_seq): """ From a88e50636995af1efd5eeef26fd3a385be8b0c04 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 29 Dec 2023 02:43:49 -0600 Subject: [PATCH 73/76] faster cache decorator via lru_cache --- vllm/grammar.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 8667056ac866..bcfac0630943 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -1,11 +1,12 @@ import collections from copy import deepcopy, copy from dataclasses import dataclass, fields -from functools import wraps, lru_cache +import functools import regex import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing import Optional, List, Set, Union +import weakref import ray @@ -82,7 +83,7 @@ def get_pattern_validator(pattern: Pattern): if isinstance(pattern, PatternRE): compiled_pattern = regex.compile(pattern.value) - @lru_cache(int(1e6)) + @functools.lru_cache(int(1e6)) def get_re_matched_parts(seq): # match complete terminal, potentially with leftover seq complete_terminal_match = compiled_pattern.match(seq) @@ -108,7 +109,7 @@ def get_re_matched_parts(seq): elif isinstance(pattern, PatternStr): base_str = pattern.value - @lru_cache(int(1e6)) + @functools.lru_cache(int(1e6)) def get_str_matched_parts(seq): if seq.startswith(base_str): processed_seq = seq[:len(base_str)] @@ -125,22 +126,23 @@ def get_str_matched_parts(seq): raise TypeError(f"Invalid pattern type: {type(pattern)}") -def memoize_by_instance(method): - """ - Memoize by id(self) and fn args - """ - mname = method.__name__ - - @wraps(method) - def wrapper(self, *args): - key = (mname, id(self), args) - if key in self._memo: - return self._memo[key] - result = method(self, *args) - self._memo[key] = result - return result +def method_lru_cache(*lru_args, **lru_kwargs): + def decorator(func): + @functools.wraps(func) + def wrapped_func(self, *args, **kwargs): + # We're storing the wrapped method inside the instance. If we had + # a strong reference to self the instance would never die. + self_weak = weakref.ref(self) + @functools.wraps(func) + @functools.lru_cache(*lru_args, **lru_kwargs) + def cached_method(*args, **kwargs): + return func(self_weak(), *args, **kwargs) + setattr(self, func.__name__, cached_method) + return cached_method(*args, **kwargs) + return wrapped_func + return decorator - return wrapper +memoize_by_instance = method_lru_cache(int(1e7)) class TrieNode: @@ -216,7 +218,7 @@ def __repr__(self): return f"{self.__class__.__name__}({self.interactive_parser.parser_state.state_stack})" @classmethod - @lru_cache(1000) + @functools.lru_cache(1000) def from_grammar(cls, grammar: str, start: str): lark_parser = Lark( grammar, @@ -347,7 +349,6 @@ def get_stepped_parser_state(self, new_token_str): def accepts(self): return set(self.interactive_parser.accepts()) | self._ignored_terms - @property @memoize_by_instance def allowed_terminals(self): if self.terminal_candidates is not None: @@ -357,7 +358,7 @@ def allowed_terminals(self): @memoize_by_instance def is_valid_next_seq(self, new_seq: Optional[str]): if new_seq is None: - return "$END" in self.allowed_terminals + return "$END" in self.allowed_terminals() return self.step(new_seq) is not None From 4f31fc058ac8aaa6c8b4369b9562a8287853cc2f Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 29 Dec 2023 02:48:34 -0600 Subject: [PATCH 74/76] add attribution --- vllm/grammar.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index bcfac0630943..dce340289d99 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -127,11 +127,10 @@ def get_str_matched_parts(seq): def method_lru_cache(*lru_args, **lru_kwargs): + # https://stackoverflow.com/a/44078118 def decorator(func): @functools.wraps(func) def wrapped_func(self, *args, **kwargs): - # We're storing the wrapped method inside the instance. If we had - # a strong reference to self the instance would never die. self_weak = weakref.ref(self) @functools.wraps(func) @functools.lru_cache(*lru_args, **lru_kwargs) From f130c98b7e1e0605ec8c406ea6e302754abc3170 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 29 Dec 2023 03:04:47 -0600 Subject: [PATCH 75/76] remove dead code --- vllm/grammar.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index dce340289d99..453162cdbf23 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -141,6 +141,7 @@ def cached_method(*args, **kwargs): return wrapped_func return decorator + memoize_by_instance = method_lru_cache(int(1e7)) @@ -332,12 +333,6 @@ def get_best_matched_terminal(self, seq): return None, None, None - def get_partial_terminal_ids(self, checked_terminals, seq): - return set([ - term for term in checked_terminals - if self._seq_validator[term](seq)[0] is not None - ]) - @memoize_by_instance def get_stepped_parser_state(self, new_token_str): ip = copy(self.interactive_parser) From 344f27b84c41034067227d740967aca5254b0ade Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 29 Dec 2023 20:04:23 -0600 Subject: [PATCH 76/76] yapf --- vllm/grammar.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm/grammar.py b/vllm/grammar.py index 453162cdbf23..7f30c77f847e 100644 --- a/vllm/grammar.py +++ b/vllm/grammar.py @@ -129,16 +129,21 @@ def get_str_matched_parts(seq): def method_lru_cache(*lru_args, **lru_kwargs): # https://stackoverflow.com/a/44078118 def decorator(func): + @functools.wraps(func) def wrapped_func(self, *args, **kwargs): self_weak = weakref.ref(self) + @functools.wraps(func) @functools.lru_cache(*lru_args, **lru_kwargs) def cached_method(*args, **kwargs): return func(self_weak(), *args, **kwargs) + setattr(self, func.__name__, cached_method) return cached_method(*args, **kwargs) + return wrapped_func + return decorator @@ -293,7 +298,7 @@ def step(self, new_seq: str): return "", self best_terminal, processed_seq, remainder_seq = self.get_best_matched_terminal( - new_seq) + new_seq) # invalid if best_terminal is None: @@ -311,14 +316,12 @@ def step(self, new_seq: str): else: return new_parser.step(remainder_seq) - @memoize_by_instance def _next_with_new_terminal(self, terminal): if terminal in self._ignored_terms: new_interactive_parser = self.interactive_parser else: - new_interactive_parser = self.get_stepped_parser_state( - terminal) + new_interactive_parser = self.get_stepped_parser_state(terminal) return self.new( interactive_parser=new_interactive_parser, @@ -448,18 +451,22 @@ def __call__(self, token_ids: List[int], @ray.remote class GrammarLogitsProcessorActor: + def __init__(self, *args, **kwargs): self.processor = GrammarLogitsProcessor(*args, **kwargs) - def process_logits(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + def process_logits(self, token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: return self.processor(token_ids, logits) class RayRemoteGrammarLogitsProcessor: + def __init__(self, *args, **kwargs): self.actor = GrammarLogitsProcessorActor.remote(*args, **kwargs) - def __call__(self, token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + def __call__(self, token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: logits_cpu = logits.cpu() result_id = self.actor.process_logits.remote(token_ids, logits_cpu) return ray.get(result_id)