From d1a0e8c672bec195e58794515c0626f101857dab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= Date: Mon, 2 Dec 2024 20:20:12 +0100 Subject: [PATCH] Make `RegexGuide` pickleable again for `vllm` and `tgi` (#99) I understand that `pickleable` is not your priority right now. But the `RegexGuide` needs to be pickled for `vllm` production use, which is multiprocessing-based. This PR reintroduces this pickling capability + some tests. I understand that this introduces more effort on your side. References: https://github.com/dottxt-ai/outlines/issues/1274 https://github.com/vllm-project/vllm/pull/10490 https://github.com/vllm-project/vllm/pull/10576 https://github.com/vllm-project/vllm/issues/10489 It would also tackle the current caching issues: https://github.com/huggingface/text-generation-inference/pull/2766 https://github.com/dottxt-ai/outlines/issues/1283 Closes: #95 --- Cargo.toml | 1 + src/index.rs | 3 +- src/python_bindings/mod.rs | 24 +++++++++++++- tests/fsm/test_serialization.py | 56 +++++++++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 tests/fsm/test_serialization.py diff --git a/Cargo.toml b/Cargo.toml index 94eab3a0..2082b140 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ regex = "1.10.6" serde-pyobject = "0.4.0" serde_json = { version = "1.0", features = ["preserve_order"] } serde = {version = "1.0", features = ["derive"]} +bincode = "2.0.0-rc.3" # Fragile dependencies, minor updates often break the code hf-hub = "=0.3.2" tokenizers = { version = "=0.20.3", features = ["http"] } diff --git a/src/index.rs b/src/index.rs index cc1187e8..a756445c 100644 --- a/src/index.rs +++ b/src/index.rs @@ -3,6 +3,7 @@ use crate::prelude::{State, TransitionKey}; use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; use crate::vocabulary::Vocabulary; use crate::{Error, Result}; +use bincode::{Decode, Encode}; use std::collections::{HashMap, HashSet}; #[derive(Debug)] @@ -32,7 +33,7 @@ impl FSMInfo { } } -#[derive(Debug)] +#[derive(Debug, Encode, Decode)] pub struct Index { initial: u32, finals: HashSet, diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 046d7ce9..55d979d1 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -5,6 +5,7 @@ use crate::regex::get_token_transition_keys; use crate::regex::get_vocabulary_transition_keys; use crate::regex::state_scan_tokens; use crate::regex::walk_fsm; +use bincode::config; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -72,7 +73,7 @@ impl PyFSMInfo { } } -#[pyclass(name = "Index")] +#[pyclass(name = "Index", module = "outlines_core.fsm.outlines_core_rs")] pub struct PyIndex(Index); #[pymethods] @@ -89,6 +90,27 @@ impl PyIndex { .map_err(Into::into) } + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { + Python::with_gil(|py| { + let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? + .getattr("Index")?; + let binary_data: Vec = bincode::encode_to_vec(&self.0, config::standard()) + .map_err(|e| { + PyErr::new::(format!("Serialization of Index failed: {}", e)) + })?; + Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,))) + }) + } + + #[staticmethod] + fn from_binary(binary_data: Vec) -> PyResult { + let (index, _): (Index, usize) = + bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| { + PyErr::new::(format!("Deserialization of Index failed: {}", e)) + })?; + Ok(PyIndex(index)) + } + fn get_allowed_tokens(&self, state: u32) -> Option> { self.0.allowed_tokens(state) } diff --git a/tests/fsm/test_serialization.py b/tests/fsm/test_serialization.py new file mode 100644 index 00000000..d3c38365 --- /dev/null +++ b/tests/fsm/test_serialization.py @@ -0,0 +1,56 @@ +import pickle + +import pytest +from outlines_core.fsm.guide import RegexGuide +from transformers import AutoTokenizer + +from tests.fsm.test_regex import TransformerTokenizer + + +def test_serialization(): + class MockTokenizer: + vocabulary = {"1": 1, "a": 2, "eos": 3} + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + regex_str = "[1-9]" + tokenizer = MockTokenizer() + + fsm = RegexGuide.from_regex(regex_str, tokenizer) + + serialized = pickle.dumps(fsm) + deserialized = pickle.loads(serialized) + + assert fsm.eos_tensor == deserialized.eos_tensor + assert fsm.initial_state == deserialized.initial_state + + +@pytest.mark.parametrize( + "hf_tokenizer_uri, revision", + [ + ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), + ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), + ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), + ( + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "783fd50eb82d7f57758de033861f54d62dde234f", + ), + ], +) +def test_complex_serialization(hf_tokenizer_uri, revision): + # The combined regular expressions of a lexer state in a Python grammar + regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + + tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision) + tokenizer = TransformerTokenizer(tokenizer) + + fsm = RegexGuide.from_regex(regex_str, tokenizer) + + serialized = pickle.dumps(fsm) + deserialized = pickle.loads(serialized) + + assert fsm.eos_tensor == deserialized.eos_tensor + assert fsm.initial_state == deserialized.initial_state