Skip to content

Commit

Permalink
maintain parallel item_props[]
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed May 6, 2024
1 parent 0c82655 commit f49110f
Showing 1 changed file with 66 additions and 11 deletions.
77 changes: 66 additions & 11 deletions controllers/guidance_ctrl/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use aici_abi::{
TokenId,
};

use super::grammar::{CGrammar, CSymIdx, CSymbol, ModelVariable, RuleIdx, SimpleHash};
use super::grammar::{CGrammar, CSymIdx, CSymbol, ModelVariable, RuleIdx};

const DEBUG: bool = false;
const INFO: bool = true;
Expand All @@ -31,6 +31,26 @@ struct Item {
data: u64,
}

// These are only tracked in definitive mode
#[derive(Debug, Clone)]
struct ItemProps {
hidden_start: usize,
}

impl Default for ItemProps {
fn default() -> Self {
ItemProps {
hidden_start: usize::MAX,
}
}
}

impl ItemProps {
fn merge(&mut self, other: ItemProps) {
self.hidden_start = self.hidden_start.min(other.hidden_start);
}
}

#[derive(Debug, Default)]
pub struct Stats {
pub rows: usize,
Expand Down Expand Up @@ -87,6 +107,7 @@ struct Scratch {
row_start: usize,
row_end: usize,
items: Vec<Item>,
item_props: Vec<ItemProps>,
definitive: bool,
}

Expand Down Expand Up @@ -116,6 +137,7 @@ impl Scratch {
row_start: 0,
row_end: 0,
items: vec![],
item_props: vec![],
definitive: true,
}
}
Expand All @@ -139,27 +161,48 @@ impl Scratch {
}

#[inline(always)]
fn just_add(&mut self, item: Item, _parent_item_idx: usize) {
fn merge_item_origin(&mut self, target_item_idx: usize, origin_item_idx: usize) {
let origin = self.item_props[origin_item_idx].clone();
self.item_props[target_item_idx].merge(origin);
}

#[inline(always)]
fn just_add(&mut self, item: Item, origin_item_idx: usize) {
self.ensure_items(self.row_end + 1);
// SAFETY: we just ensured that there is enough space
unsafe {
self.items.as_mut_ptr().add(self.row_end).write(item);
}
// self.items[self.row_end] = item;
if self.definitive {
if self.item_props.len() <= self.row_end {
self.item_props.push(ItemProps::default());
} else {
self.item_props[self.row_end] = ItemProps::default();
}
self.merge_item_origin(self.row_end, origin_item_idx);
}
self.row_end += 1;
}

#[inline(always)]
fn add_unique(&mut self, item: Item, parent_item_idx: usize, info: &str) {
if !self.items[self.row_start..self.row_end].contains(&item) {
fn add_unique(&mut self, item: Item, origin_item_idx: usize, info: &str) {
if let Some(idx) = self.items[self.row_start..self.row_end]
.iter()
.position(|&x| x == item)
{
if self.definitive {
self.merge_item_origin(idx, origin_item_idx);
}
} else {
if self.definitive {
debug!(
" addu: {} ({})",
item_to_string(&self.grammar, &item),
info
);
}
self.just_add(item, parent_item_idx);
self.just_add(item, origin_item_idx);
}
}
}
Expand Down Expand Up @@ -209,7 +252,7 @@ impl Parser {
}

fn pop_row_infos(&mut self, n: usize) {
self.assert_non_trie();
self.assert_definitive();
unsafe { self.row_infos.set_len(self.row_infos.len() - n) }
self.pop_rows(n);
}
Expand All @@ -225,13 +268,13 @@ impl Parser {
self.stats = Stats::default();
}

fn assert_non_trie(&self) {
fn assert_definitive(&self) {
assert!(self.scratch.definitive);
assert!(self.num_rows() == self.row_infos.len());
}

pub fn get_bytes(&self) -> Vec<u8> {
self.assert_non_trie();
self.assert_definitive();
self.row_infos.iter().skip(1).map(|ri| ri.byte).collect()
}

Expand All @@ -250,7 +293,7 @@ impl Parser {
mut num_skip: usize,
) -> &'static str {
// this is unused!
self.assert_non_trie();
self.assert_definitive();
let mut byte_idx = 1; // row_infos[0] has just the 0 byte
let mut tok_idx = 0;
for t in tokens {
Expand Down Expand Up @@ -287,6 +330,8 @@ impl Parser {
}

pub fn filter_max_tokens(&mut self) {
self.assert_definitive();

let mut dst = 0;

self.row_infos.push(RowInfo {
Expand All @@ -300,6 +345,7 @@ impl Parser {
self.rows[idx].first_item = dst;
for i in range {
let item = self.scratch.items[i];
let item_props = &self.scratch.item_props[i];
let sym_data = self.item_sym_data(&item);
let max_tokens = sym_data.props.max_tokens;
if max_tokens != usize::MAX {
Expand All @@ -315,6 +361,7 @@ impl Parser {
}
}
self.scratch.items[dst] = item;
self.scratch.item_props[dst] = item_props.clone();
dst += 1;
}
self.rows[idx].last_item = dst;
Expand All @@ -324,7 +371,7 @@ impl Parser {
}

pub fn force_bytes(&mut self) -> Vec<u8> {
self.assert_non_trie();
self.assert_definitive();
let mut bytes = vec![];
while let Some(b) = self.forced_byte() {
let res = self.scan(b);
Expand Down Expand Up @@ -462,7 +509,7 @@ impl Parser {
self.is_accepting = false;

while agenda_ptr < self.scratch.row_end {
let item_idx = agenda_ptr;
let mut item_idx = agenda_ptr;
let mut item = self.scratch.items[agenda_ptr];
agenda_ptr += 1;
if self.scratch.definitive {
Expand Down Expand Up @@ -514,10 +561,18 @@ impl Parser {
{
// if so, use it
item = next_item;
item_idx = ptr;
}
}
self.scratch.row_end = agenda_ptr;
self.scratch.items[agenda_ptr - 1] = item;
if self.scratch.definitive {
self.scratch.item_props[agenda_ptr - 1] =
self.scratch.item_props[item_idx].clone();
}
// better keep item_idx updated in case we use it in future
item_idx = agenda_ptr - 1;
let _ = item_idx; // silence warning
commit_item = item;
if self.scratch.definitive {
debug!(" commit point: {}", self.item_to_string(&item));
Expand Down

0 comments on commit f49110f

Please sign in to comment.