Skip to content

Commit

Permalink
revamp eos handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jul 2, 2024
1 parent 99ffe1a commit daf9950
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 54 deletions.
19 changes: 10 additions & 9 deletions controllers/aici_abi/src/toktree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,15 +558,6 @@ impl TokTrie {
}
}
}
// all prefixes of 'start' are also allowed
if start.len() > 0 {
for len in 1..=start.len() {
let bytes = &start[0..len];
if let Some(tok) = self.token_id(bytes) {
logits.allow_token(tok);
}
}
}
self.add_bias(r, logits, start);
self.apply_duplicates(logits);
}
Expand Down Expand Up @@ -682,6 +673,16 @@ impl TokTrie {

#[inline(never)]
pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) {
// all prefixes of 'start' are also allowed
if start.len() > 0 {
for len in 1..=start.len() {
let bytes = &start[0..len];
if let Some(tok) = self.token_id(bytes) {
toks.allow_token(tok);
}
}
}

r.trie_started();
let n = self.child_at_bytes(self.root(), start).unwrap();
let defl_tok = self.vocab_size() as u32;
Expand Down
26 changes: 17 additions & 9 deletions controllers/llguidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,7 @@ def character_maker2(lm, id, description, valid_weapons):
)
)

grm = "6 * 7 = " + greedy_grammar(
body = lexeme("[0-9]{1,3}")
) + "\n"
grm = "6 * 7 = " + greedy_grammar(body=lexeme("[0-9]{1,3}")) + "\n"
# assert grm.match("6 * 7 = 42\n")

grm = (
Expand All @@ -203,23 +201,33 @@ def character_maker2(lm, id, description, valid_weapons):

grm = "6 * 7 = " + gen("name", max_tokens=2)

grm = "Name: " + gen('name', max_tokens=2) + " Height: " + gen('height', max_tokens=3)
grm = "Name: " + gen('name', max_tokens=2) + "Emily Carter is great; Height: " + gen('height', max_tokens=3)
grm = (
"Name: " + gen("name", max_tokens=2) + " Height: " + gen("height", max_tokens=3)
)
grm = (
"Name: "
+ gen("name", max_tokens=2)
+ "Emily Carter is great; Height: "
+ gen("height", max_tokens=3)
)

grm = "123" + gen(name="numbers", regex=r"\d*233", max_tokens=5)

grm = greedy_grammar(body=lexeme("[0-9]+"),skip_regex=r"\s*") + "x"
grm = (
"Here: 2 + 2 = "
+ greedy_grammar(body=lexeme("[0-9]+"), skip_regex=r"\s*")
+ "x"
)
grm = "Here: 2 + 2 = " + greedy_grammar(name="num", body=lexeme("[0-9]+"))

grm = "Here: 2 + 2 = " + guidance.json(name="num", schema={"type": "integer"})
# grm = "Here: 2 + 2 = " + guidance.json(name="num", schema={"type": "integer"})
# grm = guidance.json(name="num", schema={"type": "integer"})
# m = grm.match("123<s>")
# print(m)
# assert m["num"] == "123"

# grm = "Name: " + gen('name', max_tokens=2) + " Height: " + gen('height', max_tokens=3)



# g = zero_or_more("a") + "b"
# assert g.match("b")
# assert g.match("ab")
Expand Down
6 changes: 4 additions & 2 deletions controllers/llguidance_ctrl/src/earley/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::fmt::Debug;

use super::{
lexerspec::{LexemeIdx, LexerSpec},
regexvec::{RegexVec, StateDesc, NextByte},
regexvec::{NextByte, RegexVec, StateDesc},
};

const DEBUG: bool = true;
Expand Down Expand Up @@ -73,7 +73,9 @@ impl Lexer {
}

pub fn allows_eos(&mut self, state: StateID) -> bool {
self.state_info(state).is_accepting()
let mut l = self.spec.eos_ending_lexemes();
l.and(&self.state_info(state).accepting);
!l.is_zero()
}

pub fn limit_state_to(&mut self, state: StateID, allowed_lexemes: &SimpleVob) -> StateID {
Expand Down
13 changes: 13 additions & 0 deletions controllers/llguidance_ctrl/src/earley/lexerspec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct LexemeSpec {
name: String,
pub(crate) rx: RegexAst,
compiled_rx: ExprRef,
ends_at_eos: bool,
lazy: bool,
contextual: bool,
}
Expand Down Expand Up @@ -94,6 +95,16 @@ impl LexerSpec {
v
}

pub fn eos_ending_lexemes(&self) -> SimpleVob {
let mut v = self.alloc_lexeme_set();
for (idx, lex) in self.lexemes.iter().enumerate() {
if lex.ends_at_eos {
v.set(idx, true);
}
}
v
}

pub fn is_nullable(&self, idx: LexemeIdx) -> bool {
self.regex_builder
.is_nullable(self.lexemes[idx.0].compiled_rx)
Expand Down Expand Up @@ -138,6 +149,7 @@ impl LexerSpec {
compiled_rx: ExprRef::INVALID,
lazy: false,
contextual: false,
ends_at_eos: false,
}
}

Expand All @@ -157,6 +169,7 @@ impl LexerSpec {
name,
rx,
lazy,
ends_at_eos: !lazy,
..self.empty_spec()
})
}
Expand Down
50 changes: 27 additions & 23 deletions controllers/llguidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,17 @@ impl Parser {
pub fn compute_bias(&mut self, trie: &TokTrie, start: &[u8]) -> SimpleVob {
let mut set = trie.alloc_token_set();

trie.compute_bias_ext(self, &mut set, start);
trie.add_bias(self, &mut set, start);
trie.apply_duplicates(&mut set);

if set.num_set() == 1 && set.is_allowed(trie.eos_token()) {
if set.is_zero() {
// nothing allowed
// we're going to be stopped outside - we better flush the lexer
self.flush_lexer();
let _ = self.flush_lexer();
}

if start.is_empty() && self.lexer_allows_eos() {
set.allow_token(trie.eos_token());
}

set
Expand Down Expand Up @@ -846,6 +852,7 @@ impl Parser {
}

pub fn model_variables(&mut self) -> Vec<ModelVariable> {
// this can be used in future to allow "end-of-turn" special token and the like
self.run_speculative(|s| {
let mut vars = vec![];
if s.flush_lexer() {
Expand Down Expand Up @@ -887,6 +894,9 @@ impl Parser {
})
}

/// Advance the parser as if the current lexeme (if any)
/// finished right here.
/// Returns true if the parser was able to advance (or there were no pending bytes for a lexeme).
fn flush_lexer(&mut self) -> bool {
if !self.has_pending_lexeme_bytes() {
return true;
Expand Down Expand Up @@ -1493,26 +1503,20 @@ impl Recognizer for Parser {
self.last_collapse = self.num_rows();
}

fn special_allowed(&mut self, tok: SpecialToken) -> bool {
if false {
self.print_row(self.num_rows() - 1);
println!(
"model vars: accpt={} {:?}",
self.is_accepting(),
self.model_variables()
);
}

if self
.model_variables()
.contains(&ModelVariable::SpecialToken(tok))
{
true
} else if tok == SpecialToken::EndOfSentence {
self.is_accepting() || self.lexer_allows_eos()
} else {
false
}
fn special_allowed(&mut self, _tok: SpecialToken) -> bool {
// handle EOS logic outside
unreachable!("special_allowed")

// if self
// .model_variables()
// .contains(&ModelVariable::SpecialToken(tok))
// {
// true
// } else if tok == SpecialToken::EndOfSentence {
// self.is_accepting() || self.lexer_allows_eos()
// } else {
// false
// }
}

fn trie_started(&mut self) {
Expand Down
30 changes: 19 additions & 11 deletions controllers/llguidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ impl TokenParser {
}

if token_prefix.is_empty() {
if let Err(e) = self.maybe_gen_grammar() {
if let Err(e) = self.maybe_push_parser() {
warn!(self, "Error creating nested parser: {}", e);
return MidProcessResult::stop();
}
Expand All @@ -434,22 +434,27 @@ impl TokenParser {
// self.parser.print_row(self.parser.num_rows() - 1);
let mut set = self.parser.compute_bias(trie, &token_prefix);

if inner_done
|| self.max_tokens_parser == 0
|| (set.num_set() == 1 && set.is_allowed(trie.eos_token()))
{
if inner_done || self.max_tokens_parser == 0 {
if self.parser_stack.is_empty() {
self.mid_process_was_accepting = inner_accepting;
infoln!(self, "only eos token allowed, stopping; accepting: {}", inner_accepting);
infoln!(
self,
"only eos token allowed, stopping; accepting: {}",
inner_accepting
);
return MidProcessResult::stop();
} else {
infoln!(self, "pop_parser; tokens left {}", self.max_tokens_parser);
self.pop_parser();
// re-start the whole process with a nice tail-recursion
return self.mid_process_inner(MidProcessArg {
backtrack: 0,
tokens: Vec::new(),
fork_group: Vec::new(),
return self.mid_process_inner(if has_eos {
arg
} else {
MidProcessArg {
backtrack: 0,
tokens: Vec::new(),
fork_group: Vec::new(),
}
});
}
}
Expand Down Expand Up @@ -480,6 +485,9 @@ impl TokenParser {
self.pop_tokens = Some(pop_tokens);
}
self.mid_process_was_accepting = all_accepting;
if all_accepting {
set.allow_token(trie.eos_token());
}
}

infoln!(
Expand All @@ -499,7 +507,7 @@ impl TokenParser {
return MidProcessResult::sample_with_temp(set, Some(self.parser.temperature()));
}

fn maybe_gen_grammar(&mut self) -> Result<()> {
fn maybe_push_parser(&mut self) -> Result<()> {
if let Some((msg, symidx, gen_grammar)) = self.parser.maybe_gen_grammar() {
if msg.len() > 0 {
warn!(self, "{}", msg);
Expand Down
1 change: 1 addition & 0 deletions scripts/test-guidance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ fi
function runtest() {
pytest "$@"
if [ $? -ne 0 -a $? -ne 5 ] ; then
:
exit 1
fi
}
Expand Down

0 comments on commit daf9950

Please sign in to comment.