Skip to content

Commit

Permalink
Non-optional eos_token_id
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Dec 20, 2024
1 parent 5f85340 commit 739b3d9
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 64 deletions.
3 changes: 0 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ pub type Result<T, E = crate::Error> = std::result::Result<T, E>;
pub enum Error {
#[error("The vocabulary does not allow to build an index that matches the input")]
InsufficientVocabulary,
// TODO: this error will be removed once eos_token_id for vocabulary won't be optional
#[error("Index failed since vocabulary doesn't provide eos token id")]
IndexEosTokenIdNotAvailable,
#[error("Failed to build DFA {0}")]
IndexDfaError(#[from] Box<regex_automata::dfa::dense::BuildError>),
#[error("Index failed since anchored universal start state doesn't exist")]
Expand Down
13 changes: 4 additions & 9 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ pub struct Index {

impl Index {
pub(crate) fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
let eos_token_id = match vocabulary.eos_token_id() {
Some(s) => s,
// TODO: this error will be removed once eos_token_id for vocabulary won't be optional
None => return Err(Error::IndexEosTokenIdNotAvailable),
};

let eos_token_id = vocabulary.eos_token_id();
let dfa = DFA::new(regex).map_err(Box::new)?;
let start_state = match dfa.universal_start_state(Anchored::Yes) {
Some(s) => s,
Expand Down Expand Up @@ -135,7 +130,7 @@ mod tests {
#[test]
fn index_from_regex() {
let regex = "0|[1-9][0-9]*";
let vocabulary = Vocabulary::new(Some(4))
let vocabulary = Vocabulary::new(4)
.insert("blah", 0)
.insert("1a", 1)
.insert("2", 2)
Expand All @@ -157,7 +152,7 @@ mod tests {
#[test]
fn index_from_regex_initital_in_allowed() {
let regex = "`\\n(\\.\\n)?`\\n";
let vocabulary = Vocabulary::new(Some(104))
let vocabulary = Vocabulary::new(104)
.insert("\n", 103)
.insert(".", 102)
.insert("`", 101);
Expand All @@ -172,7 +167,7 @@ mod tests {
#[test]
fn index_from_regex_multibyte() {
let regex = "😇| [😈-😍][😇-😎]*";
let vocabulary = Vocabulary::new(Some(8))
let vocabulary = Vocabulary::new(8)
.insert(" 😍", 5)
.insert("blah", 0)
.insert("😇", 2)
Expand Down
2 changes: 1 addition & 1 deletion src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl PyVocabulary {
Ok(PyVocabulary(v))
}

fn get_eos_token_id(&self) -> Option<TokenId> {
fn get_eos_token_id(&self) -> TokenId {
self.0.eos_token_id()
}

Expand Down
60 changes: 13 additions & 47 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,19 @@ mod processor;
/// ```
#[derive(Clone, Debug, Default, PartialEq, Encode, Decode)]
pub struct Vocabulary {
// TODO: Option is temp for back compatibility
eos_token_id: Option<TokenId>,
eos_token_id: TokenId,
tokens: HashMap<Token, Vec<TokenId>>,
}

impl Vocabulary {
/// Creates an empty vocabulary.
pub fn new(eos_token_id: Option<TokenId>) -> Self {
pub fn new(eos_token_id: TokenId) -> Self {
Self {
eos_token_id,
tokens: HashMap::default(),
}
}

pub fn with_eos_token_id(self, eos_token_id: Option<TokenId>) -> Self {
Self {
eos_token_id,
..self
}
}

/// Creates the vocabulary of pre-trained model from Hugging Face Hub.
pub fn from_pretrained(
model: &str,
Expand Down Expand Up @@ -77,7 +69,7 @@ impl Vocabulary {
};

// Start building the vocabulary from eos_token_id and added tokens.
let mut vocabulary = Vocabulary::new(Some(eos_token_id));
let mut vocabulary = Vocabulary::new(eos_token_id);
for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() {
if !added_token.special {
vocabulary = vocabulary.insert(added_token.content.clone(), *id);
Expand Down Expand Up @@ -110,7 +102,7 @@ impl Vocabulary {
}

/// Gets the identifier of the special end of the sentence token.
pub fn eos_token_id(&self) -> Option<TokenId> {
pub fn eos_token_id(&self) -> TokenId {
self.eos_token_id
}

Expand Down Expand Up @@ -207,7 +199,7 @@ impl From<(TokenId, HashMap<Token, Vec<TokenId>>)> for Vocabulary {
fn from(values: (TokenId, HashMap<Token, Vec<TokenId>>)) -> Vocabulary {
let (eos_token_id, tokens) = values;
Vocabulary {
eos_token_id: Some(eos_token_id),
eos_token_id,
tokens,
}
}
Expand All @@ -217,7 +209,7 @@ impl From<(TokenId, HashMap<String, Vec<TokenId>>)> for Vocabulary {
fn from(values: (TokenId, HashMap<String, Vec<TokenId>>)) -> Vocabulary {
let (eos_token_id, tokens) = values;
Vocabulary {
eos_token_id: Some(eos_token_id),
eos_token_id,
tokens: tokens
.into_iter()
.map(|(k, v)| (k.as_bytes().to_vec(), v))
Expand All @@ -226,24 +218,14 @@ impl From<(TokenId, HashMap<String, Vec<TokenId>>)> for Vocabulary {
}
}

impl<T, I> FromIterator<(T, I)> for Vocabulary
where
T: Into<Token>,
I: IntoIterator<Item = TokenId>,
{
fn from_iter<A: IntoIterator<Item = (T, I)>>(tokens_and_ids: A) -> Self {
Vocabulary::new(None).extend(tokens_and_ids)
}
}

#[cfg(test)]
mod tests {
use super::*;
use rustc_hash::FxHashSet as HashSet;

#[test]
fn insert() {
let vocabulary = Vocabulary::new(None)
let vocabulary = Vocabulary::new(4)
.insert("blah", 0)
.insert("1a", 1)
.insert("2", 2)
Expand All @@ -258,7 +240,7 @@ mod tests {

#[test]
fn extend() {
let vocabulary = Vocabulary::new(None).extend([
let vocabulary = Vocabulary::new(4).extend([
("blah", vec![0]),
("1a", vec![1]),
("2", vec![2]),
Expand All @@ -274,28 +256,19 @@ mod tests {

#[test]
fn new_empty_vocabulary() {
let vocabulary = Vocabulary::new(None);
assert!(vocabulary.eos_token_id.is_none());
let vocabulary = Vocabulary::new(1);
assert_eq!(vocabulary.eos_token_id, 1);
assert!(vocabulary.tokens.is_empty());
}

#[test]
fn new_empty_vocabulary_from_hashmap() {
let vocabulary = Vocabulary::new(None);
assert!(vocabulary.eos_token_id.is_none());
let map: HashMap<Token, Vec<TokenId>> = HashMap::default();
let vocabulary = Vocabulary::from((1_u32, map));
assert_eq!(vocabulary.eos_token_id, 1);
assert!(vocabulary.tokens.is_empty());
}

#[test]
fn new_vocabulary_from_iterator() {
let token: Token = "abc".as_bytes().to_vec();
let id: Vec<TokenId> = vec![1];
let it = vec![(token, id)];
let vocabulary = Vocabulary::from_iter(it);
assert!(vocabulary.eos_token_id.is_none());
assert!(!vocabulary.tokens.is_empty());
}

#[test]
fn supported_pretrained_models() {
// Support is expected for these:
Expand All @@ -315,7 +288,6 @@ mod tests {
let vocabulary = Vocabulary::from_pretrained(model, None);
match vocabulary {
Ok(v) => {
assert!(v.eos_token_id().is_some());
assert_eq!(v.eos_token_id, v.eos_token_id());
assert!(!v.tokens.is_empty());
}
Expand All @@ -332,9 +304,6 @@ mod tests {

let v_eos = vocabulary.eos_token_id;
assert_eq!(v_eos, vocabulary.eos_token_id());
assert!(v_eos.is_some());

let v_eos = v_eos.unwrap();
assert_eq!(v_eos, 50256);
assert_eq!(
tokenizer.id_to_token(v_eos).expect("Token not found"),
Expand Down Expand Up @@ -366,9 +335,6 @@ mod tests {

let v_eos = vocabulary.eos_token_id;
assert_eq!(v_eos, vocabulary.eos_token_id());
assert!(v_eos.is_some());

let v_eos = v_eos.unwrap();
assert_eq!(v_eos, 2);
assert_eq!(
tokenizer.id_to_token(v_eos).expect("Token not found"),
Expand Down
11 changes: 7 additions & 4 deletions tests/fsm/test_vocabulary.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import pickle
import pytest

import pytest
from outlines_core.fsm import Vocabulary


def test_supports_strings_as_keys():
eos_token_id = 3
tokens = {"1": [1], "a": [2]}
vocabulary = Vocabulary.from_dict(eos_token_id, tokens)

assert vocabulary.get_eos_token_id() == eos_token_id
assert vocabulary.get("1") == [1]
assert vocabulary.get(b"1") == [1]
assert len(vocabulary) == 2


def test_supports_bytes_as_keys():
eos_token_id = 3
tokens = {b"1": [1], b"a": [2]}
Expand All @@ -23,16 +25,17 @@ def test_supports_bytes_as_keys():
assert vocabulary.get("1") == [1]
assert len(vocabulary) == 2


def test_do_not_supports_other_types_as_keys():
eos_token_id = 3
tokens = {1: [1], 2: [2]}

with pytest.raises(
TypeError,
match="Expected a dictionary with keys of type String or Bytes"
TypeError, match="Expected a dictionary with keys of type String or Bytes"
):
Vocabulary.from_dict(eos_token_id, tokens)


def test_pickling():
eos_token_id = 3
tokens = {"1": [1], "a": [2]}
Expand Down

0 comments on commit 739b3d9

Please sign in to comment.