From be3077cfbb78cc8230e7f9dd67e23e22b6d7c9d5 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Tue, 11 Jun 2024 16:58:55 +0000 Subject: [PATCH] nested parsing fixes --- controllers/ag2_ctrl/run_g.py | 4 +- controllers/ag2_ctrl/src/earley/parser.rs | 48 +++++++++++++++-------- controllers/ag2_ctrl/src/tokenparser.rs | 13 +++++- py/guidance | 2 +- 4 files changed, 47 insertions(+), 20 deletions(-) diff --git a/controllers/ag2_ctrl/run_g.py b/controllers/ag2_ctrl/run_g.py index c36f6d63..b9d59bb2 100644 --- a/controllers/ag2_ctrl/run_g.py +++ b/controllers/ag2_ctrl/run_g.py @@ -218,8 +218,8 @@ def character_maker2(lm, id, description, valid_weapons): prompt = "How much is 2 + 2? " grm = gen(name="test", max_tokens=30, regex=r"\d+") - prompt = "About J. Random Hacker:\n" - grm = gen_json_object("hacker", max_tokens=50) + "\nScore (0-9): " + gen("score", regex=r"[0-9]") + prompt = "Three things about J. Random Hacker:\n" + grm = gen_json_object("hacker", max_tokens=150) + "\nScore (0-9): " + gen("score", regex=r"[0-9]") # grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=20) + "\n" diff --git a/controllers/ag2_ctrl/src/earley/parser.rs b/controllers/ag2_ctrl/src/earley/parser.rs index 0d02a278..57cb8e91 100644 --- a/controllers/ag2_ctrl/src/earley/parser.rs +++ b/controllers/ag2_ctrl/src/earley/parser.rs @@ -21,7 +21,7 @@ use super::{ lexerspec::{Lexeme, LexemeIdx, LexerSpec}, }; -const TRACE: bool = true; +const TRACE: bool = false; const DEBUG: bool = true; const INFO: bool = true; @@ -365,13 +365,34 @@ impl Parser { &self.grammar } + fn after_dots(&self) -> impl Iterator + '_ { + self.curr_row() + .item_indices() + .map(|i| self.scratch.items[i].rule_idx()) + } + + fn after_dots_symdata(&self) -> impl Iterator + '_ { + self.after_dots().map(|pos| self.grammar.sym_data_at(pos)) + } + + pub fn can_advance(&self) -> bool { + let skip = self.grammar.lexeme_to_sym_idx(LexemeIdx::SKIP); + for data in self.after_dots_symdata() { + if data.idx == skip || data.idx == CSymIdx::NULL { + continue; + } + if data.is_terminal || data.gen_grammar.is_some() { + return true; + } + } + false + } + pub fn is_accepting(&self) -> bool { - for idx in self.curr_row().item_indices() { - let item = self.scratch.items[idx]; - let rule = item.rule_idx(); - let after_dot = self.grammar.sym_idx_at(rule); + for pos in self.after_dots() { + let after_dot = self.grammar.sym_idx_at(pos); if after_dot == CSymIdx::NULL { - let lhs = self.grammar.sym_idx_of(item.rule_idx()); + let lhs = self.grammar.sym_idx_of(pos); if lhs == self.grammar.start() { return true; } @@ -498,9 +519,7 @@ impl Parser { pub fn temperature(&self) -> f32 { let mut temp = 0.0f32; - for i in self.curr_row().item_indices() { - let item = self.scratch.items[i]; - let data = self.grammar.sym_data_at(item.rule_idx()); + for data in self.after_dots_symdata() { if data.is_terminal { temp = temp.max(data.props.temperature); } @@ -727,9 +746,7 @@ impl Parser { pub fn model_variables(&self) -> Vec { let mut vars = vec![]; - for i in self.curr_row().item_indices() { - let item = self.scratch.items[i]; - let sym_data = self.grammar.sym_data_at(item.rule_idx()); + for sym_data in self.after_dots_symdata() { if let Some(ref mv) = sym_data.props.model_variable { if !vars.contains(mv) { vars.push(mv.clone()); @@ -777,10 +794,9 @@ impl Parser { let mut res: Option = None; let mut res_idx = None; let mut gen_grm = vec![]; - for i in self.curr_row().item_indices() { - let item = self.scratch.items[i]; - let idx = self.grammar.sym_idx_at(item.rule_idx()); - let sym_data = self.grammar.sym_data_at(item.rule_idx()); + for pos in self.after_dots() { + let idx = self.grammar.sym_idx_at(pos); + let sym_data = self.grammar.sym_data_at(pos); if let Some(ref gg) = sym_data.gen_grammar { // break ties by preferring the one with the lowest grammar number if res.is_none() || res.as_ref().unwrap().grammar.0 > gg.grammar.0 { diff --git a/controllers/ag2_ctrl/src/tokenparser.rs b/controllers/ag2_ctrl/src/tokenparser.rs index 8a479d33..d316d34c 100644 --- a/controllers/ag2_ctrl/src/tokenparser.rs +++ b/controllers/ag2_ctrl/src/tokenparser.rs @@ -293,6 +293,14 @@ impl TokenParser { } } + let inner_done = { + let is_accepting = self.parser.is_accepting(); + let can_advance = self.parser.can_advance(); + let inner_done = is_accepting && !can_advance; + infoln!("inner_done: {inner_done}; can_advance: {can_advance}; accept: {is_accepting}"); + inner_done + }; + let trie = self.token_env.tok_trie(); let mut set = trie.alloc_token_set(); // self.parser.print_row(self.parser.num_rows() - 1); @@ -303,7 +311,10 @@ impl TokenParser { set.disallow_token(self.first_token_of_eos_marker); } - if self.max_tokens_parser == 0 || (set.num_set() == 1 && set.is_allowed(trie.eos_token())) { + if inner_done + || self.max_tokens_parser == 0 + || (set.num_set() == 1 && set.is_allowed(trie.eos_token())) + { if self.parser_stack.is_empty() { infoln!("only eos token allowed, stopping"); return MidProcessResult::stop(); diff --git a/py/guidance b/py/guidance index 0db8c1b6..46af3869 160000 --- a/py/guidance +++ b/py/guidance @@ -1 +1 @@ -Subproject commit 0db8c1b6fb839fdb47a86b20e9133ce848f5b79c +Subproject commit 46af38691fdc3259d887b55df3085ca74fed7eb2