Skip to content

Commit

Permalink
Add streamer binding (#13)
Browse files Browse the repository at this point in the history
* Add streamer binding

* remove todo
  • Loading branch information
Wovchena authored May 29, 2024
1 parent 174f67a commit 6709a67
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 13 deletions.
4 changes: 0 additions & 4 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ namespace genai {
*/
class StreamerBase {
public:
explicit StreamerBase(Tokenizer tokenizer) : m_tokenizer(tokenizer) {}
Tokenizer m_tokenizer;
StreamerBase() = default;

/// @brief put is called every time new token is decoded
virtual void put(int64_t token) = 0;

Expand Down
4 changes: 2 additions & 2 deletions src/python/openvino_genai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
if hasattr(os, "add_dll_directory"):
os.add_dll_directory(os.path.dirname(__file__))

from .py_generate_pipeline import LLMPipeline, Tokenizer, GenerationConfig, DecodedResults, EncodedResults
from .py_generate_pipeline import LLMPipeline, Tokenizer, GenerationConfig, DecodedResults, EncodedResults, StreamerBase

__all__ = ['LLMPipeline', 'Tokenizer', 'GenerationConfig', 'DecodedResults', 'EncodedResults']
__all__ = ['LLMPipeline', 'Tokenizer', 'GenerationConfig', 'DecodedResults', 'EncodedResults', 'StreamerBase']
26 changes: 19 additions & 7 deletions src/python/py_generate_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ using ov::genai::GenerationConfig;
using ov::genai::EncodedResults;
using ov::genai::DecodedResults;
using ov::genai::StopCriteria;
using ov::genai::StreamerBase;
using ov::genai::StreamerVariant;

namespace {
Expand Down Expand Up @@ -151,6 +152,20 @@ std::string ov_tokenizers_module_path() {
}
return py::str(py::module_::import("openvino_tokenizers").attr("_ext_path"));
}
class EmptyStreamer: public StreamerBase {
// It's impossible to create an instance of pure virtual class. Define EmptyStreamer instead.
void put(int64_t token) override {
PYBIND11_OVERRIDE_PURE(
void, // Return type
StreamerBase, // Parent class
put, // Name of function in C++ (must match Python name)
token // Argument(s)
);
}
void end() override {
PYBIND11_OVERRIDE_PURE(void, StreamerBase, end);
}
};
}

PYBIND11_MODULE(py_generate_pipeline, m) {
Expand Down Expand Up @@ -182,13 +197,6 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
.def("generate", py::overload_cast<LLMPipeline&, const std::string&, const GenerationConfig&, const StreamerVariant&>(&call_with_config))

// todo: if input_ids is a ov::Tensor/numpy tensor
// todo: implement calling generate/operator() with StreamerBase or lambda streamer
// signature to be implemented:
// EncodedResults generate(ov::Tensor input_ids,
// std::optional<ov::Tensor> attention_mask,
// OptionalGenerationConfig generation_config=nullopt,
// OptionalStreamerVariant streamer=nullopt);


.def("get_tokenizer", &LLMPipeline::get_tokenizer)
.def("start_chat", &LLMPipeline::start_chat)
Expand Down Expand Up @@ -248,4 +256,8 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
.def_readwrite("tokens", &EncodedResults::tokens)
.def_readwrite("scores", &EncodedResults::scores);

py::class_<StreamerBase, EmptyStreamer, std::shared_ptr<StreamerBase>>(m, "StreamerBase") // Change the holder form unique_ptr to shared_ptr
.def(py::init<>())
.def("put", &StreamerBase::put)
.def("end", &StreamerBase::end);
}
62 changes: 62 additions & 0 deletions tests/python_tests/test_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,65 @@ def test_callback_kwargs_batch_fail(model_fixture, callback):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
with pytest.raises(RuntimeError):
pipe.generate(['1', '2'], max_new_tokens=10, streamer=callback)


class Printer(openvino_genai.StreamerBase):
def __init__(self, tokenizer):
super().__init__()
self.tokenizer = tokenizer
def put(self, token_id):
print(self.tokenizer.decode([token_id])) # Incorrect way to print, but easy to implement
def end(self):
print('end')


def test_streamer_one_string(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
pipe.generate('', openvino_genai.GenerationConfig(), printer)


def test_streamer_batch_fail(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
with pytest.raises(RuntimeError):
pipe.generate(['1', '2'], openvino_genai.GenerationConfig(), printer)


def test_streamer_kwargs_one_string(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
pipe.generate('', do_sample=True, streamer=printer)


def test_streamer_kwargs_batch_fail(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
with pytest.raises(RuntimeError):
pipe.generate('', num_beams=2, streamer=printer)


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
def test_operator_wit_callback_one_string(model_fixture, callback):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
pipe('', openvino_genai.GenerationConfig(), callback)


@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)])
def test_operator_wit_callback_batch_fail(model_fixture, callback):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
with pytest.raises(RuntimeError):
pipe(['1', '2'], openvino_genai.GenerationConfig(), callback)


def test_perator_wit_streamer_kwargs_one_string(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
pipe('', do_sample=True, streamer=printer)


def test_erator_wit_streamer_kwargs_batch_fail(model_fixture):
pipe = openvino_genai.LLMPipeline(model_fixture[1], 'CPU')
printer = Printer(pipe.get_tokenizer())
with pytest.raises(RuntimeError):
pipe('', num_beams=2, streamer=printer)

0 comments on commit 6709a67

Please sign in to comment.