diff --git a/controllers/guidance_ctrl/run_g.py b/controllers/guidance_ctrl/run_g.py
index a9fe8bab..783893eb 100644
--- a/controllers/guidance_ctrl/run_g.py
+++ b/controllers/guidance_ctrl/run_g.py
@@ -83,7 +83,8 @@ def main():
grm = "Tweak this proverb to apply to model instructions instead.\n" + gen(
"verse", max_tokens=2
)
- # grm = "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(")
+ grm = "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(")
+ grm = "red\n" + gen(stop="") + " and test2"
# read current script file
# with open(__file__) as f:
diff --git a/controllers/guidance_ctrl/src/earley/parser.rs b/controllers/guidance_ctrl/src/earley/parser.rs
index d504a705..d7bd1520 100644
--- a/controllers/guidance_ctrl/src/earley/parser.rs
+++ b/controllers/guidance_ctrl/src/earley/parser.rs
@@ -1,4 +1,10 @@
-use std::{fmt::Debug, hash::Hash, ops::Range, rc::Rc, vec};
+use std::{
+ fmt::{Debug, Display},
+ hash::Hash,
+ ops::Range,
+ rc::Rc,
+ vec,
+};
use aici_abi::{
toktree::{Recognizer, SpecialToken, TokTrie},
@@ -7,7 +13,7 @@ use aici_abi::{
use super::grammar::{CGrammar, CSymIdx, CSymbol, ModelVariable, RuleIdx};
-const DEBUG: bool = false;
+const DEBUG: bool = true;
const INFO: bool = true;
macro_rules! debug {
@@ -37,6 +43,16 @@ struct ItemProps {
hidden_start: usize,
}
+impl Display for ItemProps {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ if self.hidden_start == usize::MAX {
+ write!(f, "")
+ } else {
+ write!(f, "(hidden_start {})", self.hidden_start)
+ }
+ }
+}
+
impl Default for ItemProps {
fn default() -> Self {
ItemProps {
@@ -167,7 +183,7 @@ impl Scratch {
}
#[inline(always)]
- fn just_add(&mut self, item: Item, origin_item_idx: usize) {
+ fn just_add(&mut self, item: Item, origin_item_idx: usize, info: &str) {
self.ensure_items(self.row_end + 1);
// SAFETY: we just ensured that there is enough space
unsafe {
@@ -181,28 +197,53 @@ impl Scratch {
self.item_props[self.row_end] = ItemProps::default();
}
self.merge_item_origin(self.row_end, origin_item_idx);
+
+ debug!(
+ " addu: {} ({})",
+ self.item_to_string(self.row_end),
+ info
+ );
}
self.row_end += 1;
}
#[inline(always)]
- fn add_unique(&mut self, item: Item, origin_item_idx: usize, info: &str) {
- if let Some(idx) = self.items[self.row_start..self.row_end]
+ fn find_item(&self, item: Item) -> Option {
+ self.items[self.row_start..self.row_end]
.iter()
.position(|&x| x == item)
- {
+ .map(|x| x + self.row_start)
+ }
+
+ fn set_hidden_start(&mut self, item: Item, hidden_start: usize) {
+ let idx = self.find_item(item).unwrap();
+ self.item_props[idx].hidden_start =
+ std::cmp::min(self.item_props[idx].hidden_start, hidden_start);
+ debug!(
+ " hidden: {} {}",
+ hidden_start,
+ self.item_to_string(idx),
+ );
+ }
+
+ #[inline(always)]
+ fn add_unique(&mut self, item: Item, origin_item_idx: usize, info: &str) {
+ if let Some(idx) = self.find_item(item) {
if self.definitive {
self.merge_item_origin(idx, origin_item_idx);
}
} else {
- if self.definitive {
- debug!(
- " addu: {} ({})",
- item_to_string(&self.grammar, &item),
- info
- );
- }
- self.just_add(item, origin_item_idx);
+ self.just_add(item, origin_item_idx, info);
+ }
+ }
+
+ fn item_to_string(&self, idx: usize) -> String {
+ let r = item_to_string(&self.grammar, &self.items[idx]);
+ if self.definitive {
+ let props = &self.item_props[idx];
+ format!("{} {}", r, props)
+ } else {
+ r
}
}
}
@@ -235,15 +276,15 @@ impl Parser {
self.is_accepting
}
- fn item_to_string(&self, item: &Item) -> String {
- item_to_string(&self.grammar, item)
+ fn item_to_string(&self, idx: usize) -> String {
+ self.scratch.item_to_string(idx)
}
pub fn print_row(&self, row_idx: usize) {
let row = &self.rows[row_idx];
println!("row {}", row_idx);
for i in row.item_indices() {
- println!("{}", self.item_to_string(&self.scratch.items[i]));
+ println!("{}", self.item_to_string(i));
}
}
@@ -286,6 +327,14 @@ impl Parser {
self.grammar.sym_data(self.item_lhs(item))
}
+ pub fn hidden_start(&self) -> usize {
+ self.curr_row()
+ .item_indices()
+ .map(|i| self.scratch.item_props[i].hidden_start)
+ .min()
+ .unwrap_or(usize::MAX)
+ }
+
pub fn apply_tokens(
&mut self,
trie: &TokTrie,
@@ -355,7 +404,7 @@ impl Parser {
" remove: {}-{} {}",
self.token_idx,
start_token_idx,
- self.item_to_string(&item)
+ self.item_to_string(i)
);
continue;
}
@@ -489,7 +538,7 @@ impl Parser {
let idx = self.grammar.sym_idx_at(item.rule_idx()).as_index();
// idx == 0 => completed
if idx < allowed.len() && allowed[idx] {
- self.scratch.just_add(item.advance_dot(), i);
+ self.scratch.just_add(item.advance_dot(), i, "scan");
}
i += 1;
}
@@ -513,7 +562,7 @@ impl Parser {
let mut item = self.scratch.items[agenda_ptr];
agenda_ptr += 1;
if self.scratch.definitive {
- debug!(" agenda: {}", self.item_to_string(&item));
+ debug!(" agenda: {}", self.item_to_string(item_idx));
}
let rule = item.rule_idx();
@@ -570,12 +619,10 @@ impl Parser {
self.scratch.item_props[agenda_ptr - 1] =
self.scratch.item_props[item_idx].clone();
}
- // better keep item_idx updated in case we use it in future
item_idx = agenda_ptr - 1;
- let _ = item_idx; // silence warning
commit_item = item;
if self.scratch.definitive {
- debug!(" commit point: {}", self.item_to_string(&item));
+ debug!(" commit point: {}", self.item_to_string(item_idx));
if flags.hidden() {
return self.hide_item(lhs, item.start_pos());
}
@@ -601,6 +648,12 @@ impl Parser {
let new_item = Item::new(*rule, curr_idx);
self.scratch.add_unique(new_item, item_idx, "predict");
}
+ if self.scratch.definitive && sym_data.props.hidden {
+ for rule in &sym_data.rules {
+ let new_item = Item::new(*rule, curr_idx);
+ self.scratch.set_hidden_start(new_item, curr_idx);
+ }
+ }
}
}
diff --git a/controllers/guidance_ctrl/src/tokenparser.rs b/controllers/guidance_ctrl/src/tokenparser.rs
index 993c61b3..1666ebca 100644
--- a/controllers/guidance_ctrl/src/tokenparser.rs
+++ b/controllers/guidance_ctrl/src/tokenparser.rs
@@ -51,7 +51,8 @@ impl TokenParser {
if idx >= self.llm_bytes.len() {
return &[];
}
- &self.llm_bytes[idx..]
+ let endp = std::cmp::min(self.llm_bytes.len(), self.parser.hidden_start());
+ &self.llm_bytes[idx..endp]
}
pub fn process_prompt(&mut self, prompt: Vec) -> Vec {
@@ -103,13 +104,11 @@ impl TokenParser {
infoln!("post tokens: {}", trie.tokens_dbg(&arg.tokens));
arg.save_tokens(&mut self.llm_tokens);
- if arg.backtrack == 0 {
- let new_bytes = trie.decode(&arg.tokens);
- self.llm_bytes.extend_from_slice(&new_bytes);
- } else {
- // recompute on backtrack
- self.llm_bytes = trie.decode(&self.llm_tokens);
- }
+ let new_bytes = trie.decode(&arg.tokens);
+ self.llm_bytes.extend_from_slice(&new_bytes);
+
+ // TODO maybe remove in future
+ assert!(self.llm_bytes == trie.decode(&self.llm_tokens));
let res = self
.parser
@@ -123,6 +122,8 @@ impl TokenParser {
self.parser.force_bytes();
let grm_bytes = self.grm_bytes();
+ let mut backtrack = 0;
+
// now, see if we need to backtrack
if self.llm_bytes.len() > grm_bytes.len()
|| self.llm_bytes != grm_bytes[0..self.llm_bytes.len()]
@@ -132,22 +133,25 @@ impl TokenParser {
let b = trie.token(*t);
let pend = ptr + b.len();
if pend > grm_bytes.len() || b != &grm_bytes[ptr..pend] {
- let tokens = self.token_env.tokenize_bytes(&grm_bytes[ptr..]);
- let backtrack = self.llm_tokens.len() - idx;
+ backtrack = self.llm_tokens.len() - idx;
infoln!(
- "backtrack: {} tokens: {}",
+ "backtrack: {} (deletes: {:?})",
backtrack,
- trie.tokens_dbg(&tokens)
+ String::from_utf8_lossy(&self.llm_bytes[ptr..])
);
- return MidProcessResult::splice(backtrack as u32, tokens);
+ assert!(backtrack > 0);
+ self.llm_bytes.drain(ptr..);
+ break;
}
ptr = pend;
}
- panic!(
- "backtrack failed {:?} {:?}",
- String::from_utf8_lossy(&self.llm_bytes),
- String::from_utf8_lossy(&grm_bytes)
- );
+ if backtrack == 0 {
+ panic!(
+ "backtrack failed {:?} {:?}",
+ String::from_utf8_lossy(&self.llm_bytes),
+ String::from_utf8_lossy(&grm_bytes)
+ );
+ }
}
if arg.tokens.contains(&trie.eos_token()) {
@@ -157,7 +161,7 @@ impl TokenParser {
let new_forced = grm_bytes[self.llm_bytes.len()..].to_vec();
let mut token_prefix = Vec::new();
- if new_forced.len() > 0 {
+ if new_forced.len() > 0 || backtrack > 0 {
let mut grm_tokens = self.token_env.tokenize_bytes(&new_forced);
infoln!("forced: {}", trie.tokens_dbg(&grm_tokens));
let (chop_tokens, chop_bytes) = trie.chop_tokens(&mut self.parser, &grm_tokens);
@@ -167,7 +171,7 @@ impl TokenParser {
if grm_tokens.len() > 0 {
infoln!("fixed_tokens: {}", trie.tokens_dbg(&grm_tokens));
- return MidProcessResult::splice(0, grm_tokens);
+ return MidProcessResult::splice(backtrack as u32, grm_tokens);
} else {
infoln!("no fixed tokens");
}