Skip to content

Commit

Permalink
more JSON output from guidance parser
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Apr 30, 2024
1 parent fb5fb81 commit 81dba56
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
49 changes: 44 additions & 5 deletions controllers/guidance_ctrl/src/gctrl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ macro_rules! infoln {
pub struct Runner {
tok_parser: TokenParser,
reported_captures: usize,
text_ptr: usize,
token_ptr: usize,
}

#[derive(Serialize, Deserialize)]
Expand All @@ -34,13 +36,17 @@ impl Runner {
let guidance = base64::engine::general_purpose::STANDARD
.decode(arg.guidance_b64)
.expect("invalid base64");
let tok_parser = TokenParser::from_guidance_protobuf(
Box::new(aici_abi::WasmTokenizerEnv::default()),
&guidance,
)
.expect("invalid guidance protobuf");
let token_ptr = tok_parser.num_tokens();
Runner {
tok_parser: TokenParser::from_guidance_protobuf(
Box::new(aici_abi::WasmTokenizerEnv::default()),
&guidance,
)
.expect("invalid guidance protobuf"),
tok_parser,
reported_captures: 0,
text_ptr: 0,
token_ptr,
}
}

Expand All @@ -53,9 +59,20 @@ impl Runner {
name: name.clone(),
str: String::from_utf8_lossy(val).to_string(),
hex: to_hex_string(val),
log_prob: 0.0, // TODO
};
json_out(&cap);
}

let new_text = self.tok_parser.bytes_since(self.text_ptr);
if new_text.len() > 0 {
// TODO log_prob
let text =
Text::from_bytes(new_text, 0.0, self.tok_parser.num_tokens() - self.token_ptr);
json_out(&text);
self.text_ptr += new_text.len();
self.token_ptr = self.tok_parser.num_tokens();
}
}
}

Expand All @@ -69,6 +86,7 @@ struct Capture {
name: String,
str: String,
hex: String,
log_prob: f64,
}

#[derive(Serialize, Deserialize)]
Expand All @@ -78,6 +96,27 @@ struct FinalText {
hex: String,
}

#[derive(Serialize, Deserialize)]
struct Text {
object: &'static str, // "text"
str: String,
hex: String,
log_prob: f64,
num_tokens: usize,
}

impl Text {
pub fn from_bytes(bytes: &[u8], log_prob: f64, num_tokens: usize) -> Self {
Text {
object: "text",
str: String::from_utf8_lossy(bytes).to_string(),
hex: to_hex_string(bytes),
log_prob,
num_tokens,
}
}
}

impl FinalText {
pub fn from_bytes(bytes: &[u8]) -> Self {
FinalText {
Expand Down
8 changes: 8 additions & 0 deletions controllers/guidance_ctrl/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,18 @@ impl TokenParser {
})
}

pub fn num_tokens(&self) -> usize {
self.llm_tokens.len()
}

pub fn final_bytes(&self) -> &[u8] {
&self.llm_bytes[self.grm_prefix.len()..]
}

pub fn bytes_since(&self, idx: usize) -> &[u8] {
&self.llm_bytes[self.grm_prefix.len() + idx..]
}

pub fn process_prompt(&mut self, prompt: Vec<TokenId>) -> Vec<TokenId> {
assert!(self.llm_tokens.is_empty());

Expand Down

0 comments on commit 81dba56

Please sign in to comment.