diff --git a/controllers/derivre/src/regexvec.rs b/controllers/derivre/src/regexvec.rs index f54fbcdd..5a66cfa7 100644 --- a/controllers/derivre/src/regexvec.rs +++ b/controllers/derivre/src/regexvec.rs @@ -289,6 +289,16 @@ impl RegexVec { next_byte } + pub fn limit_state_to(&mut self, state: StateID, allowed_lexemes: &SimpleVob) -> StateID { + let mut vec_desc = vec![]; + for (idx, e) in iter_state(&self.rx_sets, state) { + if allowed_lexemes.get(idx) { + Self::push_rx(&mut vec_desc, idx, e); + } + } + self.insert_state(vec_desc) + } + pub fn total_fuel_spent(&self) -> usize { self.exprs.cost } diff --git a/controllers/llguidance_ctrl/run_g.py b/controllers/llguidance_ctrl/run_g.py index 221a9472..bafb72b7 100644 --- a/controllers/llguidance_ctrl/run_g.py +++ b/controllers/llguidance_ctrl/run_g.py @@ -201,6 +201,13 @@ def character_maker2(lm, id, description, valid_weapons): grm = gen(regex="a*") grm = "6 * 7 = " + gen(regex="5*") + gen(regex="[1-4][0-9]") + "\n" + grm = "6 * 7 = " + gen("name", max_tokens=2) + + grm = "Name: " + gen('name', max_tokens=2) + " Height: " + gen('height', max_tokens=3) + + # grm = "Name: " + gen('name', max_tokens=2) + " Height: " + gen('height', max_tokens=3) + + # g = zero_or_more("a") + "b" # assert g.match("b") @@ -214,7 +221,7 @@ def character_maker2(lm, id, description, valid_weapons): # body = lexeme("[0-9]+") # ) - max_tokens = 250 + max_tokens = 50 serialized = grm.ll_serialize() diff --git a/controllers/llguidance_ctrl/src/earley/grammar.rs b/controllers/llguidance_ctrl/src/earley/grammar.rs index 833b2451..ab891c72 100644 --- a/controllers/llguidance_ctrl/src/earley/grammar.rs +++ b/controllers/llguidance_ctrl/src/earley/grammar.rs @@ -221,7 +221,10 @@ impl Grammar { return Ok(()); } if lexer_spec.is_nullable(lex) { - let wrap = self.fresh_symbol(format!("rx_null_{}", self.sym_name(lhs)).as_str()); + let wrap = self.fresh_symbol_ext( + format!("rx_null_{}", self.sym_name(lhs)).as_str(), + self.sym_data(lhs).props.for_wrapper(), + ); self.sym_data_mut(wrap).lexeme = Some(lex); self.add_rule(lhs, vec![wrap])?; self.add_rule(lhs, vec![])?; diff --git a/controllers/llguidance_ctrl/src/earley/lexer.rs b/controllers/llguidance_ctrl/src/earley/lexer.rs index 7c2d1fbe..07024463 100644 --- a/controllers/llguidance_ctrl/src/earley/lexer.rs +++ b/controllers/llguidance_ctrl/src/earley/lexer.rs @@ -74,6 +74,14 @@ impl Lexer { self.state_info(state).is_accepting() } + pub fn limit_state_to(&mut self, state: StateID, allowed_lexemes: &SimpleVob) -> StateID { + self.dfa.limit_state_to(state, allowed_lexemes) + } + + pub fn possible_lexemes(&self, state: StateID) -> &SimpleVob { + &self.state_info(state).possible + } + pub fn force_lexeme_end(&self, prev: StateID) -> LexerResult { let info = self.state_info(prev); match info.possible.first_bit_set() { diff --git a/controllers/llguidance_ctrl/src/earley/parser.rs b/controllers/llguidance_ctrl/src/earley/parser.rs index 19f3ff5b..9ce8ddf4 100644 --- a/controllers/llguidance_ctrl/src/earley/parser.rs +++ b/controllers/llguidance_ctrl/src/earley/parser.rs @@ -1,4 +1,5 @@ use std::{ + collections::HashMap, fmt::{Debug, Display}, hash::Hash, ops::Range, @@ -154,7 +155,7 @@ struct RowInfo { lexeme: Lexeme, token_idx_start: usize, token_idx_stop: usize, - max_tokens: usize, + max_tokens: HashMap, } impl RowInfo { @@ -169,10 +170,10 @@ impl RowInfo { self.token_idx_start, self.token_idx_stop, lexspec.dbg_lexeme(&self.lexeme), - if self.max_tokens == usize::MAX { + if self.max_tokens.is_empty() { "".to_string() } else { - format!("max_tokens={}", self.max_tokens) + format!("max_tokens={:?}", self.max_tokens) } ) } @@ -614,10 +615,37 @@ impl Parser { 0, self.token_idx as isize + 1 - info.token_idx_start as isize, ) as usize; - if info_tokens >= info.max_tokens { - debug!(" max_tokens reached; {}", info.dbg(self.lexer_spec())); - if !self.try_push_byte_definitive(None) { - return Ok("parse reject on max_tokens"); + let lex_state = self.lexer_state().lexer_state; + let mut limit = trie.alloc_token_set(); + let mut num_limit = 0; + for idx in self.lexer.possible_lexemes(lex_state).iter() { + let lex = LexemeIdx::new(idx as usize); + let max_tokens = *info.max_tokens.get(&lex).unwrap_or(&usize::MAX); + trace!( + " max_tokens: {} max={} info={}", + self.lexer_spec().dbg_lexeme(&Lexeme::just_idx(lex)), + max_tokens, + info_tokens + ); + if info_tokens < max_tokens { + limit.allow_token(idx); + } else { + num_limit += 1; + } + } + if num_limit > 0 { + debug!( + " max_tokens limiting to: {}", + self.lexer_spec().dbg_lexeme_set(&limit) + ); + let new_state = self.lexer.limit_state_to(lex_state, &limit); + if new_state.is_dead() { + debug!(" limited everything; forcing EOI"); + if !self.try_push_byte_definitive(None) { + return Ok("parse reject on max_tokens"); + } + } else { + self.lexer_stack.last_mut().unwrap().lexer_state = new_state; } } } @@ -678,7 +706,7 @@ impl Parser { start_byte_idx: 0, token_idx_start: self.token_idx, token_idx_stop: self.token_idx, - max_tokens: usize::MAX, + max_tokens: HashMap::default(), }); for idx in 0..self.num_rows() { @@ -1016,7 +1044,8 @@ impl Parser { // with agenda pointer above self.rows[added_row_idx].allowed_lexemes = allowed_lexemes; if self.scratch.definitive { - self.row_infos[added_row_idx].max_tokens = self.row_infos[added_row_idx - 1].max_tokens; + self.row_infos[added_row_idx].max_tokens = + self.row_infos[added_row_idx - 1].max_tokens.clone(); } true } @@ -1061,7 +1090,7 @@ impl Parser { #[inline(always)] fn push_row(&mut self, curr_idx: usize, mut agenda_ptr: usize, lexeme: &Lexeme) -> bool { let mut allowed_lexemes = SimpleVob::alloc(self.grammar.num_terminals()); - let mut max_tokens = 0; + let mut max_tokens = vec![]; while agenda_ptr < self.scratch.row_end { let item_idx = agenda_ptr; @@ -1129,7 +1158,9 @@ impl Parser { let sym_data = self.grammar.sym_data(after_dot); if let Some(lx) = self.grammar.lexeme_idx_of(after_dot) { allowed_lexemes.set(lx.as_usize(), true); - max_tokens = max_tokens.max(sym_data.props.max_tokens); + if self.scratch.definitive { + max_tokens.push((lx, sym_data.props.max_tokens)); + } } if sym_data.is_nullable { self.scratch @@ -1185,12 +1216,31 @@ impl Parser { if self.row_infos.len() > idx { self.row_infos.drain(idx..); } + let mut max_tokens_map = HashMap::default(); + for (lx, mx) in max_tokens { + if let Some(ex) = max_tokens_map.get(&lx) { + if *ex < mx { + max_tokens_map.insert(lx, mx); + } + } else { + max_tokens_map.insert(lx, mx); + } + } + let mut to_remove = vec![]; + for (lx, mx) in max_tokens_map.iter() { + if *mx == usize::MAX { + to_remove.push(*lx); + } + } + for lx in to_remove { + max_tokens_map.remove(&lx); + } self.row_infos.push(RowInfo { lexeme: Lexeme::bogus(), token_idx_start: self.token_idx, token_idx_stop: self.token_idx, start_byte_idx: self.byte_idx, - max_tokens, + max_tokens: max_tokens_map, }); // debug!(" push: {idx} {} {}", self.rows.len(), self.row_infos.len()); } diff --git a/controllers/llguidance_ctrl/src/tokenparser.rs b/controllers/llguidance_ctrl/src/tokenparser.rs index 104ec910..600b61b6 100644 --- a/controllers/llguidance_ctrl/src/tokenparser.rs +++ b/controllers/llguidance_ctrl/src/tokenparser.rs @@ -455,26 +455,28 @@ impl TokenParser { if inner_accepting { let mut all_accepting = true; - let mut pop_tokens = trie.alloc_token_set(); - for pentry in self.parser_stack.iter_mut() { - if pentry.mask.is_none() { - assert!(token_prefix.is_empty()); - let mask = pentry - .parser - .compute_bias_after_gen_grammar(trie, pentry.symidx); - infoln!(self, "bias for upper parser: {}", trie.token_set_dbg(&mask)); - pentry.mask = Some(mask); - } - let m = pentry.mask.as_ref().unwrap(); - pop_tokens.or_minus(m, &set); - set.or(m); - if !pentry.is_accepting { - all_accepting = false; - break; + if self.parser_stack.len() > 0 { + let mut pop_tokens = trie.alloc_token_set(); + for pentry in self.parser_stack.iter_mut() { + if pentry.mask.is_none() { + assert!(token_prefix.is_empty()); + let mask = pentry + .parser + .compute_bias_after_gen_grammar(trie, pentry.symidx); + infoln!(self, "bias for upper parser: {}", trie.token_set_dbg(&mask)); + pentry.mask = Some(mask); + } + let m = pentry.mask.as_ref().unwrap(); + pop_tokens.or_minus(m, &set); + set.or(m); + if !pentry.is_accepting { + all_accepting = false; + break; + } } + infoln!(self, "pop_tokens: {}", trie.token_set_dbg(&pop_tokens)); + self.pop_tokens = Some(pop_tokens); } - infoln!(self, "pop_tokens: {}", trie.token_set_dbg(&pop_tokens)); - self.pop_tokens = Some(pop_tokens); self.mid_process_was_accepting = all_accepting; } diff --git a/py/guidance b/py/guidance index 20db158e..f8ef8129 160000 --- a/py/guidance +++ b/py/guidance @@ -1 +1 @@ -Subproject commit 20db158efdf6f254008232b74865e0340a5c96f4 +Subproject commit f8ef8129f4338b8241afe1f504d2e7e3636db2cc diff --git a/scripts/test-guidance.sh b/scripts/test-guidance.sh index c9d148c6..4c1ebae2 100755 --- a/scripts/test-guidance.sh +++ b/scripts/test-guidance.sh @@ -10,12 +10,20 @@ export AZURE_GUIDANCE_URL FILES="tests/need_credentials/test_azure_guidance.py tests/model_integration/test_greedy.py" +cd $(dirname $0)/../py/guidance + if [ "X$1" != "X" ] ; then if [ "X${1:0:2}" = "X::" ] ; then - FILES="tests/models/test_azure_guidance.py$1" + FILES="tests/need_credentials/test_azure_guidance.py$1" shift + pytest --selected_model azure_guidance --durations=10 $FILES "$@" + exit $? fi fi -cd $(dirname $0)/../py/guidance +set -e +# quick tests first +pytest tests/unit/test_ll.py +pytest tests/unit pytest --selected_model azure_guidance --durations=10 $FILES "$@" +pytest tests/model_integration