diff --git a/controllers/guidance_ctrl/run_g.py b/controllers/guidance_ctrl/run_g.py index a9fe8bab..783893eb 100644 --- a/controllers/guidance_ctrl/run_g.py +++ b/controllers/guidance_ctrl/run_g.py @@ -83,7 +83,8 @@ def main(): grm = "Tweak this proverb to apply to model instructions instead.\n" + gen( "verse", max_tokens=2 ) - # grm = "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(") + grm = "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(") + grm = "red\n" + gen(stop="") + " and test2" # read current script file # with open(__file__) as f: diff --git a/controllers/guidance_ctrl/src/earley/parser.rs b/controllers/guidance_ctrl/src/earley/parser.rs index d504a705..d7bd1520 100644 --- a/controllers/guidance_ctrl/src/earley/parser.rs +++ b/controllers/guidance_ctrl/src/earley/parser.rs @@ -1,4 +1,10 @@ -use std::{fmt::Debug, hash::Hash, ops::Range, rc::Rc, vec}; +use std::{ + fmt::{Debug, Display}, + hash::Hash, + ops::Range, + rc::Rc, + vec, +}; use aici_abi::{ toktree::{Recognizer, SpecialToken, TokTrie}, @@ -7,7 +13,7 @@ use aici_abi::{ use super::grammar::{CGrammar, CSymIdx, CSymbol, ModelVariable, RuleIdx}; -const DEBUG: bool = false; +const DEBUG: bool = true; const INFO: bool = true; macro_rules! debug { @@ -37,6 +43,16 @@ struct ItemProps { hidden_start: usize, } +impl Display for ItemProps { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.hidden_start == usize::MAX { + write!(f, "") + } else { + write!(f, "(hidden_start {})", self.hidden_start) + } + } +} + impl Default for ItemProps { fn default() -> Self { ItemProps { @@ -167,7 +183,7 @@ impl Scratch { } #[inline(always)] - fn just_add(&mut self, item: Item, origin_item_idx: usize) { + fn just_add(&mut self, item: Item, origin_item_idx: usize, info: &str) { self.ensure_items(self.row_end + 1); // SAFETY: we just ensured that there is enough space unsafe { @@ -181,28 +197,53 @@ impl Scratch { self.item_props[self.row_end] = ItemProps::default(); } self.merge_item_origin(self.row_end, origin_item_idx); + + debug!( + " addu: {} ({})", + self.item_to_string(self.row_end), + info + ); } self.row_end += 1; } #[inline(always)] - fn add_unique(&mut self, item: Item, origin_item_idx: usize, info: &str) { - if let Some(idx) = self.items[self.row_start..self.row_end] + fn find_item(&self, item: Item) -> Option { + self.items[self.row_start..self.row_end] .iter() .position(|&x| x == item) - { + .map(|x| x + self.row_start) + } + + fn set_hidden_start(&mut self, item: Item, hidden_start: usize) { + let idx = self.find_item(item).unwrap(); + self.item_props[idx].hidden_start = + std::cmp::min(self.item_props[idx].hidden_start, hidden_start); + debug!( + " hidden: {} {}", + hidden_start, + self.item_to_string(idx), + ); + } + + #[inline(always)] + fn add_unique(&mut self, item: Item, origin_item_idx: usize, info: &str) { + if let Some(idx) = self.find_item(item) { if self.definitive { self.merge_item_origin(idx, origin_item_idx); } } else { - if self.definitive { - debug!( - " addu: {} ({})", - item_to_string(&self.grammar, &item), - info - ); - } - self.just_add(item, origin_item_idx); + self.just_add(item, origin_item_idx, info); + } + } + + fn item_to_string(&self, idx: usize) -> String { + let r = item_to_string(&self.grammar, &self.items[idx]); + if self.definitive { + let props = &self.item_props[idx]; + format!("{} {}", r, props) + } else { + r } } } @@ -235,15 +276,15 @@ impl Parser { self.is_accepting } - fn item_to_string(&self, item: &Item) -> String { - item_to_string(&self.grammar, item) + fn item_to_string(&self, idx: usize) -> String { + self.scratch.item_to_string(idx) } pub fn print_row(&self, row_idx: usize) { let row = &self.rows[row_idx]; println!("row {}", row_idx); for i in row.item_indices() { - println!("{}", self.item_to_string(&self.scratch.items[i])); + println!("{}", self.item_to_string(i)); } } @@ -286,6 +327,14 @@ impl Parser { self.grammar.sym_data(self.item_lhs(item)) } + pub fn hidden_start(&self) -> usize { + self.curr_row() + .item_indices() + .map(|i| self.scratch.item_props[i].hidden_start) + .min() + .unwrap_or(usize::MAX) + } + pub fn apply_tokens( &mut self, trie: &TokTrie, @@ -355,7 +404,7 @@ impl Parser { " remove: {}-{} {}", self.token_idx, start_token_idx, - self.item_to_string(&item) + self.item_to_string(i) ); continue; } @@ -489,7 +538,7 @@ impl Parser { let idx = self.grammar.sym_idx_at(item.rule_idx()).as_index(); // idx == 0 => completed if idx < allowed.len() && allowed[idx] { - self.scratch.just_add(item.advance_dot(), i); + self.scratch.just_add(item.advance_dot(), i, "scan"); } i += 1; } @@ -513,7 +562,7 @@ impl Parser { let mut item = self.scratch.items[agenda_ptr]; agenda_ptr += 1; if self.scratch.definitive { - debug!(" agenda: {}", self.item_to_string(&item)); + debug!(" agenda: {}", self.item_to_string(item_idx)); } let rule = item.rule_idx(); @@ -570,12 +619,10 @@ impl Parser { self.scratch.item_props[agenda_ptr - 1] = self.scratch.item_props[item_idx].clone(); } - // better keep item_idx updated in case we use it in future item_idx = agenda_ptr - 1; - let _ = item_idx; // silence warning commit_item = item; if self.scratch.definitive { - debug!(" commit point: {}", self.item_to_string(&item)); + debug!(" commit point: {}", self.item_to_string(item_idx)); if flags.hidden() { return self.hide_item(lhs, item.start_pos()); } @@ -601,6 +648,12 @@ impl Parser { let new_item = Item::new(*rule, curr_idx); self.scratch.add_unique(new_item, item_idx, "predict"); } + if self.scratch.definitive && sym_data.props.hidden { + for rule in &sym_data.rules { + let new_item = Item::new(*rule, curr_idx); + self.scratch.set_hidden_start(new_item, curr_idx); + } + } } } diff --git a/controllers/guidance_ctrl/src/tokenparser.rs b/controllers/guidance_ctrl/src/tokenparser.rs index 993c61b3..1666ebca 100644 --- a/controllers/guidance_ctrl/src/tokenparser.rs +++ b/controllers/guidance_ctrl/src/tokenparser.rs @@ -51,7 +51,8 @@ impl TokenParser { if idx >= self.llm_bytes.len() { return &[]; } - &self.llm_bytes[idx..] + let endp = std::cmp::min(self.llm_bytes.len(), self.parser.hidden_start()); + &self.llm_bytes[idx..endp] } pub fn process_prompt(&mut self, prompt: Vec) -> Vec { @@ -103,13 +104,11 @@ impl TokenParser { infoln!("post tokens: {}", trie.tokens_dbg(&arg.tokens)); arg.save_tokens(&mut self.llm_tokens); - if arg.backtrack == 0 { - let new_bytes = trie.decode(&arg.tokens); - self.llm_bytes.extend_from_slice(&new_bytes); - } else { - // recompute on backtrack - self.llm_bytes = trie.decode(&self.llm_tokens); - } + let new_bytes = trie.decode(&arg.tokens); + self.llm_bytes.extend_from_slice(&new_bytes); + + // TODO maybe remove in future + assert!(self.llm_bytes == trie.decode(&self.llm_tokens)); let res = self .parser @@ -123,6 +122,8 @@ impl TokenParser { self.parser.force_bytes(); let grm_bytes = self.grm_bytes(); + let mut backtrack = 0; + // now, see if we need to backtrack if self.llm_bytes.len() > grm_bytes.len() || self.llm_bytes != grm_bytes[0..self.llm_bytes.len()] @@ -132,22 +133,25 @@ impl TokenParser { let b = trie.token(*t); let pend = ptr + b.len(); if pend > grm_bytes.len() || b != &grm_bytes[ptr..pend] { - let tokens = self.token_env.tokenize_bytes(&grm_bytes[ptr..]); - let backtrack = self.llm_tokens.len() - idx; + backtrack = self.llm_tokens.len() - idx; infoln!( - "backtrack: {} tokens: {}", + "backtrack: {} (deletes: {:?})", backtrack, - trie.tokens_dbg(&tokens) + String::from_utf8_lossy(&self.llm_bytes[ptr..]) ); - return MidProcessResult::splice(backtrack as u32, tokens); + assert!(backtrack > 0); + self.llm_bytes.drain(ptr..); + break; } ptr = pend; } - panic!( - "backtrack failed {:?} {:?}", - String::from_utf8_lossy(&self.llm_bytes), - String::from_utf8_lossy(&grm_bytes) - ); + if backtrack == 0 { + panic!( + "backtrack failed {:?} {:?}", + String::from_utf8_lossy(&self.llm_bytes), + String::from_utf8_lossy(&grm_bytes) + ); + } } if arg.tokens.contains(&trie.eos_token()) { @@ -157,7 +161,7 @@ impl TokenParser { let new_forced = grm_bytes[self.llm_bytes.len()..].to_vec(); let mut token_prefix = Vec::new(); - if new_forced.len() > 0 { + if new_forced.len() > 0 || backtrack > 0 { let mut grm_tokens = self.token_env.tokenize_bytes(&new_forced); infoln!("forced: {}", trie.tokens_dbg(&grm_tokens)); let (chop_tokens, chop_bytes) = trie.chop_tokens(&mut self.parser, &grm_tokens); @@ -167,7 +171,7 @@ impl TokenParser { if grm_tokens.len() > 0 { infoln!("fixed_tokens: {}", trie.tokens_dbg(&grm_tokens)); - return MidProcessResult::splice(0, grm_tokens); + return MidProcessResult::splice(backtrack as u32, grm_tokens); } else { infoln!("no fixed tokens"); }