From dd14b0d6a04843c8912a2ae8cb121451836e33bd Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 28 Jun 2024 00:45:01 +0000 Subject: [PATCH] add LLTokenizer.test_trace_tokens --- .vscode/settings.json | 7 +- controllers/aici_abi/src/toktree.rs | 31 ++- controllers/llguidance_ctrl/run_g.py | 250 +++++------------- controllers/llguidance_ctrl/src/api.rs | 2 + .../llguidance_ctrl/src/tokenparser.rs | 54 +++- py/guidance | 2 +- py/llguidance/python/llguidance/_lib.pyi | 6 + py/llguidance/rust/py.rs | 4 + 8 files changed, 160 insertions(+), 196 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 1c9f480f..7b23c84f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -154,5 +154,10 @@ "files.readonlyInclude": { "**/dist/*": true, "**/aici-types.d.ts": true - } + }, + "python.testing.pytestArgs": [ + "py/guidance" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } \ No newline at end of file diff --git a/controllers/aici_abi/src/toktree.rs b/controllers/aici_abi/src/toktree.rs index 33d5f656..dd846d8e 100644 --- a/controllers/aici_abi/src/toktree.rs +++ b/controllers/aici_abi/src/toktree.rs @@ -241,30 +241,35 @@ impl TokTrie { vec![0.0; self.vocab_size() + 1] } + pub fn test_trace_tokens(&self, toks: &[u32]) -> String { + toks.iter() + .map(|t| { + let s = self.token_dbg(*t); + if s.starts_with("\"") { + self.token_str(*t) + } else { + format!("≺{}≻", s) + } + }) + .collect::>() + .join("‧") + } + pub fn tokens_dbg(&self, toks: &[u32]) -> String { - let minimal = false; - let sep = "‧"; let joined = toks .iter() .map(|t| { let s = self.token_dbg(*t); if s.starts_with("\"") { - let inner = s[1..s.len() - 1].to_string(); - let b = s.as_bytes(); - // for " [\w]..." and " " the sep in front is implicit - if minimal && b[1] == b' ' && ((b[2] as char).is_alphanumeric() || b.len() == 3) - { - inner - } else { - format!("{}{}", sep, inner) - } + s[1..s.len() - 1].to_string() } else { format!("≺{}≻", s) } }) .collect::>() - .join(""); - format!("\"{}\"", joined.trim_start_matches(sep)) + .join("‧"); + + format!("\"{}\"", joined) } pub fn token_dbg(&self, idx: u32) -> String { diff --git a/controllers/llguidance_ctrl/run_g.py b/controllers/llguidance_ctrl/run_g.py index 643a9d54..707754fb 100644 --- a/controllers/llguidance_ctrl/run_g.py +++ b/controllers/llguidance_ctrl/run_g.py @@ -3,7 +3,7 @@ import base64 import ujson as json import binascii - +import os import guidance from guidance import ( @@ -24,146 +24,7 @@ ) -@guidance(stateless=True) -def number(lm): - n = one_or_more(select(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"])) - return lm + select(["-" + n, n]) - - -@guidance(stateless=True) -def identifier(lm): - letter = select([byte_range(b"a", b"z"), byte_range(b"A", b"Z"), "_"]) - num = byte_range(b"0", b"9") - return lm + letter + zero_or_more(select([letter, num])) - - -@guidance(stateless=True) -def assignment_stmt(lm): - return lm + identifier() + " = " + expression() - - -@guidance(stateless=True) -def while_stmt(lm): - return lm + "while " + expression() + ":" + stmt() - - -@guidance(stateless=True) -def stmt(lm): - return lm + select([assignment_stmt(), while_stmt()]) - - -@guidance(stateless=True) -def operator(lm): - return lm + select(["+", "*", "**", "/", "-"]) - - -@guidance(stateless=True) -def expression(lm): - return lm + select( - [ - identifier(), - expression() - + zero_or_more(" ") - + operator() - + zero_or_more(" ") - + expression(), - "(" + expression() + ")", - ] - ) - - -@guidance(stateless=True) -def json_string(lm): - return lm + lexeme(r'"(\\(["\\\/bfnrt]|u[a-fA-F0-9]{4})|[^"\\\x00-\x1F\x7F]+)*"') - - -@guidance(stateless=True) -def json_number(lm): - return lm + lexeme(r"-?(?:0|[1-9][0-9]*)(?:\.[0-9]+)?(?:[eE][+-]?[0-9]+)?") - - -@guidance(stateless=True) -def json_value(lm): - return lm + select( - [ - json_string(), - json_number(), - json_object(), - json_array(), - "true", - "false", - "null", - ] - ) - - -@guidance(stateless=True) -def json_member(lm): - return lm + json_string() + ":" + json_value() - - -@guidance(stateless=True) -def json_object(lm): - return lm + "{" + optional(json_member() + one_or_more("," + json_member())) + "}" - - -@guidance(stateless=True) -def json_array(lm): - return lm + "[" + optional(json_value() + one_or_more("," + json_value())) + "]" - - -@guidance(stateless=True) -def gen_json_object(lm, name: str, max_tokens=100000000): - grm = greedy_grammar(body=json_object(), skip_regex=r"[\x20\x0A\x0D\x09]+") - return lm + grm - - def main(): - grm = ( - "Here's a sample arithmetic expression: " - + capture(expression(), "expr") - + " = " - + capture(number(), "num") - ) - grm = ( - "Parallel lines have so much in common. It’s a shame they’ll never meet.\nScore: 8/10\n" - + "" - + capture(gen(regex=r"[A-Z\(].*", max_tokens=50, stop=""), "joke") - + "\nScore: " - + capture(gen(regex=r"\d{1,3}"), "score") - + "/10\n" - ) - grm = "this is a test" + gen("test", max_tokens=10) - grm = "Tweak this proverb to apply to model instructions instead.\n" + gen( - "verse", max_tokens=2 - ) - grm = "How much is 2 + 2? " + gen(name="test", max_tokens=10, regex=r"\(") - grm = "red\n" + gen(stop="") + " and test2" - - lm = "Here's a " - lm += select(["joke", "poem"], name="type") - lm += ": " - lm += gen("words", regex=r"[A-Z ]+", stop="\n") - grm = lm - - @guidance(stateless=True, dedent=True) - def character_maker(lm, id, description, valid_weapons): - lm += f"""\ - The following is a character profile for an RPG game in JSON format. - ```json - {{ - "id": "{id}", - "description": "{description}", - "name": "{gen('name', stop='"')}", - "age": {gen('age', regex='[0-9]+', stop=',')}, - "armor": "{select(options=['leather', 'chainmail', 'plate'], name='armor')}", - "weapon": "{select(options=valid_weapons, name='weapon')}", - "class": "{gen('class', stop='"')}", - "mantra": "{gen('mantra', stop='"')}", - "strength": {gen('strength', regex='[0-9]+', stop=',')}, - "items": ["{gen('item', list_append=True, stop='"')}", "{gen('item', list_append=True, stop='"')}", "{gen('item', list_append=True, stop='"')}"] - }}```""" - return lm @guidance(stateless=True, dedent=True) def character_maker2(lm, id, description, valid_weapons): @@ -234,13 +95,6 @@ def character_maker2(lm, id, description, valid_weapons): ) ) - prompt = "Three things about J. Random Hacker:\n" - grm = ( - gen_json_object("hacker", max_tokens=150) - + "\nScore (0-9): " - + gen("score", regex=r"[0-9]") - ) - grm = character_maker2(1, "A nimble fighter", ["axe", "sword", "bow"]) prompt = "" @@ -280,6 +134,12 @@ def character_maker2(lm, id, description, valid_weapons): prompt = "" grm = optional("A") + grm = one_or_more(gen(regex="[a-z]")) + grm = "A odd number is " + gen( + "number", regex="[0-9]+", max_tokens=5, temperature=0 + ) + + grm = ( "Q: Are dolphins fish?\nA: " + gen("dolphins", regex="Yes|No", max_tokens=10) @@ -287,9 +147,20 @@ def character_maker2(lm, id, description, valid_weapons): + gen("sharks", regex="Yes|No", max_tokens=10) ) - grm = one_or_more(gen(regex="[a-z]")) + grm = ( + "Power frequency is " + + gen("number", regex="[0-9]+", max_tokens=5, temperature=0) + + "Hz; voltage is " + + gen("number", regex="[0-9]+", max_tokens=5, temperature=0) + + "V" + ) + + grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=5) + + grm = "Dolphin name: " + commit_point( + '"' + byte_range(b"A", b"Z") + one_or_more(byte_range(b"a", b"z")) + '"' + ) + "," - # grm = "Q: 7 * 8\nA: " + gen("text", regex="[0-9]+", max_tokens=5) # g = zero_or_more("a") + "b" # assert g.match("b") @@ -299,13 +170,17 @@ def character_maker2(lm, id, description, valid_weapons): # grammar = one_or_more(select(["1", "2"])) # lm += grammar + # grm = greedy_grammar( + # body = lexeme("[0-9]+") + # ) + max_tokens = 250 serialized = grm.ll_serialize() # with open("tmp/long_json_grammar_req.json", "r") as f: - # with open("tmp/email_regex_grammar.json", "r") as f: - # max_tokens = 2000 + # # with open("tmp/email_regex_grammar.json", "r") as f: + # max_tokens = 1000 # serialized = json.load(f) x_serialized = { @@ -325,35 +200,8 @@ def character_maker2(lm, id, description, valid_weapons): ] } - x_serialized = { - "grammars": [ - { - "greedy_lexer": False, - "nodes": [ - { - "GenGrammar": { - "grammar": 1, - "stop_rx": "", - "no_initial_skip": True, - "temperature": 0.0, - } - } - ], - "rx_nodes": [], - }, - { - "greedy_lexer": True, - "greedy_skip_rx": "[\\x20\\x0A\\x0D\\x09]+", - "nodes": [ - {"Lexeme": {"rx": "-?(?:0|[1-9][0-9]*)", "contextual": False}} - # {"Lexeme": {"rx": "[ab][ab]", "contextual": False}} - ], - "rx_nodes": [], - }, - ] - } - serialized["max_tokens"] = max_tokens + serialized["test_trace"] = True llguidance_json = {"grammar": serialized} llguidance_arg = json.dumps(llguidance_json, indent=1) @@ -371,7 +219,10 @@ def character_maker2(lm, id, description, valid_weapons): # script = f.read() # grm = "```python\n" + substring(script[0:1400]) - mod_id = pyaici.cli.build_rust(".", features=["logging"]) + features = ["logging"] + if "FAST" in os.environ: + features = [] + mod_id = pyaici.cli.build_rust(".", features=features) if "127.0.0.1" in pyaici.rest.base_url: pyaici.rest.tag_module(mod_id, ["llguidance_ctrl-latest", "llguidance"]) pyaici.rest.log_level = 2 @@ -388,6 +239,8 @@ def character_maker2(lm, id, description, valid_weapons): print("Storage:", res["storage"]) print() + testcase_from_logs(res["logs"][0]) + text = b"" captures = {} for j in res["json_out"][0]: @@ -402,4 +255,41 @@ def character_maker2(lm, id, description, valid_weapons): print() +def testcase_from_logs(logs: str): + sep = "‧" + pairs = [] + prev_res = None + prompt = None + for line in logs.split("\n"): + if line.startswith("TEST: "): + obj = json.loads(line[6:]) + if prompt is None: + prompt = obj["res_prompt"] + continue + if prev_res: + pairs.append((prev_res, obj["arg"])) + prev_res = obj["res"] + print(obj) + assert prev_res == "stop" + testcase = [prompt] + gen_tokens = [] + + def flush_gen_tokens(): + testcase.append(sep.join(gen_tokens)) + gen_tokens.clear() + + for res, arg in pairs: + if res["sample_mask"]: + gen_tokens.append(arg["tokens"]) + else: + t0 = res["splices"][0]["tokens"] + assert t0 == arg["tokens"] + flush_gen_tokens() + testcase.append(t0) + if gen_tokens: + flush_gen_tokens() + + print("Testcase:", testcase) + + main() diff --git a/controllers/llguidance_ctrl/src/api.rs b/controllers/llguidance_ctrl/src/api.rs index e3518e55..00e10023 100644 --- a/controllers/llguidance_ctrl/src/api.rs +++ b/controllers/llguidance_ctrl/src/api.rs @@ -7,6 +7,8 @@ use serde::{Deserialize, Serialize}; pub struct TopLevelGrammar { pub grammars: Vec, pub max_tokens: Option, + #[serde(default)] + pub test_trace: bool, } pub const DEFAULT_CONTEXTUAL: bool = true; diff --git a/controllers/llguidance_ctrl/src/tokenparser.rs b/controllers/llguidance_ctrl/src/tokenparser.rs index 55b5a395..619e39cf 100644 --- a/controllers/llguidance_ctrl/src/tokenparser.rs +++ b/controllers/llguidance_ctrl/src/tokenparser.rs @@ -6,6 +6,7 @@ use crate::{ }; use aici_abi::{MidProcessArg, MidProcessResult, TokenId, TokenizerEnv}; use anyhow::Result; +use serde_json::json; macro_rules! infoln { ($s:expr, $($arg:tt)*) => { @@ -30,6 +31,7 @@ pub struct TokenParser { pub parser: Parser, pub log_level: isize, pub mid_process_start_time: std::time::Instant, + test_trace: bool, parser_stack: Vec, parser_llm_tokens_offset: usize, // this is empty for top-level parser, @@ -63,6 +65,7 @@ impl TokenParser { log_level: isize, ) -> Result { let mid_process_start_time = std::time::Instant::now(); + let test_trace = buf.test_trace; let max_tokens = buf.max_tokens.unwrap_or(usize::MAX); let compiled_grammars = grammars_from_json(buf, log_level >= 2)?; let parser = Parser::new( @@ -72,6 +75,7 @@ impl TokenParser { Ok(TokenParser { log_level, + test_trace, token_env, mid_process_start_time, mid_process_was_accepting: false, @@ -156,6 +160,12 @@ impl TokenParser { } infoln!(self, "res_prompt: {}", trie.tokens_dbg(&res_prompt)); + if self.test_trace { + self.test_trace_json(&json!({ + "prompt": trie.test_trace_tokens(&prompt), + "res_prompt": trie.test_trace_tokens(&res_prompt), + })); + } res_prompt } @@ -170,6 +180,12 @@ impl TokenParser { self.parser_stack.is_empty() } + fn test_trace_json(&self, j: &serde_json::Value) { + if self.test_trace { + infoln!(self, "TEST: {}", serde_json::to_string(j).unwrap()); + } + } + pub fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult { self.mid_process_start_time = std::time::Instant::now(); if self.max_tokens_total == 0 { @@ -178,7 +194,43 @@ impl TokenParser { } self.max_tokens_total -= 1; self.max_tokens_parser = self.max_tokens_parser.saturating_sub(1); - self.mid_process_inner(arg) + + let trace = if self.test_trace { + let tokens = self.token_env.tok_trie().test_trace_tokens(&arg.tokens); + Some(json!({ + "backtrack": arg.backtrack, + "tokens": tokens, + })) + } else { + None + }; + + let r = self.mid_process_inner(arg); + + if self.test_trace { + let res = if r.is_stop() { + json!("stop") + } else { + let b = &r.branches[0]; + json!({ + "sample_mask": b.sample_mask.is_some(), + "temperature": b.temperature, + "splices": b.splices.iter().map(|s| { + json!({ + "when_sampled": s.when_sampled, + "backtrack": s.backtrack, + "tokens": self.token_env.tok_trie().test_trace_tokens(&s.ff_tokens), + }) + }).collect::>(), + }) + }; + self.test_trace_json(&json!({ + "arg": trace.unwrap(), + "res": res, + })); + } + + r } fn mid_process_inner(&mut self, mut arg: MidProcessArg) -> MidProcessResult { diff --git a/py/guidance b/py/guidance index 75806240..815ed6bc 160000 --- a/py/guidance +++ b/py/guidance @@ -1 +1 @@ -Subproject commit 7580624096c08b200de300638620ce5bb8227742 +Subproject commit 815ed6bceb9453288dbdd1d689a66259320eade1 diff --git a/py/llguidance/python/llguidance/_lib.pyi b/py/llguidance/python/llguidance/_lib.pyi index d022de46..b3e8d71e 100644 --- a/py/llguidance/python/llguidance/_lib.pyi +++ b/py/llguidance/python/llguidance/_lib.pyi @@ -39,6 +39,12 @@ class LLTokenizer: The result is double-quoted and tokens are separated by '‧'. """ + def test_trace_tokens(self, tokens: List[int]) -> str: + """ + Return a debug string representation of the tokens + for test traces. + """ + def decode_str(self, tokens: List[int]) -> str: """ Decode the tokens into a string. diff --git a/py/llguidance/rust/py.rs b/py/llguidance/rust/py.rs index 41cd5bfd..ad0fcadf 100644 --- a/py/llguidance/rust/py.rs +++ b/py/llguidance/rust/py.rs @@ -183,6 +183,10 @@ impl LLTokenizer { self.tok_trie.greedy_tokenize(text.as_bytes()) } + fn test_trace_tokens(&self, tokens: Vec) -> String { + self.tok_trie.test_trace_tokens(&tokens) + } + fn dbg_tokens(&self, tokens: Vec) -> String { self.tok_trie.tokens_dbg(&tokens) }