diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 9995152b..87864381 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -132,6 +132,21 @@ impl Branch { splices: self.splices.clone(), } } + + pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { + Branch { + sample_mask: None, + splices: vec![Splice { + when_sampled: vec![], + backtrack, + ff_tokens, + }], + } + } + + pub fn noop() -> Self { + Self::splice(0, vec![]) + } } #[derive(Debug)] @@ -159,14 +174,7 @@ impl MidProcessResult { pub fn splice(backtrack: u32, ff_tokens: Vec) -> Self { MidProcessResult { - branches: vec![Branch { - sample_mask: None, - splices: vec![Splice { - when_sampled: vec![], - backtrack, - ff_tokens, - }], - }], + branches: vec![Branch::splice(backtrack, ff_tokens)], } } diff --git a/controllers/declctrl/src/declctrl.rs b/controllers/declctrl/src/declctrl.rs index ff5e00f5..a7afb218 100644 --- a/controllers/declctrl/src/declctrl.rs +++ b/controllers/declctrl/src/declctrl.rs @@ -13,12 +13,11 @@ use aici_abi::{ aici_expose_all, bytes::limit_str, cfg::CfgParser, - rx::RecRx, - rx::RxStackRecognizer, + rx::{RecRx, RxStackRecognizer}, svob::SimpleVob, - tokenize, tokenize_bytes, + tokenize_bytes, toktree::{Recognizer, SpecialToken, TokTrie}, - AiciCtrl, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, TokenId, + AiciCtrl, Branch, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, TokenId, VariableStorage, }; use core::panic; @@ -747,6 +746,7 @@ impl StepState { self.check_eos(false) } + #[allow(dead_code)] fn attention_mask(&self, ctx: &RunnerCtx) -> Vec { if self.mask_tags.len() == 0 { vec![] @@ -1220,17 +1220,6 @@ impl AiciCtrl for Runner { } } - if let StepSpecific::Fork { branches } = &self.curr_state().specific { - let attention_masks = branches - .iter() - .map(|b| b[0].attention_mask(&self.ctx)) - .collect::>(); - PreProcessResult::new(attention_masks.len()) - } else { - let mask = self.curr_state().attention_mask(&self.ctx); - PreProcessResult::new(vec![mask].len()) - } - // // MID // @@ -1255,18 +1244,16 @@ impl AiciCtrl for Runner { } } + if let StepSpecific::Fork { branches } = &self.curr_state().specific { + assert!(branches.len() > 1); + return MidProcessResult { + branches: branches.iter().map(|_| Branch::noop()).collect(), + }; + } + if self.maybe_wait() { // this is a bit late in the game, but it's the best we can do - MidProcessResult::Splice { - backtrack: 0, - ff_tokens: tokenize(" "), - } - - // // we pop the useless generated token - // MidProcessResult::Splice { - // backtrack: 1, - // ff_tokens: vec![], - // } + MidProcessResult::noop() } else { self.try_backtrack() } diff --git a/controllers/pyctrl/driver.py b/controllers/pyctrl/driver.py index be34c83f..32faec05 100644 --- a/controllers/pyctrl/driver.py +++ b/controllers/pyctrl/driver.py @@ -13,7 +13,6 @@ def __init__(self): self.mod = pyaici.cli.build_rust(".") self.oks = [] self.failures = [] - pyaici.rest.log_level = 0 def fail(self, id: str, logs: str): print(f"FAIL") @@ -67,6 +66,14 @@ def main(): help="skip tests matching string", ) + parser.add_argument( + "--log-level", + "-l", + type=int, + default=0, + help="AICI log level", + ) + parser.add_argument( "--only", "-k", @@ -82,8 +89,11 @@ def main(): help="files to test", ) + args = parser.parse_args() + pyaici.rest.log_level = args.log_level + files = args.test_file if files[0].endswith(".js"): diff --git a/docs/aicirt-proto.md b/docs/aicirt-proto.md index e864e486..effb7eb4 100644 --- a/docs/aicirt-proto.md +++ b/docs/aicirt-proto.md @@ -67,8 +67,6 @@ The `ping` command is used to check if the AICIrt is alive. The requests always have an `op` field, and the responses always have a `type` field, which is either `"ok"` or `"error"`, as well as a `data` field. -TODO: add inference_caps - The `tokens` command gives the size of the vocabulary of the loaded tokenizer. ```json diff --git a/scripts/test-all.sh b/scripts/test-all.sh index a578e4d3..0201088a 100755 --- a/scripts/test-all.sh +++ b/scripts/test-all.sh @@ -4,4 +4,4 @@ set -x set -e ./scripts/test-pyctrl.sh ./scripts/test-jsctrl.sh -#pytest +pytest