Skip to content

Commit

Permalink
Correct interface in pyi, reprs for all python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Jan 3, 2025
1 parent 56f5791 commit c70e505
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 68 deletions.
79 changes: 34 additions & 45 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,69 +19,53 @@ WHITESPACE: str

class Guide:
def __init__(self, index: Index):
"""
Defines a guide object an index an initializes it in its start state.
"""
"""Defines a guide object an index an initializes it in its start state."""
def get_start_tokens(self) -> List[int]:
"""
Gets the list of allowed tokens from the start state.
"""
"""Gets the list of allowed tokens from the start state."""
...
def read_next_token(self, token_id: int) -> List[int]:
"""
Reads the next token according to the model and returns a list of allowable tokens.
"""
"""Reads the next token according to the model and returns a list of allowable tokens."""
...
def is_finished(self) -> bool:
"""
Checks if the automaton is in a final state.
"""
"""Checks if the automaton is in a final state."""
...
def __repr__(self) -> str:
"""Gets the debug string representation of the guide."""
...
def __str__(self) -> str:
"""Gets the string representation of the guide."""

class Vocabulary:
"""
Vocabulary of an LLM.
"""

@staticmethod
def from_dict(eos_token_id: int, map: Dict[str, List[int]]) -> "Vocabulary":
"""
Creates a vocabulary from a map of tokens to token ids and eos token id.
"""
def __init__(
self, eos_token_id: int, map: Dict[Union[str, bytes], List[int]]
) -> "Vocabulary":
"""Creates a vocabulary from a map of tokens to token ids and eos token id."""
...
@staticmethod
def from_pretrained(model: str) -> "Vocabulary":
"""
Creates the vocabulary of a pre-trained model.
"""
def from_pretrained(
model: str, revision: Optional[String], token: Optional[String]
) -> "Vocabulary":
"""Creates the vocabulary of a pre-trained model."""
...
def get_eos_token_id(self) -> Optional[int]:
"""Gets the end of sentence token id."""
...
def get(self, token: Union[str, bytes]) -> Optional[List[int]]:
"""Gets the end of sentence token id."""
...
def __repr__(self) -> str:
"""
Gets the debug string representation of the vocabulary.
"""
"""Gets the debug string representation of the vocabulary."""
...
def __str__(self) -> str:
"""
Gets the string representation of the vocabulary.
"""
"""Gets the string representation of the vocabulary."""
...
def __eq__(self, other: object) -> bool:
"""
Gets whether two vocabularies are the same.
"""
...
def get_eos_token_id(self) -> Optional[int]:
"""
Gets the end of sentence token id.
"""
"""Compares whether two vocabularies are the same."""
...

class Index:
@staticmethod
def from_regex(regex: str, vocabulary: "Vocabulary") -> "Index":
"""
Creates an index from a regex and vocabulary.
"""
def __init__(self, regex: str, vocabulary: "Vocabulary") -> "Index":
"""Creates an index from a regex and vocabulary."""
...
def get_allowed_tokens(self, state: int) -> Optional[List[int]]:
"""Returns allowed tokens in this state."""
Expand All @@ -95,9 +79,14 @@ class Index:
def final_states(self) -> List[int]:
"""Get all final states."""
...
def get_index_dict(self) -> Dict[int, Dict[int, int]]:
def get_transitions(self) -> Dict[int, Dict[int, int]]:
"""Returns the Index as a Python Dict object."""
...
def get_initial_state(self) -> int:
"""Returns the ID of the initial state of the input FSM automata."""
...
def __repr__(self) -> str:
"""Gets the debug string representation of the index."""
...
def __str__(self) -> str:
"""Gets the string representation of the index."""
10 changes: 10 additions & 0 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,16 @@ impl Index {
}
}

impl std::fmt::Display for Index {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Index object with transitions:")?;
for (state_id, token_ids) in self.states_to_token_subsets.iter() {
writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?;
}
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
63 changes: 42 additions & 21 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ impl PyGuide {
self.index.is_final_state(self.state)
}

fn __repr__(&self) -> String {
format!(
"Guide object with the state={:#?} and {:#?}",
self.state, self.index
)
}

fn __str__(&self) -> String {
format!(
"Guide object with the state={} and {}",
self.state, self.index.0
)
}

fn __reduce__(&self) -> PyResult<(PyObject, (Vec<u8>,))> {
Python::with_gil(|py| {
let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")?
Expand Down Expand Up @@ -101,27 +115,6 @@ impl PyIndex {
})
}

fn __reduce__(&self) -> PyResult<(PyObject, (Vec<u8>,))> {
Python::with_gil(|py| {
let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")?
.getattr("Index")?;
let binary_data: Vec<u8> = bincode::encode_to_vec(&self.0, config::standard())
.map_err(|e| {
PyErr::new::<PyValueError, _>(format!("Serialization of Index failed: {}", e))
})?;
Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,)))
})
}

#[staticmethod]
fn from_binary(binary_data: Vec<u8>) -> PyResult<Self> {
let (index, _): (Index, usize) =
bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| {
PyErr::new::<PyValueError, _>(format!("Deserialization of Index failed: {}", e))
})?;
Ok(PyIndex(index))
}

fn get_allowed_tokens(&self, state: StateId) -> Option<Vec<TokenId>> {
self.0.allowed_tokens(state)
}
Expand All @@ -145,6 +138,34 @@ impl PyIndex {
fn get_initial_state(&self) -> StateId {
self.0.initial()
}
fn __repr__(&self) -> String {
format!("{:#?}", self.0)
}

fn __str__(&self) -> String {
format!("{}", self.0)
}

fn __reduce__(&self) -> PyResult<(PyObject, (Vec<u8>,))> {
Python::with_gil(|py| {
let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")?
.getattr("Index")?;
let binary_data: Vec<u8> = bincode::encode_to_vec(&self.0, config::standard())
.map_err(|e| {
PyErr::new::<PyValueError, _>(format!("Serialization of Index failed: {}", e))
})?;
Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,)))
})
}

#[staticmethod]
fn from_binary(binary_data: Vec<u8>) -> PyResult<Self> {
let (index, _): (Index, usize) =
bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| {
PyErr::new::<PyValueError, _>(format!("Deserialization of Index failed: {}", e))
})?;
Ok(PyIndex(index))
}
}

#[pyfunction(name = "build_regex_from_schema")]
Expand Down
16 changes: 14 additions & 2 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,21 @@ impl Vocabulary {

impl std::fmt::Display for Vocabulary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "[{:?}]", self.eos_token_id)?;
writeln!(
f,
"Vocabulary object with eos_token_id={:?} and the following tokens to token_ids:",
self.eos_token_id
)?;
for (token, token_ids) in self.tokens.iter() {
writeln!(f, "{:?} -> {:?}", token, token_ids)?;
writeln!(
f,
"{:?} -> {:?}",
token
.iter()
.map(|b| format!("0x{:02X}", b))
.collect::<Vec<_>>(),
token_ids
)?;
}
Ok(())
}
Expand Down
1 change: 1 addition & 0 deletions tests/fsm/test_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_str_and_bytes_produce_the_same():
index2 = Index(regex, vocabulary2)
guide1 = Guide(index1)
guide2 = Guide(index2)
assert False

assert guide1.read_next_token(3) == guide2.read_next_token(3)
# `a` and `b` have similar transitions to `z`
Expand Down

0 comments on commit c70e505

Please sign in to comment.