Skip to content

Commit

Permalink
Index by ref in Guide
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Jan 7, 2025
1 parent 6a81250 commit 88d609b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 12 deletions.
25 changes: 23 additions & 2 deletions benchmarks/bench_regex_guide.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from concurrent.futures import ThreadPoolExecutor

import psutil
from outlines_core.fsm import Index, Vocabulary
from outlines_core.fsm import Guide, Index, Vocabulary

regex_samples = {
"email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?",
Expand Down Expand Up @@ -60,5 +61,25 @@ def setup(self, pattern_name):
self.vocabulary = Vocabulary.from_pretrained("gpt2")
self.pattern = regex_samples[pattern_name]

def peakmem_regex_to_guide(self, pattern_name):
def peakmem_regex_to_index(self, pattern_name):
Index(self.pattern, self.vocabulary)


class MemoryStabilityBenchmark:
params = [1, 10_000]

def setup(self, num):
self.vocabulary = Vocabulary.from_pretrained("gpt2")
self.index = Index(".*", self.vocabulary)
self.process = psutil.Process(os.getpid())

def _memory_usage(self):
return self.process.memory_info().rss / 1024**2

def peakmem_guides_per_index(self, num_guides):
initial = self._memory_usage()
objects = [Guide(self.index) for i in range(num_guides)]
final = self._memory_usage()

assert len(objects) == num_guides
assert final - initial < 5
18 changes: 10 additions & 8 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::index::Index;
use crate::json_schema;
use crate::prelude::*;
Expand Down Expand Up @@ -28,11 +30,11 @@ pub struct PyGuide {
#[pymethods]
impl PyGuide {
#[new]
fn new(index: PyIndex) -> Self {
PyGuide {
fn __new__(index: PyIndex) -> PyResult<PyGuide> {
Ok(PyGuide {
state: index.get_initial_state(),
index,
}
})
}

fn get_start_tokens(&self) -> PyResult<Vec<TokenId>> {
Expand Down Expand Up @@ -102,15 +104,15 @@ impl PyGuide {

#[pyclass(name = "Index", module = "outlines_core.fsm.outlines_core_rs")]
#[derive(Clone, Debug, Encode, Decode)]
pub struct PyIndex(Index);
pub struct PyIndex(Arc<Index>);

#[pymethods]
impl PyIndex {
#[new]
fn new(py: Python<'_>, regex: &str, vocabulary: &PyVocabulary) -> PyResult<Self> {
fn __new__(py: Python<'_>, regex: &str, vocabulary: &PyVocabulary) -> PyResult<Self> {
py.allow_threads(|| {
Index::new(regex, &vocabulary.0)
.map(PyIndex)
.map(|x| PyIndex(Arc::new(x)))
.map_err(Into::into)
})
}
Expand Down Expand Up @@ -164,7 +166,7 @@ impl PyIndex {
bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| {
PyErr::new::<PyValueError, _>(format!("Deserialization of Index failed: {}", e))
})?;
Ok(PyIndex(index))
Ok(PyIndex(Arc::new(index)))
}
}

Expand Down Expand Up @@ -193,7 +195,7 @@ pub struct PyVocabulary(Vocabulary);
#[pymethods]
impl PyVocabulary {
#[new]
fn new(py: Python<'_>, eos_token_id: TokenId, map: Py<PyAny>) -> PyResult<PyVocabulary> {
fn __new__(py: Python<'_>, eos_token_id: TokenId, map: Py<PyAny>) -> PyResult<PyVocabulary> {
if let Ok(dict) = map.extract::<HashMap<String, Vec<TokenId>>>(py) {
return Ok(PyVocabulary(Vocabulary::from((eos_token_id, dict))));
}
Expand Down
1 change: 0 additions & 1 deletion tests/fsm/test_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def test_str_and_bytes_produce_the_same():
index2 = Index(regex, vocabulary2)
guide1 = Guide(index1)
guide2 = Guide(index2)
assert False

assert guide1.read_next_token(3) == guide2.read_next_token(3)
# `a` and `b` have similar transitions to `z`
Expand Down
2 changes: 1 addition & 1 deletion tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Model(BaseModel):
n: int

json_schema = json.dumps(Model.model_json_schema())
pattern = build_regex_from_schema(json_schema, whitespace_pattern=None)
build_regex_from_schema(json_schema, whitespace_pattern=None)


def test_match_object():
Expand Down

0 comments on commit 88d609b

Please sign in to comment.