Skip to content

Commit

Permalink
add LLInterpter.is_accepting()
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jun 18, 2024
1 parent e1c7846 commit b841346
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
22 changes: 17 additions & 5 deletions controllers/llguidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub struct TokenParser {
// this is empty for top-level parser,
// and the previous grm_bytes for sub-parsers
previous_grm_bytes: Vec<u8>,
mid_process_was_accepting: bool,

first_token_of_eos_marker: TokenId,
max_tokens_total: usize,
Expand Down Expand Up @@ -78,6 +79,7 @@ impl TokenParser {
log_level,
token_env,
mid_process_start_time,
mid_process_was_accepting: false,
parser,
parser_llm_tokens_offset: 0,
parser_stack: Vec::new(),
Expand All @@ -104,6 +106,10 @@ impl TokenParser {
&self.llm_bytes[self.grm_prefix.len()..]
}

pub fn mid_process_was_accepting(&self) -> bool {
self.mid_process_was_accepting
}

pub fn bytes_since(&mut self, mut idx: usize) -> &[u8] {
idx += self.grm_prefix.len();
let endp = std::cmp::min(
Expand Down Expand Up @@ -182,6 +188,8 @@ impl TokenParser {
fn mid_process_inner(&mut self, mut arg: MidProcessArg) -> MidProcessResult {
let start_time = std::time::Instant::now();

self.mid_process_was_accepting = false;

infoln!(self, "\n");
let trie = self.token_env.tok_trie();

Expand Down Expand Up @@ -287,9 +295,9 @@ impl TokenParser {
}
}

if arg.tokens.contains(&trie.eos_token()) {
return MidProcessResult::stop();
}
// if arg.tokens.contains(&trie.eos_token()) {
// return MidProcessResult::stop();
// }

let new_forced = grm_bytes[self.llm_bytes.len()..].to_vec();
let mut token_prefix = Vec::new();
Expand Down Expand Up @@ -330,13 +338,16 @@ impl TokenParser {
}

let inner_done = {
let empty_token_prefix = token_prefix.is_empty();
let is_accepting = self.parser.is_accepting();
let can_advance = self.parser.can_advance();
let inner_done = is_accepting && !can_advance;
let inner_done = empty_token_prefix && is_accepting && !can_advance;
infoln!(
self,
"inner_done: {inner_done}; can_advance: {can_advance}; accept: {is_accepting}"
"inner_done: {inner_done}; can_advance: {can_advance}; accept: {is_accepting}; empty_token_prefix: {empty_token_prefix}"
);
self.mid_process_was_accepting =
is_accepting && empty_token_prefix && self.parser_stack.is_empty();
inner_done
};

Expand Down Expand Up @@ -378,6 +389,7 @@ impl TokenParser {
);

if set.num_set() == 0 {
infoln!(self, "no tokens allowed, stopping");
return MidProcessResult::stop();
}

Expand Down
8 changes: 7 additions & 1 deletion py/llguidance/python/llguidance/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LLTokenizer:
prefix of the text, and then fallback to greedy_tokenize() for the last
few bytes.
"""

def tokenize_str(self, text: str) -> List[int]:
"""
Same as tokenize_bytes, but for strings.
Expand Down Expand Up @@ -71,6 +71,12 @@ class LLInterpreter:
Create a deep copy of the interpreter.
"""

def is_accepting(self) -> bool:
"""
Check if the last mid_process() call resulted in overall accepting state
of the parser.
"""

def process_prompt(self, prompt: List[TokenId]) -> List[TokenId]:
"""
Perform any adjustments to the prompt before completion.
Expand Down
10 changes: 9 additions & 1 deletion py/llguidance/rust/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ use aici_abi::{
toktree::{self, TokTrie},
MidProcessArg, TokenId, TokenizerEnv,
};
use aici_llguidance_ctrl::{api::TopLevelGrammar, output::{ParserOutput, Reporter}, TokenParser};
use aici_llguidance_ctrl::{
api::TopLevelGrammar,
output::{ParserOutput, Reporter},
TokenParser,
};
use pyo3::{exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -55,6 +59,10 @@ impl LLInterpreter {
self.clone()
}

fn is_accepting(&self) -> bool {
self.inner.mid_process_was_accepting()
}

fn process_prompt(&mut self, prompt: Vec<TokenId>) -> Vec<TokenId> {
self.inner.process_prompt(prompt)
}
Expand Down

0 comments on commit b841346

Please sign in to comment.