Skip to content

Commit

Permalink
handle hidden start
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed May 6, 2024
1 parent f49110f commit 0befe69
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 44 deletions.
3 changes: 2 additions & 1 deletion controllers/guidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def main():
grm = "Tweak this proverb to apply to model instructions instead.\n" + gen(
"verse", max_tokens=2
)
# grm = "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(")
grm = "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(")
grm = "<color>red</color>\n<color>" + gen(stop="</color>") + " and test2"

# read current script file
# with open(__file__) as f:
Expand Down
99 changes: 76 additions & 23 deletions controllers/guidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
use std::{fmt::Debug, hash::Hash, ops::Range, rc::Rc, vec};
use std::{
fmt::{Debug, Display},
hash::Hash,
ops::Range,
rc::Rc,
vec,
};

use aici_abi::{
toktree::{Recognizer, SpecialToken, TokTrie},
Expand All @@ -7,7 +13,7 @@ use aici_abi::{

use super::grammar::{CGrammar, CSymIdx, CSymbol, ModelVariable, RuleIdx};

const DEBUG: bool = false;
const DEBUG: bool = true;
const INFO: bool = true;

macro_rules! debug {
Expand Down Expand Up @@ -37,6 +43,16 @@ struct ItemProps {
hidden_start: usize,
}

impl Display for ItemProps {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.hidden_start == usize::MAX {
write!(f, "")
} else {
write!(f, "(hidden_start {})", self.hidden_start)
}
}
}

impl Default for ItemProps {
fn default() -> Self {
ItemProps {
Expand Down Expand Up @@ -167,7 +183,7 @@ impl Scratch {
}

#[inline(always)]
fn just_add(&mut self, item: Item, origin_item_idx: usize) {
fn just_add(&mut self, item: Item, origin_item_idx: usize, info: &str) {
self.ensure_items(self.row_end + 1);
// SAFETY: we just ensured that there is enough space
unsafe {
Expand All @@ -181,28 +197,53 @@ impl Scratch {
self.item_props[self.row_end] = ItemProps::default();
}
self.merge_item_origin(self.row_end, origin_item_idx);

debug!(
" addu: {} ({})",
self.item_to_string(self.row_end),
info
);
}
self.row_end += 1;
}

#[inline(always)]
fn add_unique(&mut self, item: Item, origin_item_idx: usize, info: &str) {
if let Some(idx) = self.items[self.row_start..self.row_end]
fn find_item(&self, item: Item) -> Option<usize> {
self.items[self.row_start..self.row_end]
.iter()
.position(|&x| x == item)
{
.map(|x| x + self.row_start)
}

fn set_hidden_start(&mut self, item: Item, hidden_start: usize) {
let idx = self.find_item(item).unwrap();
self.item_props[idx].hidden_start =
std::cmp::min(self.item_props[idx].hidden_start, hidden_start);
debug!(
" hidden: {} {}",
hidden_start,
self.item_to_string(idx),
);
}

#[inline(always)]
fn add_unique(&mut self, item: Item, origin_item_idx: usize, info: &str) {
if let Some(idx) = self.find_item(item) {
if self.definitive {
self.merge_item_origin(idx, origin_item_idx);
}
} else {
if self.definitive {
debug!(
" addu: {} ({})",
item_to_string(&self.grammar, &item),
info
);
}
self.just_add(item, origin_item_idx);
self.just_add(item, origin_item_idx, info);
}
}

fn item_to_string(&self, idx: usize) -> String {
let r = item_to_string(&self.grammar, &self.items[idx]);
if self.definitive {
let props = &self.item_props[idx];
format!("{} {}", r, props)
} else {
r
}
}
}
Expand Down Expand Up @@ -235,15 +276,15 @@ impl Parser {
self.is_accepting
}

fn item_to_string(&self, item: &Item) -> String {
item_to_string(&self.grammar, item)
fn item_to_string(&self, idx: usize) -> String {
self.scratch.item_to_string(idx)
}

pub fn print_row(&self, row_idx: usize) {
let row = &self.rows[row_idx];
println!("row {}", row_idx);
for i in row.item_indices() {
println!("{}", self.item_to_string(&self.scratch.items[i]));
println!("{}", self.item_to_string(i));
}
}

Expand Down Expand Up @@ -286,6 +327,14 @@ impl Parser {
self.grammar.sym_data(self.item_lhs(item))
}

pub fn hidden_start(&self) -> usize {
self.curr_row()
.item_indices()
.map(|i| self.scratch.item_props[i].hidden_start)
.min()
.unwrap_or(usize::MAX)
}

pub fn apply_tokens(
&mut self,
trie: &TokTrie,
Expand Down Expand Up @@ -355,7 +404,7 @@ impl Parser {
" remove: {}-{} {}",
self.token_idx,
start_token_idx,
self.item_to_string(&item)
self.item_to_string(i)
);
continue;
}
Expand Down Expand Up @@ -489,7 +538,7 @@ impl Parser {
let idx = self.grammar.sym_idx_at(item.rule_idx()).as_index();
// idx == 0 => completed
if idx < allowed.len() && allowed[idx] {
self.scratch.just_add(item.advance_dot(), i);
self.scratch.just_add(item.advance_dot(), i, "scan");
}
i += 1;
}
Expand All @@ -513,7 +562,7 @@ impl Parser {
let mut item = self.scratch.items[agenda_ptr];
agenda_ptr += 1;
if self.scratch.definitive {
debug!(" agenda: {}", self.item_to_string(&item));
debug!(" agenda: {}", self.item_to_string(item_idx));
}

let rule = item.rule_idx();
Expand Down Expand Up @@ -570,12 +619,10 @@ impl Parser {
self.scratch.item_props[agenda_ptr - 1] =
self.scratch.item_props[item_idx].clone();
}
// better keep item_idx updated in case we use it in future
item_idx = agenda_ptr - 1;
let _ = item_idx; // silence warning
commit_item = item;
if self.scratch.definitive {
debug!(" commit point: {}", self.item_to_string(&item));
debug!(" commit point: {}", self.item_to_string(item_idx));
if flags.hidden() {
return self.hide_item(lhs, item.start_pos());
}
Expand All @@ -601,6 +648,12 @@ impl Parser {
let new_item = Item::new(*rule, curr_idx);
self.scratch.add_unique(new_item, item_idx, "predict");
}
if self.scratch.definitive && sym_data.props.hidden {
for rule in &sym_data.rules {
let new_item = Item::new(*rule, curr_idx);
self.scratch.set_hidden_start(new_item, curr_idx);
}
}
}
}

Expand Down
44 changes: 24 additions & 20 deletions controllers/guidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ impl TokenParser {
if idx >= self.llm_bytes.len() {
return &[];
}
&self.llm_bytes[idx..]
let endp = std::cmp::min(self.llm_bytes.len(), self.parser.hidden_start());
&self.llm_bytes[idx..endp]
}

pub fn process_prompt(&mut self, prompt: Vec<TokenId>) -> Vec<TokenId> {
Expand Down Expand Up @@ -103,13 +104,11 @@ impl TokenParser {
infoln!("post tokens: {}", trie.tokens_dbg(&arg.tokens));
arg.save_tokens(&mut self.llm_tokens);

if arg.backtrack == 0 {
let new_bytes = trie.decode(&arg.tokens);
self.llm_bytes.extend_from_slice(&new_bytes);
} else {
// recompute on backtrack
self.llm_bytes = trie.decode(&self.llm_tokens);
}
let new_bytes = trie.decode(&arg.tokens);
self.llm_bytes.extend_from_slice(&new_bytes);

// TODO maybe remove in future
assert!(self.llm_bytes == trie.decode(&self.llm_tokens));

let res = self
.parser
Expand All @@ -123,6 +122,8 @@ impl TokenParser {
self.parser.force_bytes();
let grm_bytes = self.grm_bytes();

let mut backtrack = 0;

// now, see if we need to backtrack
if self.llm_bytes.len() > grm_bytes.len()
|| self.llm_bytes != grm_bytes[0..self.llm_bytes.len()]
Expand All @@ -132,22 +133,25 @@ impl TokenParser {
let b = trie.token(*t);
let pend = ptr + b.len();
if pend > grm_bytes.len() || b != &grm_bytes[ptr..pend] {
let tokens = self.token_env.tokenize_bytes(&grm_bytes[ptr..]);
let backtrack = self.llm_tokens.len() - idx;
backtrack = self.llm_tokens.len() - idx;
infoln!(
"backtrack: {} tokens: {}",
"backtrack: {} (deletes: {:?})",
backtrack,
trie.tokens_dbg(&tokens)
String::from_utf8_lossy(&self.llm_bytes[ptr..])
);
return MidProcessResult::splice(backtrack as u32, tokens);
assert!(backtrack > 0);
self.llm_bytes.drain(ptr..);
break;
}
ptr = pend;
}
panic!(
"backtrack failed {:?} {:?}",
String::from_utf8_lossy(&self.llm_bytes),
String::from_utf8_lossy(&grm_bytes)
);
if backtrack == 0 {
panic!(
"backtrack failed {:?} {:?}",
String::from_utf8_lossy(&self.llm_bytes),
String::from_utf8_lossy(&grm_bytes)
);
}
}

if arg.tokens.contains(&trie.eos_token()) {
Expand All @@ -157,7 +161,7 @@ impl TokenParser {
let new_forced = grm_bytes[self.llm_bytes.len()..].to_vec();
let mut token_prefix = Vec::new();

if new_forced.len() > 0 {
if new_forced.len() > 0 || backtrack > 0 {
let mut grm_tokens = self.token_env.tokenize_bytes(&new_forced);
infoln!("forced: {}", trie.tokens_dbg(&grm_tokens));
let (chop_tokens, chop_bytes) = trie.chop_tokens(&mut self.parser, &grm_tokens);
Expand All @@ -167,7 +171,7 @@ impl TokenParser {

if grm_tokens.len() > 0 {
infoln!("fixed_tokens: {}", trie.tokens_dbg(&grm_tokens));
return MidProcessResult::splice(0, grm_tokens);
return MidProcessResult::splice(backtrack as u32, grm_tokens);
} else {
infoln!("no fixed tokens");
}
Expand Down

0 comments on commit 0befe69

Please sign in to comment.