From daf9950677292a1ab96f3b5184954cb690fb7f16 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 2 Jul 2024 22:54:47 +0000 Subject: [PATCH] revamp eos handling --- controllers/aici_abi/src/toktree.rs | 19 +++---- controllers/llguidance_ctrl/run_g.py | 26 ++++++---- .../llguidance_ctrl/src/earley/lexer.rs | 6 ++- .../llguidance_ctrl/src/earley/lexerspec.rs | 13 +++++ .../llguidance_ctrl/src/earley/parser.rs | 50 ++++++++++--------- .../llguidance_ctrl/src/tokenparser.rs | 30 +++++++---- scripts/test-guidance.sh | 1 + 7 files changed, 91 insertions(+), 54 deletions(-) diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index dd846d8e..99271f01 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -558,15 +558,6 @@ impl TokTrie { } } } - // all prefixes of 'start' are also allowed - if start.len() > 0 { - for len in 1..=start.len() { - let bytes = &start[0..len]; - if let Some(tok) = self.token_id(bytes) { - logits.allow_token(tok); - } - } - } self.add_bias(r, logits, start); self.apply_duplicates(logits); } @@ -682,6 +673,16 @@ impl TokTrie { #[inline(never)] pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) { + // all prefixes of 'start' are also allowed + if start.len() > 0 { + for len in 1..=start.len() { + let bytes = &start[0..len]; + if let Some(tok) = self.token_id(bytes) { + toks.allow_token(tok); + } + } + } + r.trie_started(); let n = self.child_at_bytes(self.root(), start).unwrap(); let defl_tok = self.vocab_size() as u32; diff --git a/controllers/llguidance_ctrl/run_g.py b/controllers/llguidance_ctrl/run_g.py index ad56ce09..2c275c0e 100644 --- a/controllers/llguidance_ctrl/run_g.py +++ b/controllers/llguidance_ctrl/run_g.py @@ -185,9 +185,7 @@ def character_maker2(lm, id, description, valid_weapons): ) ) - grm = "6 * 7 = " + greedy_grammar( - body = lexeme("[0-9]{1,3}") - ) + "\n" + grm = "6 * 7 = " + greedy_grammar(body=lexeme("[0-9]{1,3}")) + "\n" # assert grm.match("6 * 7 = 42\n") grm = ( @@ -203,14 +201,26 @@ def character_maker2(lm, id, description, valid_weapons): grm = "6 * 7 = " + gen("name", max_tokens=2) - grm = "Name: " + gen('name', max_tokens=2) + " Height: " + gen('height', max_tokens=3) - grm = "Name: " + gen('name', max_tokens=2) + "Emily Carter is great; Height: " + gen('height', max_tokens=3) + grm = ( + "Name: " + gen("name", max_tokens=2) + " Height: " + gen("height", max_tokens=3) + ) + grm = ( + "Name: " + + gen("name", max_tokens=2) + + "Emily Carter is great; Height: " + + gen("height", max_tokens=3) + ) grm = "123" + gen(name="numbers", regex=r"\d*233", max_tokens=5) - grm = greedy_grammar(body=lexeme("[0-9]+"),skip_regex=r"\s*") + "x" + grm = ( + "Here: 2 + 2 = " + + greedy_grammar(body=lexeme("[0-9]+"), skip_regex=r"\s*") + + "x" + ) + grm = "Here: 2 + 2 = " + greedy_grammar(name="num", body=lexeme("[0-9]+")) - grm = "Here: 2 + 2 = " + guidance.json(name="num", schema={"type": "integer"}) + # grm = "Here: 2 + 2 = " + guidance.json(name="num", schema={"type": "integer"}) # grm = guidance.json(name="num", schema={"type": "integer"}) # m = grm.match("123") # print(m) @@ -218,8 +228,6 @@ def character_maker2(lm, id, description, valid_weapons): # grm = "Name: " + gen('name', max_tokens=2) + " Height: " + gen('height', max_tokens=3) - - # g = zero_or_more("a") + "b" # assert g.match("b") # assert g.match("ab") diff --git a/controllers/llguidance_ctrl/src/earley/lexer.rs b/controllers/llguidance_ctrl/src/earley/lexer.rs index 3d06c148..e56888df 100644 --- a/controllers/llguidance_ctrl/src/earley/lexer.rs +++ b/controllers/llguidance_ctrl/src/earley/lexer.rs @@ -4,7 +4,7 @@ use std::fmt::Debug; use super::{ lexerspec::{LexemeIdx, LexerSpec}, - regexvec::{RegexVec, StateDesc, NextByte}, + regexvec::{NextByte, RegexVec, StateDesc}, }; const DEBUG: bool = true; @@ -73,7 +73,9 @@ impl Lexer { } pub fn allows_eos(&mut self, state: StateID) -> bool { - self.state_info(state).is_accepting() + let mut l = self.spec.eos_ending_lexemes(); + l.and(&self.state_info(state).accepting); + !l.is_zero() } pub fn limit_state_to(&mut self, state: StateID, allowed_lexemes: &SimpleVob) -> StateID { diff --git a/controllers/llguidance_ctrl/src/earley/lexerspec.rs b/controllers/llguidance_ctrl/src/earley/lexerspec.rs index 5eb8e198..b547952c 100644 --- a/controllers/llguidance_ctrl/src/earley/lexerspec.rs +++ b/controllers/llguidance_ctrl/src/earley/lexerspec.rs @@ -18,6 +18,7 @@ pub struct LexemeSpec { name: String, pub(crate) rx: RegexAst, compiled_rx: ExprRef, + ends_at_eos: bool, lazy: bool, contextual: bool, } @@ -94,6 +95,16 @@ impl LexerSpec { v } + pub fn eos_ending_lexemes(&self) -> SimpleVob { + let mut v = self.alloc_lexeme_set(); + for (idx, lex) in self.lexemes.iter().enumerate() { + if lex.ends_at_eos { + v.set(idx, true); + } + } + v + } + pub fn is_nullable(&self, idx: LexemeIdx) -> bool { self.regex_builder .is_nullable(self.lexemes[idx.0].compiled_rx) @@ -138,6 +149,7 @@ impl LexerSpec { compiled_rx: ExprRef::INVALID, lazy: false, contextual: false, + ends_at_eos: false, } } @@ -157,6 +169,7 @@ impl LexerSpec { name, rx, lazy, + ends_at_eos: !lazy, ..self.empty_spec() }) } diff --git a/controllers/llguidance_ctrl/src/earley/parser.rs b/controllers/llguidance_ctrl/src/earley/parser.rs index aeea7731..a44b8d32 100644 --- a/controllers/llguidance_ctrl/src/earley/parser.rs +++ b/controllers/llguidance_ctrl/src/earley/parser.rs @@ -388,11 +388,17 @@ impl Parser { pub fn compute_bias(&mut self, trie: &TokTrie, start: &[u8]) -> SimpleVob { let mut set = trie.alloc_token_set(); - trie.compute_bias_ext(self, &mut set, start); + trie.add_bias(self, &mut set, start); + trie.apply_duplicates(&mut set); - if set.num_set() == 1 && set.is_allowed(trie.eos_token()) { + if set.is_zero() { + // nothing allowed // we're going to be stopped outside - we better flush the lexer - self.flush_lexer(); + let _ = self.flush_lexer(); + } + + if start.is_empty() && self.lexer_allows_eos() { + set.allow_token(trie.eos_token()); } set @@ -846,6 +852,7 @@ impl Parser { } pub fn model_variables(&mut self) -> Vec { + // this can be used in future to allow "end-of-turn" special token and the like self.run_speculative(|s| { let mut vars = vec![]; if s.flush_lexer() { @@ -887,6 +894,9 @@ impl Parser { }) } + /// Advance the parser as if the current lexeme (if any) + /// finished right here. + /// Returns true if the parser was able to advance (or there were no pending bytes for a lexeme). fn flush_lexer(&mut self) -> bool { if !self.has_pending_lexeme_bytes() { return true; @@ -1493,26 +1503,20 @@ impl Recognizer for Parser { self.last_collapse = self.num_rows(); } - fn special_allowed(&mut self, tok: SpecialToken) -> bool { - if false { - self.print_row(self.num_rows() - 1); - println!( - "model vars: accpt={} {:?}", - self.is_accepting(), - self.model_variables() - ); - } - - if self - .model_variables() - .contains(&ModelVariable::SpecialToken(tok)) - { - true - } else if tok == SpecialToken::EndOfSentence { - self.is_accepting() || self.lexer_allows_eos() - } else { - false - } + fn special_allowed(&mut self, _tok: SpecialToken) -> bool { + // handle EOS logic outside + unreachable!("special_allowed") + + // if self + // .model_variables() + // .contains(&ModelVariable::SpecialToken(tok)) + // { + // true + // } else if tok == SpecialToken::EndOfSentence { + // self.is_accepting() || self.lexer_allows_eos() + // } else { + // false + // } } fn trie_started(&mut self) { diff --git a/controllers/llguidance_ctrl/src/tokenparser.rs b/controllers/llguidance_ctrl/src/tokenparser.rs index f5ed8f49..5435b90f 100644 --- a/controllers/llguidance_ctrl/src/tokenparser.rs +++ b/controllers/llguidance_ctrl/src/tokenparser.rs @@ -407,7 +407,7 @@ impl TokenParser { } if token_prefix.is_empty() { - if let Err(e) = self.maybe_gen_grammar() { + if let Err(e) = self.maybe_push_parser() { warn!(self, "Error creating nested parser: {}", e); return MidProcessResult::stop(); } @@ -434,22 +434,27 @@ impl TokenParser { // self.parser.print_row(self.parser.num_rows() - 1); let mut set = self.parser.compute_bias(trie, &token_prefix); - if inner_done - || self.max_tokens_parser == 0 - || (set.num_set() == 1 && set.is_allowed(trie.eos_token())) - { + if inner_done || self.max_tokens_parser == 0 { if self.parser_stack.is_empty() { self.mid_process_was_accepting = inner_accepting; - infoln!(self, "only eos token allowed, stopping; accepting: {}", inner_accepting); + infoln!( + self, + "only eos token allowed, stopping; accepting: {}", + inner_accepting + ); return MidProcessResult::stop(); } else { infoln!(self, "pop_parser; tokens left {}", self.max_tokens_parser); self.pop_parser(); // re-start the whole process with a nice tail-recursion - return self.mid_process_inner(MidProcessArg { - backtrack: 0, - tokens: Vec::new(), - fork_group: Vec::new(), + return self.mid_process_inner(if has_eos { + arg + } else { + MidProcessArg { + backtrack: 0, + tokens: Vec::new(), + fork_group: Vec::new(), + } }); } } @@ -480,6 +485,9 @@ impl TokenParser { self.pop_tokens = Some(pop_tokens); } self.mid_process_was_accepting = all_accepting; + if all_accepting { + set.allow_token(trie.eos_token()); + } } infoln!( @@ -499,7 +507,7 @@ impl TokenParser { return MidProcessResult::sample_with_temp(set, Some(self.parser.temperature())); } - fn maybe_gen_grammar(&mut self) -> Result<()> { + fn maybe_push_parser(&mut self) -> Result<()> { if let Some((msg, symidx, gen_grammar)) = self.parser.maybe_gen_grammar() { if msg.len() > 0 { warn!(self, "{}", msg); diff --git a/scripts/test-guidance.sh b/scripts/test-guidance.sh index 80970724..bb1f2c6f 100755 --- a/scripts/test-guidance.sh +++ b/scripts/test-guidance.sh @@ -24,6 +24,7 @@ fi function runtest() { pytest "$@" if [ $? -ne 0 -a $? -ne 5 ] ; then + : exit 1 fi }