Skip to content

Commit

Permalink
fix declctrl
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Apr 12, 2024
1 parent 2d58639 commit 7804d11
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 37 deletions.
24 changes: 16 additions & 8 deletions controllers/aici_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,21 @@ impl<S> Branch<S> {
splices: self.splices.clone(),
}
}

pub fn splice(backtrack: u32, ff_tokens: Vec<TokenId>) -> Self {
Branch {
sample_mask: None,
splices: vec![Splice {
when_sampled: vec![],
backtrack,
ff_tokens,
}],
}
}

pub fn noop() -> Self {
Self::splice(0, vec![])
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -159,14 +174,7 @@ impl MidProcessResult {

pub fn splice(backtrack: u32, ff_tokens: Vec<TokenId>) -> Self {
MidProcessResult {
branches: vec![Branch {
sample_mask: None,
splices: vec![Splice {
when_sampled: vec![],
backtrack,
ff_tokens,
}],
}],
branches: vec![Branch::splice(backtrack, ff_tokens)],
}
}

Expand Down
37 changes: 12 additions & 25 deletions controllers/declctrl/src/declctrl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ use aici_abi::{
aici_expose_all,
bytes::limit_str,
cfg::CfgParser,
rx::RecRx,
rx::RxStackRecognizer,
rx::{RecRx, RxStackRecognizer},
svob::SimpleVob,
tokenize, tokenize_bytes,
tokenize_bytes,
toktree::{Recognizer, SpecialToken, TokTrie},
AiciCtrl, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, TokenId,
AiciCtrl, Branch, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult, TokenId,
VariableStorage,
};
use core::panic;
Expand Down Expand Up @@ -747,6 +746,7 @@ impl StepState {
self.check_eos(false)
}

#[allow(dead_code)]
fn attention_mask(&self, ctx: &RunnerCtx) -> Vec<f32> {
if self.mask_tags.len() == 0 {
vec![]
Expand Down Expand Up @@ -1220,17 +1220,6 @@ impl AiciCtrl for Runner {
}
}

if let StepSpecific::Fork { branches } = &self.curr_state().specific {
let attention_masks = branches
.iter()
.map(|b| b[0].attention_mask(&self.ctx))
.collect::<Vec<_>>();
PreProcessResult::new(attention_masks.len())
} else {
let mask = self.curr_state().attention_mask(&self.ctx);
PreProcessResult::new(vec![mask].len())
}

//
// MID
//
Expand All @@ -1255,18 +1244,16 @@ impl AiciCtrl for Runner {
}
}

if let StepSpecific::Fork { branches } = &self.curr_state().specific {
assert!(branches.len() > 1);
return MidProcessResult {
branches: branches.iter().map(|_| Branch::noop()).collect(),
};
}

if self.maybe_wait() {
// this is a bit late in the game, but it's the best we can do
MidProcessResult::Splice {
backtrack: 0,
ff_tokens: tokenize(" "),
}

// // we pop the useless generated token
// MidProcessResult::Splice {
// backtrack: 1,
// ff_tokens: vec![],
// }
MidProcessResult::noop()
} else {
self.try_backtrack()
}
Expand Down
12 changes: 11 additions & 1 deletion controllers/pyctrl/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self):
self.mod = pyaici.cli.build_rust(".")
self.oks = []
self.failures = []
pyaici.rest.log_level = 0

def fail(self, id: str, logs: str):
print(f"FAIL")
Expand Down Expand Up @@ -67,6 +66,14 @@ def main():
help="skip tests matching string",
)

parser.add_argument(
"--log-level",
"-l",
type=int,
default=0,
help="AICI log level",
)

parser.add_argument(
"--only",
"-k",
Expand All @@ -82,8 +89,11 @@ def main():
help="files to test",
)


args = parser.parse_args()

pyaici.rest.log_level = args.log_level

files = args.test_file

if files[0].endswith(".js"):
Expand Down
2 changes: 0 additions & 2 deletions docs/aicirt-proto.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ The `ping` command is used to check if the AICIrt is alive.
The requests always have an `op` field, and the responses always have a `type` field,
which is either `"ok"` or `"error"`, as well as a `data` field.

TODO: add inference_caps

The `tokens` command gives the size of the vocabulary of the loaded tokenizer.

```json
Expand Down
2 changes: 1 addition & 1 deletion scripts/test-all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ set -x
set -e
./scripts/test-pyctrl.sh
./scripts/test-jsctrl.sh
#pytest
pytest

0 comments on commit 7804d11

Please sign in to comment.