Skip to content

Commit

Permalink
max_tokens fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 28, 2024
1 parent 6a31aa6 commit c03a0f0
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 35 deletions.
10 changes: 10 additions & 0 deletions controllers/derivre/src/regexvec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,16 @@ impl RegexVec {
next_byte
}

pub fn limit_state_to(&mut self, state: StateID, allowed_lexemes: &SimpleVob) -> StateID {
let mut vec_desc = vec![];
for (idx, e) in iter_state(&self.rx_sets, state) {
if allowed_lexemes.get(idx) {
Self::push_rx(&mut vec_desc, idx, e);
}
}
self.insert_state(vec_desc)
}

pub fn total_fuel_spent(&self) -> usize {
self.exprs.cost
}
Expand Down
9 changes: 8 additions & 1 deletion controllers/llguidance_ctrl/run_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ def character_maker2(lm, id, description, valid_weapons):
grm = gen(regex="a*")
grm = "6 * 7 = " + gen(regex="5*") + gen(regex="[1-4][0-9]") + "\n"

grm = "6 * 7 = " + gen("name", max_tokens=2)

grm = "Name: " + gen('name', max_tokens=2) + " Height: " + gen('height', max_tokens=3)

# grm = "Name: " + gen('name', max_tokens=2) + " Height: " + gen('height', max_tokens=3)



# g = zero_or_more("a") + "b"
# assert g.match("b")
Expand All @@ -214,7 +221,7 @@ def character_maker2(lm, id, description, valid_weapons):
# body = lexeme("[0-9]+")
# )

max_tokens = 250
max_tokens = 50

serialized = grm.ll_serialize()

Expand Down
5 changes: 4 additions & 1 deletion controllers/llguidance_ctrl/src/earley/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ impl Grammar {
return Ok(());
}
if lexer_spec.is_nullable(lex) {
let wrap = self.fresh_symbol(format!("rx_null_{}", self.sym_name(lhs)).as_str());
let wrap = self.fresh_symbol_ext(
format!("rx_null_{}", self.sym_name(lhs)).as_str(),
self.sym_data(lhs).props.for_wrapper(),
);
self.sym_data_mut(wrap).lexeme = Some(lex);
self.add_rule(lhs, vec![wrap])?;
self.add_rule(lhs, vec![])?;
Expand Down
8 changes: 8 additions & 0 deletions controllers/llguidance_ctrl/src/earley/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ impl Lexer {
self.state_info(state).is_accepting()
}

pub fn limit_state_to(&mut self, state: StateID, allowed_lexemes: &SimpleVob) -> StateID {
self.dfa.limit_state_to(state, allowed_lexemes)
}

pub fn possible_lexemes(&self, state: StateID) -> &SimpleVob {
&self.state_info(state).possible
}

pub fn force_lexeme_end(&self, prev: StateID) -> LexerResult {
let info = self.state_info(prev);
match info.possible.first_bit_set() {
Expand Down
74 changes: 62 additions & 12 deletions controllers/llguidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
collections::HashMap,
fmt::{Debug, Display},
hash::Hash,
ops::Range,
Expand Down Expand Up @@ -154,7 +155,7 @@ struct RowInfo {
lexeme: Lexeme,
token_idx_start: usize,
token_idx_stop: usize,
max_tokens: usize,
max_tokens: HashMap<LexemeIdx, usize>,
}

impl RowInfo {
Expand All @@ -169,10 +170,10 @@ impl RowInfo {
self.token_idx_start,
self.token_idx_stop,
lexspec.dbg_lexeme(&self.lexeme),
if self.max_tokens == usize::MAX {
if self.max_tokens.is_empty() {
"".to_string()
} else {
format!("max_tokens={}", self.max_tokens)
format!("max_tokens={:?}", self.max_tokens)
}
)
}
Expand Down Expand Up @@ -614,10 +615,37 @@ impl Parser {
0,
self.token_idx as isize + 1 - info.token_idx_start as isize,
) as usize;
if info_tokens >= info.max_tokens {
debug!(" max_tokens reached; {}", info.dbg(self.lexer_spec()));
if !self.try_push_byte_definitive(None) {
return Ok("parse reject on max_tokens");
let lex_state = self.lexer_state().lexer_state;
let mut limit = trie.alloc_token_set();
let mut num_limit = 0;
for idx in self.lexer.possible_lexemes(lex_state).iter() {
let lex = LexemeIdx::new(idx as usize);
let max_tokens = *info.max_tokens.get(&lex).unwrap_or(&usize::MAX);
trace!(
" max_tokens: {} max={} info={}",
self.lexer_spec().dbg_lexeme(&Lexeme::just_idx(lex)),
max_tokens,
info_tokens
);
if info_tokens < max_tokens {
limit.allow_token(idx);
} else {
num_limit += 1;
}
}
if num_limit > 0 {
debug!(
" max_tokens limiting to: {}",
self.lexer_spec().dbg_lexeme_set(&limit)
);
let new_state = self.lexer.limit_state_to(lex_state, &limit);
if new_state.is_dead() {
debug!(" limited everything; forcing EOI");
if !self.try_push_byte_definitive(None) {
return Ok("parse reject on max_tokens");
}
} else {
self.lexer_stack.last_mut().unwrap().lexer_state = new_state;
}
}
}
Expand Down Expand Up @@ -678,7 +706,7 @@ impl Parser {
start_byte_idx: 0,
token_idx_start: self.token_idx,
token_idx_stop: self.token_idx,
max_tokens: usize::MAX,
max_tokens: HashMap::default(),
});

for idx in 0..self.num_rows() {
Expand Down Expand Up @@ -1016,7 +1044,8 @@ impl Parser {
// with agenda pointer above
self.rows[added_row_idx].allowed_lexemes = allowed_lexemes;
if self.scratch.definitive {
self.row_infos[added_row_idx].max_tokens = self.row_infos[added_row_idx - 1].max_tokens;
self.row_infos[added_row_idx].max_tokens =
self.row_infos[added_row_idx - 1].max_tokens.clone();
}
true
}
Expand Down Expand Up @@ -1061,7 +1090,7 @@ impl Parser {
#[inline(always)]
fn push_row(&mut self, curr_idx: usize, mut agenda_ptr: usize, lexeme: &Lexeme) -> bool {
let mut allowed_lexemes = SimpleVob::alloc(self.grammar.num_terminals());
let mut max_tokens = 0;
let mut max_tokens = vec![];

while agenda_ptr < self.scratch.row_end {
let item_idx = agenda_ptr;
Expand Down Expand Up @@ -1129,7 +1158,9 @@ impl Parser {
let sym_data = self.grammar.sym_data(after_dot);
if let Some(lx) = self.grammar.lexeme_idx_of(after_dot) {
allowed_lexemes.set(lx.as_usize(), true);
max_tokens = max_tokens.max(sym_data.props.max_tokens);
if self.scratch.definitive {
max_tokens.push((lx, sym_data.props.max_tokens));
}
}
if sym_data.is_nullable {
self.scratch
Expand Down Expand Up @@ -1185,12 +1216,31 @@ impl Parser {
if self.row_infos.len() > idx {
self.row_infos.drain(idx..);
}
let mut max_tokens_map = HashMap::default();
for (lx, mx) in max_tokens {
if let Some(ex) = max_tokens_map.get(&lx) {
if *ex < mx {
max_tokens_map.insert(lx, mx);
}
} else {
max_tokens_map.insert(lx, mx);
}
}
let mut to_remove = vec![];
for (lx, mx) in max_tokens_map.iter() {
if *mx == usize::MAX {
to_remove.push(*lx);
}
}
for lx in to_remove {
max_tokens_map.remove(&lx);
}
self.row_infos.push(RowInfo {
lexeme: Lexeme::bogus(),
token_idx_start: self.token_idx,
token_idx_stop: self.token_idx,
start_byte_idx: self.byte_idx,
max_tokens,
max_tokens: max_tokens_map,
});
// debug!(" push: {idx} {} {}", self.rows.len(), self.row_infos.len());
}
Expand Down
38 changes: 20 additions & 18 deletions controllers/llguidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,26 +455,28 @@ impl TokenParser {

if inner_accepting {
let mut all_accepting = true;
let mut pop_tokens = trie.alloc_token_set();
for pentry in self.parser_stack.iter_mut() {
if pentry.mask.is_none() {
assert!(token_prefix.is_empty());
let mask = pentry
.parser
.compute_bias_after_gen_grammar(trie, pentry.symidx);
infoln!(self, "bias for upper parser: {}", trie.token_set_dbg(&mask));
pentry.mask = Some(mask);
}
let m = pentry.mask.as_ref().unwrap();
pop_tokens.or_minus(m, &set);
set.or(m);
if !pentry.is_accepting {
all_accepting = false;
break;
if self.parser_stack.len() > 0 {
let mut pop_tokens = trie.alloc_token_set();
for pentry in self.parser_stack.iter_mut() {
if pentry.mask.is_none() {
assert!(token_prefix.is_empty());
let mask = pentry
.parser
.compute_bias_after_gen_grammar(trie, pentry.symidx);
infoln!(self, "bias for upper parser: {}", trie.token_set_dbg(&mask));
pentry.mask = Some(mask);
}
let m = pentry.mask.as_ref().unwrap();
pop_tokens.or_minus(m, &set);
set.or(m);
if !pentry.is_accepting {
all_accepting = false;
break;
}
}
infoln!(self, "pop_tokens: {}", trie.token_set_dbg(&pop_tokens));
self.pop_tokens = Some(pop_tokens);
}
infoln!(self, "pop_tokens: {}", trie.token_set_dbg(&pop_tokens));
self.pop_tokens = Some(pop_tokens);
self.mid_process_was_accepting = all_accepting;
}

Expand Down
2 changes: 1 addition & 1 deletion py/guidance
12 changes: 10 additions & 2 deletions scripts/test-guidance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@ export AZURE_GUIDANCE_URL

FILES="tests/need_credentials/test_azure_guidance.py tests/model_integration/test_greedy.py"

cd $(dirname $0)/../py/guidance

if [ "X$1" != "X" ] ; then
if [ "X${1:0:2}" = "X::" ] ; then
FILES="tests/models/test_azure_guidance.py$1"
FILES="tests/need_credentials/test_azure_guidance.py$1"
shift
pytest --selected_model azure_guidance --durations=10 $FILES "$@"
exit $?
fi
fi

cd $(dirname $0)/../py/guidance
set -e
# quick tests first
pytest tests/unit/test_ll.py
pytest tests/unit
pytest --selected_model azure_guidance --durations=10 $FILES "$@"
pytest tests/model_integration

0 comments on commit c03a0f0

Please sign in to comment.