From 591c81a2bc4f68bef814e65d0b23d6be6281abf2 Mon Sep 17 00:00:00 2001 From: sbalandi Date: Fri, 3 Jan 2025 23:04:27 +0000 Subject: [PATCH] Add a choice of how to end streaming from callback: STOP or CANCEL --- samples/cpp/text_generation/chat_sample.cpp | 5 +- samples/python/text_generation/chat_sample.py | 3 +- .../prompt_lookup_decoding_lm.py | 10 +- .../speculative_decoding_lm.py | 8 +- .../openvino/genai/generation_handle.hpp | 12 +- .../include/openvino/genai/llm_pipeline.hpp | 2 +- .../include/openvino/genai/streamer_base.hpp | 15 +++ src/cpp/src/continuous_batching_adapter.hpp | 4 +- src/cpp/src/continuous_batching_impl.cpp | 21 +--- src/cpp/src/generation_handle.cpp | 28 +++-- src/cpp/src/generation_stream.hpp | 9 +- src/cpp/src/icontinuous_batching.cpp | 6 +- src/cpp/src/llm_pipeline.cpp | 4 +- src/cpp/src/llm_pipeline_stateful.cpp | 84 ++++++++----- src/cpp/src/llm_pipeline_stateful.hpp | 5 +- src/cpp/src/llm_pipeline_static.cpp | 26 +--- src/cpp/src/lm_encoding.cpp | 25 ++-- src/cpp/src/lm_encoding.hpp | 2 +- .../src/prompt_lookup/prompt_lookup_impl.cpp | 22 +--- src/cpp/src/sequence_group.hpp | 10 +- ...batching_for_speculative_decoding_impl.cpp | 2 +- .../speculative_decoding_impl.cpp | 15 +-- src/cpp/src/text_callback_streamer.cpp | 25 +++- src/cpp/src/text_callback_streamer.hpp | 7 +- src/cpp/src/utils.cpp | 24 ++++ src/cpp/src/utils.hpp | 20 ++- .../src/visual_language/inputs_embedder.cpp | 6 +- src/cpp/src/visual_language/pipeline.cpp | 25 +--- src/python/openvino_genai/__init__.py | 3 +- src/python/openvino_genai/__init__.pyi | 3 +- .../openvino_genai/py_openvino_genai.pyi | 76 ++++++++--- .../py_continuous_batching_pipeline.cpp | 28 ++++- src/python/py_openvino_genai.cpp | 5 + src/python/py_utils.cpp | 38 ++++-- src/python/py_utils.hpp | 2 +- tests/python_tests/test_llm_pipeline.py | 119 ++++++++++++++++-- .../accuracy/continuous_batching_accuracy.cpp | 2 +- ...ntinuous_batching_speculative_decoding.cpp | 2 +- 38 files changed, 477 insertions(+), 226 deletions(-) diff --git a/samples/cpp/text_generation/chat_sample.cpp b/samples/cpp/text_generation/chat_sample.cpp index c0d172563c..61800751d2 100644 --- a/samples/cpp/text_generation/chat_sample.cpp +++ b/samples/cpp/text_generation/chat_sample.cpp @@ -15,11 +15,12 @@ int main(int argc, char* argv[]) try { ov::genai::GenerationConfig config; config.max_new_tokens = 100; - std::function streamer = [](std::string word) { + + auto streamer = [](std::string word) { std::cout << word << std::flush; // Return flag corresponds whether generation should be stopped. // false means continue generation. - return false; + return ov::genai::StreamerRunningStatus::RUNNING; }; pipe.start_chat(); diff --git a/samples/python/text_generation/chat_sample.py b/samples/python/text_generation/chat_sample.py index eee66fb71d..2656cda44b 100755 --- a/samples/python/text_generation/chat_sample.py +++ b/samples/python/text_generation/chat_sample.py @@ -10,8 +10,7 @@ def streamer(subword): print(subword, end='', flush=True) # Return flag corresponds whether generation should be stopped. # False means continue generation. - return False - + return openvino_genai.GenerationStatus.RUNNING def main(): parser = argparse.ArgumentParser() diff --git a/samples/python/text_generation/prompt_lookup_decoding_lm.py b/samples/python/text_generation/prompt_lookup_decoding_lm.py index 726391ba9b..54af23e434 100755 --- a/samples/python/text_generation/prompt_lookup_decoding_lm.py +++ b/samples/python/text_generation/prompt_lookup_decoding_lm.py @@ -5,11 +5,11 @@ import argparse import openvino_genai -def streamer(subword): - print(subword, end='', flush=True) - # Return flag corresponds whether generation should be stopped. - # False means continue generation. - return False +def streamer(subword): + print(subword, end='', flush=True) + # Return flag corresponds whether generation should be stopped. + # False means continue generation. + return False def main(): parser = argparse.ArgumentParser() diff --git a/samples/python/text_generation/speculative_decoding_lm.py b/samples/python/text_generation/speculative_decoding_lm.py index 217b8a2730..740d9b589d 100755 --- a/samples/python/text_generation/speculative_decoding_lm.py +++ b/samples/python/text_generation/speculative_decoding_lm.py @@ -8,10 +8,10 @@ import threading def streamer(subword): - print(subword, end='', flush=True) - # Return flag corresponds whether generation should be stopped. - # False means continue generation. - return False + print(subword, end='', flush=True) + # Return flag corresponds whether generation should be stopped. + # False means continue generation. + return False def main(): parser = argparse.ArgumentParser() diff --git a/src/cpp/include/openvino/genai/generation_handle.hpp b/src/cpp/include/openvino/genai/generation_handle.hpp index 6619e3e012..4a59f44471 100644 --- a/src/cpp/include/openvino/genai/generation_handle.hpp +++ b/src/cpp/include/openvino/genai/generation_handle.hpp @@ -15,8 +15,8 @@ enum class GenerationStatus { RUNNING = 0, // Default status for ongoing generation FINISHED = 1, // Status set when generation has been finished IGNORED = 2, // Status set when generation run into out-of-memory condition and could not be continued - DROPPED_BY_PIPELINE = 3, // Currently not used, TODO: implement abort functionality - DROPPED_BY_HANDLE = 4 // Status set when generation handle is dropped + CANCEL = 3, // Status set when generation handle is canceled + STOP = 4 // Status set when generation handle is stopped }; struct EncodedGenerationResult { @@ -74,7 +74,9 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl { std::shared_ptr m_generation_stream; ov::genai::GenerationConfig m_sampling_params; - bool is_dropped(); + bool is_stopped(); + + bool is_canceled(); public: GenerationHandleImpl(std::shared_ptr generation_stream, const ov::genai::GenerationConfig& sampling_params) : @@ -93,6 +95,10 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl { void drop(); + void stop(); + + void cancel(); + GenerationOutputs back(); // Reads result of a generation for single iteration GenerationOutputs read(); diff --git a/src/cpp/include/openvino/genai/llm_pipeline.hpp b/src/cpp/include/openvino/genai/llm_pipeline.hpp index 31b1ac1675..4e723fc5ab 100644 --- a/src/cpp/include/openvino/genai/llm_pipeline.hpp +++ b/src/cpp/include/openvino/genai/llm_pipeline.hpp @@ -19,7 +19,7 @@ namespace ov { namespace genai { // Return flag corresponds whether generation should be stopped: false means continue generation, true means stop. -using StreamerVariant = std::variant, std::shared_ptr, std::monostate>; +using StreamerVariant = std::variant, std::function, std::shared_ptr, std::monostate>; using OptionalGenerationConfig = std::optional; using EncodedInputs = std::variant; using StringInputs = std::variant>; diff --git a/src/cpp/include/openvino/genai/streamer_base.hpp b/src/cpp/include/openvino/genai/streamer_base.hpp index f286e896e5..f5d9a0e167 100644 --- a/src/cpp/include/openvino/genai/streamer_base.hpp +++ b/src/cpp/include/openvino/genai/streamer_base.hpp @@ -4,10 +4,20 @@ #pragma once #include "openvino/genai/tokenizer.hpp" +#include "openvino/genai/generation_handle.hpp" +#include namespace ov { namespace genai { +enum class StreamerRunningStatus { + RUNNING = 0, // Continue to run of inference + STOP = 1, // Stop generation, keep history as is, KV cache includes last request and generated tokens + CANCEL = 2 // Stop generate, drop last prompt and all generated tokens from history, KV cache include history but last step +}; + +using CallbackTypeVariant = std::variant; + /** * @brief base class for streamers. In order to use inherit from from this class and implement put, and methods * @@ -15,6 +25,7 @@ namespace genai { */ class OPENVINO_GENAI_EXPORTS StreamerBase { public: + StreamerRunningStatus m_streaming_finish_status = StreamerRunningStatus::RUNNING; /// @brief put is called every time new token is decoded, /// @return bool flag to indicate whether generation should be stopped, if return true generation stops virtual bool put(int64_t token) = 0; @@ -22,6 +33,10 @@ class OPENVINO_GENAI_EXPORTS StreamerBase { /// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one virtual void end() = 0; + StreamerRunningStatus get_streaming_status() { + return m_streaming_finish_status; + } + virtual ~StreamerBase(); }; diff --git a/src/cpp/src/continuous_batching_adapter.hpp b/src/cpp/src/continuous_batching_adapter.hpp index 00928b342d..bb12542ca8 100644 --- a/src/cpp/src/continuous_batching_adapter.hpp +++ b/src/cpp/src/continuous_batching_adapter.hpp @@ -97,7 +97,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { std::vector plain_replies; std::vector plain_scores; for (GenerationResult& res : generated) { - OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus"); + OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP || res.m_status == GenerationStatus::CANCEL, "Got unfinished GenerationStatus"); std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies)); std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores)); } @@ -189,7 +189,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { std::vector> plain_tokens; std::vector plain_scores; for (EncodedGenerationResult& res : generated) { - OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus"); + OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP, "Got unfinished GenerationStatus"); std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_tokens)); std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores)); } diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 56e8bd995f..2c4e425112 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -326,17 +326,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector& streamer_ptr = std::visit(overloaded{ - [](std::monostate) -> std::shared_ptr { - return nullptr; - }, - [](const std::shared_ptr& streamer) { - return streamer; - }, - [this](const std::function& streamer) -> std::shared_ptr { - return std::make_unique(m_tokenizer, streamer); - } - }, streamer); + const std::shared_ptr& streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), @@ -367,13 +357,13 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorcan_read()) { std::unordered_map token = generation->back(); for (const auto& gen_token : token.begin()->second.generated_ids) { continue_generation = !streamer_ptr->put(gen_token); if (!continue_generation) { - generation->drop(); + streamer_ptr->get_streaming_status() == StreamerRunningStatus::CANCEL ? generation->cancel() : generation->stop(); break; } } @@ -403,6 +393,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorget_generation_stream()->get_status(); for (size_t i = 0; i < num_outputs; ++i) { const auto & sequence = sequences[i]; @@ -435,7 +426,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_free_non_running_reque std::vector::iterator requests_iterator = m_requests.begin(); while (requests_iterator != m_requests.end()) { const auto& request = *requests_iterator; - if (request->has_finished() || request->handle_dropped()) { + if(request->has_finished() || request->handle_stopped() || request->handle_canceled()) { for (const auto& sequence: request->get_sequences()) { if (m_scheduler->has_block_table(sequence->get_id())) { m_scheduler->free_sequence(sequence->get_id()); @@ -453,7 +444,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_notify_requests_droppe // Notify the last time by pushing empty output // This causes read() to unblock by adding anything to the queue for (SequenceGroup::Ptr& request : m_requests) { - if (request->handle_dropped()) + if (request->handle_stopped() || request->handle_canceled()) request->push_empty_outputs(); } } diff --git a/src/cpp/src/generation_handle.cpp b/src/cpp/src/generation_handle.cpp index 5d92c560e9..9d2c656b7b 100644 --- a/src/cpp/src/generation_handle.cpp +++ b/src/cpp/src/generation_handle.cpp @@ -9,7 +9,7 @@ using namespace ov::genai; GenerationHandleImpl::~GenerationHandleImpl() { - drop(); + stop(); } GenerationStatus GenerationHandleImpl::get_status() { @@ -17,24 +17,36 @@ GenerationStatus GenerationHandleImpl::get_status() { } bool GenerationHandleImpl::can_read() { - return !is_dropped() && m_generation_stream->can_read(); + return !is_canceled() && !is_stopped() && m_generation_stream->can_read(); } -bool GenerationHandleImpl::is_dropped() { - return get_status() == GenerationStatus::DROPPED_BY_HANDLE; +bool GenerationHandleImpl::is_stopped() { + return get_status() == GenerationStatus::STOP; +} + +bool GenerationHandleImpl::is_canceled() { + return get_status() == GenerationStatus::CANCEL; } void GenerationHandleImpl::drop() { - m_generation_stream->drop(); + m_generation_stream->stop(); +} + +void GenerationHandleImpl::stop() { + m_generation_stream->stop(); +} + +void GenerationHandleImpl::cancel() { + m_generation_stream->cancel(); } std::unordered_map GenerationHandleImpl::back() { - OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped."); + OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped."); return m_generation_stream->back(); } std::unordered_map GenerationHandleImpl::read() { - OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped."); + OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped."); return m_generation_stream->read(); } @@ -57,7 +69,7 @@ void add_partial_result(std::unordered_map& partial_ } std::vector GenerationHandleImpl::read_all() { - OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped."); + OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped."); std::vector results; std::unordered_map partial_results; // We iterate until generation is running or there are tokens we haven't read yet diff --git a/src/cpp/src/generation_stream.hpp b/src/cpp/src/generation_stream.hpp index d76d0cf7f4..50ad2d8175 100644 --- a/src/cpp/src/generation_stream.hpp +++ b/src/cpp/src/generation_stream.hpp @@ -51,9 +51,14 @@ class GenerationStream { return m_status; } - void drop() { + void stop() { std::lock_guard lock(m_mutex); - m_status = GenerationStatus::DROPPED_BY_HANDLE; + m_status = GenerationStatus::STOP; + } + + void cancel() { + std::lock_guard lock(m_mutex); + m_status = GenerationStatus::CANCEL; } }; } diff --git a/src/cpp/src/icontinuous_batching.cpp b/src/cpp/src/icontinuous_batching.cpp index 78f8fda8f7..6f14b73cd8 100644 --- a/src/cpp/src/icontinuous_batching.cpp +++ b/src/cpp/src/icontinuous_batching.cpp @@ -77,7 +77,7 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate( const auto decode_start = std::chrono::steady_clock::now(); generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx))); raw_counters.detokenization_durations.emplace_back(std::chrono::steady_clock::now() - decode_start); - if (m_is_chat_conversation && 0 == idx) { + if (m_is_chat_conversation && 0 == idx && res.m_status != ov::genai::GenerationStatus::CANCEL) { m_history.push_back({{"role", "assistant"}, {"content", generated.back()}}); } } @@ -98,6 +98,10 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate( }); } + // if streaming was canceled, prompt/answer of current step shouldn't be presented in history, so let's remove prompt from history + if (m_is_chat_conversation && !encoded.empty() && encoded[0].m_status == ov::genai::GenerationStatus::CANCEL) + m_history.pop_back(); + return decoded; } } diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 6ebef7bfba..722be3d1e2 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -47,7 +47,9 @@ std::pair split_model_descr( std::pair streamer(StreamerVariant func) { if (auto streamer_obj = std::get_if>(&func)) { return {utils::STREAMER_ARG_NAME, Any::make>(*streamer_obj)}; - } else { + } else if (auto streamer_obj = std::get_if>(&func)) { + return {utils::STREAMER_ARG_NAME, Any::make>(*streamer_obj)}; + } else { auto callback = std::get>(func); return {utils::STREAMER_ARG_NAME, Any::make>(callback)}; } diff --git a/src/cpp/src/llm_pipeline_stateful.cpp b/src/cpp/src/llm_pipeline_stateful.cpp index 2a53154c27..2dd3744cd9 100644 --- a/src/cpp/src/llm_pipeline_stateful.cpp +++ b/src/cpp/src/llm_pipeline_stateful.cpp @@ -39,7 +39,7 @@ StatefulLLMPipeline::StatefulLLMPipeline( const ov::genai::GenerationConfig& generation_config) : LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) { utils::apply_slice_before_matmul_transformation(model); - m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model); + m_kv_history_manager.kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model); ov::CompiledModel compiled_model; if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) { @@ -86,6 +86,9 @@ DecodedResults StatefulLLMPipeline::generate( TokenizedInputs encoded_input; + std::string prev_templated_chat_history(m_templated_chat_history); + std::vector prev_tokenized_chat_history(m_tokenized_chat_history); + if (auto input_vector = std::get_if>(&inputs)) { OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts"); encoded_input = m_tokenizer.encode(*input_vector); @@ -104,7 +107,7 @@ DecodedResults StatefulLLMPipeline::generate( m_history.push_back({{"role", "user"}, {"content", prompt}}); constexpr bool add_generation_prompt = true; - auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); // Do not add special tokens in chat scenario to be aligned with HF. auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false)); auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false)); @@ -116,21 +119,24 @@ DecodedResults StatefulLLMPipeline::generate( if (!m_tokenized_chat_history.empty()) { std::set stop_tokens = config.stop_token_ids; trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens); - m_trust_encoded_history = trusted_history_length == SIZE_MAX; } if (m_tokenized_chat_history.empty()) { encoded_input = new_chat_tokens; - } else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) { - // does_kv_cache_need_to_update will be true here if beam search is activated + } else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_history_cache_need_to_update()) { + // does_history_cache_need_to_update will be true here if beam search is activated // in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly // if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager - if (m_kv_history_manager.does_kv_cache_need_to_update()) { + if (m_kv_history_manager.does_history_cache_need_to_update()) { trusted_history_length = m_kv_history_manager.trusted_history_length; } else { - m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length; + size_t num_tokens_to_remove_from_kv_cache = m_tokenized_chat_history.size() - trusted_history_length; // if prev generation was finished because of max len was reached, kv cache is missed one last token, let's keep it - m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0; + num_tokens_to_remove_from_kv_cache -= m_last_disappeared_token.has_value() ? 1 : 0; + + // if streaming was used and canceled on prev step, num_tokens_to_remove_from_kv_cache could be already set and it will be bigger as include answer + prompt + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = num_tokens_to_remove_from_kv_cache > m_kv_history_manager.num_tokens_to_remove_from_kv_cache ? + num_tokens_to_remove_from_kv_cache : m_kv_history_manager.num_tokens_to_remove_from_kv_cache; } ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(), @@ -145,6 +151,7 @@ DecodedResults StatefulLLMPipeline::generate( new_tensor.copy_to(encoded_input.input_ids); encoded_input.attention_mask = new_attention_mask; m_last_disappeared_token = std::nullopt; + m_kv_history_manager.reset_kv_cache = (trusted_history_length == 0); } else { encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); } @@ -169,11 +176,20 @@ DecodedResults StatefulLLMPipeline::generate( auto decode_stop_time = std::chrono::steady_clock::now(); if (is_chat_conversation) { - // Tail of chat template is missing in KV cache. - // Find the tail to concatenate it with the next input prompt. - auto answer = decoded_results.texts[0]; - m_templated_chat_history.append(answer); - m_history.push_back({{"role", "assistant"}, {"content", answer}}); + if (m_chat_generation_finish_status == ov::genai::GenerationStatus::CANCEL) { + // If chat generation process was canceled by user, let's rollback to previous state of history + m_history.pop_back(); + m_kv_history_manager.num_tokens_to_remove_from_kv_cache += m_tokenized_chat_history.size() - prev_tokenized_chat_history.size(); + m_templated_chat_history = prev_templated_chat_history; + m_tokenized_chat_history = prev_tokenized_chat_history; + m_kv_history_manager.reset_kv_cache = m_tokenized_chat_history.empty(); + } else { + // Tail of chat template is missing in KV cache. + // Find the tail to concatenate it with the next input prompt. + auto answer = decoded_results.texts[0]; + m_templated_chat_history.append(answer); + m_history.push_back({{"role", "assistant"}, {"content", answer}}); + } } // generate_durations @@ -218,6 +234,8 @@ EncodedResults StatefulLLMPipeline::generate( if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) std::copy(input_ids.data(), input_ids.data() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history)); + size_t real_input_ids_size = input_ids.get_shape().at(1); + // Tail of previous output in chat mode is missing in KV cache. if (m_last_disappeared_token.has_value()) { attention_mask = ov::genai::utils::push_front_inputs(attention_mask, 1); @@ -234,14 +252,7 @@ EncodedResults StatefulLLMPipeline::generate( // Stateful pipeline does not provide logprobs for prompt tokens OPENVINO_ASSERT(config.echo == false, "Echo is not supported in the stateful pipeline"); - std::shared_ptr streamer_ptr; - if (auto streamer_obj = std::get_if(&streamer)) { - streamer_ptr = nullptr; - } else if (auto streamer_obj = std::get_if>(&streamer)) { - streamer_ptr = *streamer_obj; - } else if (auto callback = std::get_if>(&streamer)) { - streamer_ptr = std::make_shared(m_tokenizer, *callback); - } + std::shared_ptr streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); auto batch_size = input_ids.get_shape().at(0); OPENVINO_ASSERT(streamer_ptr == nullptr || batch_size == 1 && config.num_return_sequences == 1 && @@ -254,7 +265,11 @@ EncodedResults StatefulLLMPipeline::generate( "(input_ids, attention_mask, position_ids, beam_idx) " "but you have '" + std::to_string(num_inputs) + "' inputs"); - ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, m_kv_cache_seq_length_axis, m_adapter_controller); + if (m_kv_history_manager.reset_kv_cache) + reset_kv_state(); + else + ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_manager.num_tokens_to_remove_from_kv_cache, + m_kv_history_manager.kv_cache_seq_length_axis, m_adapter_controller); size_t kv_cache_len = 0; ov::Tensor concatenated_attention_mask; @@ -292,8 +307,7 @@ EncodedResults StatefulLLMPipeline::generate( m_adapter_controller->apply(m_model_runner, config.adapters); } - if (is_chat_conversation && !m_trust_encoded_history) { - m_trust_encoded_history = true; + if (is_chat_conversation) { m_kv_history_manager.reset(); } @@ -321,9 +335,11 @@ EncodedResults StatefulLLMPipeline::generate( m_sampler.set_seed(config.rng_seed); } - ov::genai::EncodedResults result; - std::tie(result, m_last_disappeared_token) = get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, + ov::genai::utils::GenerationFinishInfo finish_info = get_lm_encoded_results(m_model_runner, input_ids, concatenated_attention_mask, streamer_ptr, m_sampler, requests, position_ids, std::nullopt); + ov::genai::EncodedResults result = finish_info.results; + m_last_disappeared_token = finish_info.probably_disappeared_token; + m_chat_generation_finish_status = finish_info.streaming_finish_status; if (is_chat_conversation) { // force remove from kv_cache last answer @@ -332,15 +348,21 @@ EncodedResults StatefulLLMPipeline::generate( m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size; } - std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); + if (m_chat_generation_finish_status == ov::genai::GenerationStatus::CANCEL) { + m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size; + + if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) { + m_tokenized_chat_history.resize(m_tokenized_chat_history.size() - real_input_ids_size); + m_kv_history_manager.num_tokens_to_remove_from_kv_cache += real_input_ids_size; + } + } else { + std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); + } } else { reset_kv_state(); m_last_disappeared_token = std::nullopt; } - if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) - std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); - auto stop_time = std::chrono::steady_clock::now(); // If is called without tokenization then that stat will not be reported. @@ -354,7 +376,6 @@ EncodedResults StatefulLLMPipeline::generate( void StatefulLLMPipeline::start_chat(const std::string& system_message) { is_chat_conversation = true; - m_trust_encoded_history = true; m_kv_history_manager.reset(); m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; m_last_disappeared_token = std::nullopt; @@ -387,7 +408,6 @@ void StatefulLLMPipeline::reset_kv_state() { void StatefulLLMPipeline::finish_chat() { is_chat_conversation = false; - m_trust_encoded_history = true; m_kv_history_manager.reset(); m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; m_last_disappeared_token = std::nullopt; diff --git a/src/cpp/src/llm_pipeline_stateful.hpp b/src/cpp/src/llm_pipeline_stateful.hpp index 968c550a86..05e56aae70 100644 --- a/src/cpp/src/llm_pipeline_stateful.hpp +++ b/src/cpp/src/llm_pipeline_stateful.hpp @@ -24,8 +24,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { // If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache // If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history // so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history - ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0}; - size_t m_kv_cache_seq_length_axis = 2; + ov::genai::utils::HistoryRemoveManager m_kv_history_manager = {0, 0, 2}; + // Finish reason of last generation for chat scenario + ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING; void reset_kv_state(); public: diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index b17ee959c5..441a8794e2 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -661,7 +661,7 @@ void stream_generated_tokens(std::shared_ptr streamer_p std::unordered_map token = handle->back(); for (const auto& gen_token : token.begin()->second.generated_ids) { if (streamer_ptr->put(gen_token)) { - handle->drop(); + streamer_ptr->get_streaming_status() == ov::genai::StreamerRunningStatus::CANCEL ? handle->cancel() : handle->stop(); break; } } @@ -882,14 +882,7 @@ EncodedResults StatefulLLMPipeline::generate( config.set_eos_token_id(m_generation_config.eos_token_id); config.validate(); - std::shared_ptr streamer_ptr; - if (auto streamer_obj = std::get_if(&streamer)) { - streamer_ptr = nullptr; - } else if (auto streamer_obj = std::get_if>(&streamer)) { - streamer_ptr = *streamer_obj; - } else if (auto callback = std::get_if>(&streamer)) { - streamer_ptr = std::make_shared(m_tokenizer, *callback); - } + std::shared_ptr streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); OPENVINO_ASSERT(config.is_greedy_decoding() || config.is_multinomial(), "Currently only greedy and multinomial decoding are supported"); @@ -956,7 +949,7 @@ EncodedResults StatefulLLMPipeline::generate( m_request.set_tensor("input_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast(&input_ids_data))); m_request.set_tensor("position_ids", ov::Tensor(ov::element::i64, ov::Shape{1,1}, reinterpret_cast(&position_ids_data))); - while (sequence_group->is_running() && !sequence_group->handle_dropped()) { + while (sequence_group->is_running() && !sequence_group->handle_stopped()) { // KV Cache is full, no further generation is possible if (position_ids_data + 1 == m_kvcache_total) { sequence_group->set_out_of_memory(); @@ -1351,14 +1344,7 @@ EncodedResults StatelessLLMPipeline::generate( config.set_eos_token_id(m_generation_config.eos_token_id); config.validate(); - std::shared_ptr streamer_ptr; - if (auto streamer_obj = std::get_if(&streamer)) { - streamer_ptr = nullptr; - } else if (auto streamer_obj = std::get_if>(&streamer)) { - streamer_ptr = *streamer_obj; - } else if (auto callback = std::get_if>(&streamer)) { - streamer_ptr = std::make_shared(m_tokenizer, *callback); - } + std::shared_ptr streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); if (!config.is_greedy_decoding() && !config.is_multinomial()) { OPENVINO_THROW("Currently only greedy and multinomial decoding are supported"); @@ -1461,7 +1447,7 @@ EncodedResults StatelessLLMPipeline::generate( std::fill(attention_mask_data, attention_mask_data + m_kvcache_desc.num_stored_tokens - 1u, 1u); attention_mask_data[m_kvcache_desc.total_size - 1] = 1u; - while (sequence_group->is_running() && !sequence_group->handle_dropped()) { + while (sequence_group->is_running() && !sequence_group->handle_stopped()) { sequence_group->schedule_tokens(1); const auto running_sequences = sequence_group->get_running_sequences(); OPENVINO_ASSERT(running_sequences.size() == 1u); @@ -1480,7 +1466,7 @@ EncodedResults StatelessLLMPipeline::generate( {sequence_group}, m_kvcache_request.get_tensor("logits")); stream_generated_tokens(streamer_ptr, handle); - if (sequence_group->handle_dropped()) + if (sequence_group->handle_stopped()) break; // NB: KV-cache is full, further generation is impossible diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index e2ec3a1b33..a5bc1b7bc0 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -13,6 +13,7 @@ #include "debug_utils.hpp" #include "lm_encoding.hpp" #include "openvino/genai/perf_metrics.hpp" +#include "openvino/genai/streamer_base.hpp" namespace ov { @@ -67,7 +68,7 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector> get_lm_encoded_results( +ov::genai::utils::GenerationFinishInfo get_lm_encoded_results( ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, @@ -91,7 +92,7 @@ std::pair> get_lm_encoded_results( std::unordered_map token = handle->back(); for (const auto& gen_token : token.begin()->second.generated_ids) { if (streamer_ptr->put(gen_token)) { - handle->drop(); + streamer_ptr->get_streaming_status() == StreamerRunningStatus::CANCEL ? handle->cancel() : handle->stop(); break; } } @@ -101,7 +102,7 @@ std::pair> get_lm_encoded_results( auto free_non_running_requests = [&streamer_ptr, &generations, &active_sequence_groups]() { auto removed_it = std::remove_if(active_sequence_groups.begin(), active_sequence_groups.end(), [](SequenceGroup::Ptr sg) -> bool { - return sg->has_finished() || sg->handle_dropped(); + return sg->has_finished() || sg->handle_stopped() || sg->handle_canceled(); }); active_sequence_groups.erase(removed_it, active_sequence_groups.end()); }; @@ -111,8 +112,8 @@ std::pair> get_lm_encoded_results( // Initialize results and performance metrics. - EncodedResults results; - auto& raw_perf_counters = results.perf_metrics.raw_metrics; + ov::genai::utils::GenerationFinishInfo finish_info; + auto& raw_perf_counters = finish_info.results.perf_metrics.raw_metrics; raw_perf_counters.m_inference_durations = {{ MicroSeconds(0.0f) }}; // Initialize inputs @@ -248,25 +249,25 @@ std::pair> get_lm_encoded_results( auto sampling_params = sequence_group->get_sampling_parameters(); const auto& sequences = sequence_group->get_finished_sequences(); size_t num_outputs = std::min(sequence_group->get_sampling_parameters().num_return_sequences, sequences.size()); + finish_info.streaming_finish_status = sequence_group->get_generation_stream()->get_status(); for (size_t seq_id = 0; seq_id < num_outputs; ++seq_id) { const auto & sequence = sequences[seq_id]; const float score = sampling_params.is_beam_search() ? sequence->get_beam_search_score(sampling_params) : sequence->get_cumulative_log_prob(); - results.tokens.push_back(sequence->get_generated_ids()); - results.scores.push_back(score); + finish_info.results.tokens.push_back(sequence->get_generated_ids()); + finish_info.results.scores.push_back(score); } } for (SequenceGroup::Ptr sequence_group : sequence_groups) sampler.clear_request_info(sequence_group->get_request_id()); - // it is not saved in KV cache, we need to add it for some cases - std::optional last_token_of_best_sequence = std::nullopt; - if (sequence_groups[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || sequence_groups[0]->handle_dropped()) - last_token_of_best_sequence = results.tokens[0].back(); + // last generated token is not saved in KV cache, we need to add it for some cases + if (sequence_groups[0]->get_finished_sequences()[0]->get_finish_reason() == GenerationFinishReason::LENGTH || sequence_groups[0]->handle_stopped()) + finish_info.probably_disappeared_token = finish_info.results.tokens[0].back(); - return {results, last_token_of_best_sequence}; + return finish_info; } } // namespace genai diff --git a/src/cpp/src/lm_encoding.hpp b/src/cpp/src/lm_encoding.hpp index 56f6db5227..7b34dcb88e 100644 --- a/src/cpp/src/lm_encoding.hpp +++ b/src/cpp/src/lm_encoding.hpp @@ -8,7 +8,7 @@ namespace ov { namespace genai { -std::pair> get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, +ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector sequence_groups, std::optional position_ids, std::optional m_embedding, std::optional rope_delta = std::nullopt); diff --git a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp index 9eb54f700c..5a0c9b7155 100644 --- a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp +++ b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp @@ -1,6 +1,7 @@ // Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 +#include "utils.hpp" #include "prompt_lookup_impl.hpp" #include "text_callback_streamer.hpp" @@ -106,17 +107,7 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vectorset_adapters(sampling_params[0].adapters); - const std::shared_ptr& streamer_ptr = std::visit(overloaded{ - [](std::monostate) -> std::shared_ptr { - return nullptr; - }, - [](const std::shared_ptr& streamer) { - return streamer; - }, - [this](const std::function& streamer) -> std::shared_ptr { - return std::make_unique(m_tokenizer, streamer); - } - }, streamer); + const std::shared_ptr& streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); @@ -149,7 +140,7 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vectorsecond.generated_ids) { continue_generation = !streamer_ptr->put(gen_token); if (!continue_generation) { - generation->drop(); + streamer_ptr->get_streaming_status() == StreamerRunningStatus::CANCEL ? generation->cancel() : generation->stop(); break; } } @@ -160,12 +151,6 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vectorend(); } - if (!continue_generation) { - drop_requests(); - } else { - OPENVINO_ASSERT(m_pipeline->is_requests_empty(), "Internal error: current request is supposed to be dropped within step() function as completed"); - } - std::vector results; results.reserve(all_requests.size()); @@ -179,6 +164,7 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vectorget_generation_stream()->get_status(); for (size_t i = 0; i < num_outputs; ++i) { const auto & sequence = sequences[i]; diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index fef9757b43..94601039ae 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -348,7 +348,7 @@ class SequenceGroup : public std::enable_shared_from_this { finished_seqs.reserve(num_total_seqs()); for (size_t seq_id = 0; seq_id < m_sequences.size(); ++seq_id) { - if (m_sequences[seq_id]->has_finished() || m_sequences[seq_id]->out_of_memory() || handle_dropped()) { + if (m_sequences[seq_id]->has_finished() || m_sequences[seq_id]->out_of_memory() || handle_stopped() || handle_canceled()) { finished_seqs.push_back(m_sequences[seq_id]); } } @@ -589,8 +589,12 @@ class SequenceGroup : public std::enable_shared_from_this { m_generation_stream->set_generation_status(status); } - bool handle_dropped() const { - return m_generation_stream->get_status() == GenerationStatus::DROPPED_BY_HANDLE; + bool handle_stopped() const { + return m_generation_stream->get_status() == GenerationStatus::STOP; + } + + bool handle_canceled() const { + return m_generation_stream->get_status() == GenerationStatus::CANCEL; } void push_empty_outputs() { diff --git a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp index bec2b75e0d..9a0cdc56d4 100644 --- a/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/continuous_batching_for_speculative_decoding_impl.cpp @@ -28,7 +28,7 @@ ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::finish } } m_sampler->clear_request_info(request->get_request_id()); - request->set_generation_status(GenerationStatus::DROPPED_BY_HANDLE); + request->set_generation_status(GenerationStatus::STOP); } void ContinuousBatchingPipeline::ContinuousBatchingForSpeculativeDecodingImpl::finish_request(int64_t request_id) { diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 7a6066fc5c..87615096da 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -227,17 +227,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< m_main_pipeline->set_adapters(sampling_params[0].adapters); m_draft_pipeline->set_adapters(sampling_params[0].adapters); - const std::shared_ptr& streamer_ptr = std::visit(overloaded{ - [](std::monostate) -> std::shared_ptr { - return nullptr; - }, - [](const std::shared_ptr& streamer) { - return streamer; - }, - [this](const std::function& streamer) -> std::shared_ptr { - return std::make_unique(m_tokenizer, streamer); - } - }, streamer); + const std::shared_ptr& streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && (sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()), "Currently streaming is possible only with batch size=1 and only for greedy or multinomial decoding"); @@ -274,7 +264,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< for (const auto& gen_token : token.begin()->second.generated_ids) { continue_generation = !streamer_ptr->put(gen_token); if (!continue_generation) { - main_generation->drop(); + streamer_ptr->get_streaming_status() == StreamerRunningStatus::CANCEL ? main_generation->cancel() : main_generation->stop(); break; } } @@ -306,6 +296,7 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< result.m_request_id = request_id; result.m_generation_ids.resize(num_outputs); result.m_scores.resize(num_outputs); + result.m_status = request->get_generation_stream()->get_status(); for (size_t i = 0; i < num_outputs; ++i) { const auto & sequence = sequences[i]; diff --git a/src/cpp/src/text_callback_streamer.cpp b/src/cpp/src/text_callback_streamer.cpp index 4c4db4311f..17619933c5 100644 --- a/src/cpp/src/text_callback_streamer.cpp +++ b/src/cpp/src/text_callback_streamer.cpp @@ -6,7 +6,7 @@ namespace ov { namespace genai { -TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback) { +TextCallbackStreamer::TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback) { m_tokenizer = tokenizer; on_finalized_subword_callback = callback; } @@ -23,7 +23,7 @@ bool TextCallbackStreamer::put(int64_t token) { m_tokens_cache.clear(); m_decoded_lengths.clear(); m_printed_len = 0; - return on_finalized_subword_callback(res.str()); + return is_generation_complete(on_finalized_subword_callback(res.str())); } constexpr size_t delay_n_tokens = 3; @@ -31,13 +31,13 @@ bool TextCallbackStreamer::put(int64_t token) { // e.g. when apostrophe removing regex had worked after adding new tokens. // Printing several last tokens is delayed. if (m_decoded_lengths.size() < delay_n_tokens) { - return on_finalized_subword_callback(res.str()); + return is_generation_complete(on_finalized_subword_callback(res.str())); } constexpr char replacement[] = "\xef\xbf\xbd"; // MSVC with /utf-8 fails to compile � directly with newline in string literal error. if (text.size() >= 3 && text.compare(text.size() - 3, 3, replacement) == 0) { m_decoded_lengths[m_decoded_lengths.size() - 1] = -1; // Don't print incomplete text - return on_finalized_subword_callback(res.str()); + return is_generation_complete(on_finalized_subword_callback(res.str())); } auto print_until = m_decoded_lengths[m_decoded_lengths.size() - delay_n_tokens]; if (print_until != -1 && print_until > m_printed_len) { @@ -47,7 +47,22 @@ bool TextCallbackStreamer::put(int64_t token) { m_printed_len = print_until; } - return on_finalized_subword_callback(res.str()); + return is_generation_complete(on_finalized_subword_callback(res.str())); +} + +bool TextCallbackStreamer::is_generation_complete(CallbackTypeVariant callback_status) { + bool is_complete = false; + if (auto status = std::get_if(&callback_status)) { + m_streaming_finish_status = *status; + is_complete = (m_streaming_finish_status == StreamerRunningStatus::STOP || m_streaming_finish_status == StreamerRunningStatus::CANCEL); + } else if (auto status = std::get_if(&callback_status)) { + is_complete = *status; + m_streaming_finish_status = *status ? StreamerRunningStatus::STOP : StreamerRunningStatus::RUNNING; + } else if (auto status = std::get_if(&callback_status)) { + m_streaming_finish_status = StreamerRunningStatus::RUNNING; + } + + return is_complete; } void TextCallbackStreamer::end() { diff --git a/src/cpp/src/text_callback_streamer.hpp b/src/cpp/src/text_callback_streamer.hpp index 2c5fab5700..797d09342e 100644 --- a/src/cpp/src/text_callback_streamer.hpp +++ b/src/cpp/src/text_callback_streamer.hpp @@ -12,11 +12,14 @@ namespace genai { class TextCallbackStreamer: public StreamerBase { public: bool put(int64_t token) override; + void end() override; - TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback); + TextCallbackStreamer(const Tokenizer& tokenizer, std::function callback); + + bool is_generation_complete(CallbackTypeVariant callback_status); - std::function on_finalized_subword_callback = [](std::string words)->bool { return false; }; + std::function on_finalized_subword_callback = [](std::string words)->bool { return false; }; protected: Tokenizer m_tokenizer; diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 2d6dfd2ae5..66114f8880 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -3,6 +3,7 @@ #include "utils.hpp" +#include #include #include @@ -15,6 +16,8 @@ #include "openvino/op/tanh.hpp" #include "openvino/op/transpose.hpp" +#include "text_callback_streamer.hpp" + #include "sampler.hpp" namespace ov { @@ -174,11 +177,32 @@ ov::genai::StreamerVariant get_streamer_from_map(const ov::AnyMap& config_map) { streamer = any_val.as>(); } else if (any_val.is>()) { streamer = any_val.as>(); + } else if (any_val.is>()) { + streamer = any_val.as>(); } } return streamer; } +std::shared_ptr create_streamer(StreamerVariant streamer, Tokenizer tokenizer) { + std::shared_ptr streamer_ptr = std::visit(overloaded{ + [](std::monostate) -> std::shared_ptr { + return nullptr; + }, + [](const std::shared_ptr& streamer) { + return streamer; + }, + [&tokenizer = tokenizer](const std::function& streamer) -> std::shared_ptr { + return std::make_unique(tokenizer, streamer); + }, + [&tokenizer = tokenizer](const std::function& streamer) -> std::shared_ptr { + return std::make_unique(tokenizer, streamer); + } + }, streamer); + + return streamer_ptr; +} + ov::genai::OptionalGenerationConfig get_config_from_map(const ov::AnyMap& config_map) { if (config_map.count(CONFIG_ARG_NAME)) return config_map.at(CONFIG_ARG_NAME).as(); diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index ff1aea1ae9..8d9c53c4d0 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -9,6 +9,8 @@ #include "visual_language/processor_config.hpp" +#include "openvino/genai/streamer_base.hpp" + namespace ov { namespace genai { namespace utils { @@ -32,17 +34,27 @@ struct HistoryRemoveManager { size_t num_tokens_to_remove_from_kv_cache = 0; size_t trusted_history_length = 0; + size_t kv_cache_seq_length_axis = 2; + bool reset_kv_cache = false; - bool does_kv_cache_need_to_update() { - return (trusted_history_length > 0 || num_tokens_to_remove_from_kv_cache > 0); + bool does_history_cache_need_to_update() { + return (trusted_history_length > 0 && num_tokens_to_remove_from_kv_cache > 0); } void reset() { num_tokens_to_remove_from_kv_cache = 0; trusted_history_length = 0; + reset_kv_cache = false; } }; +struct GenerationFinishInfo +{ + EncodedResults results; + std::optional probably_disappeared_token = std::nullopt; + GenerationStatus streaming_finish_status; +}; + Tensor init_attention_mask(const Tensor& position_ids); void print_tensor(const ov::Tensor& tensor); @@ -118,6 +130,10 @@ ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, int64_t add_to_front void print_compiled_model_properties(ov::CompiledModel& compiled_Model, const char* model_title); +template struct overloaded : Ts... {using Ts::operator()...;}; +template overloaded(Ts...) -> overloaded; +std::shared_ptr create_streamer(StreamerVariant streamer, Tokenizer tokenizer); + } // namespace utils } // namespace genai } // namespace ov diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index 66b17e5804..1642377872 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -201,11 +201,11 @@ class InputsEmbedder::IInputsEmbedder { if (m_tokenized_history.empty()) { encoded_input_ids = new_chat_tokens; - } else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_kv_cache_need_to_update()) { - // does_kv_cache_need_to_update will be true here if beam search is activated + } else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_history_cache_need_to_update()) { + // does_history_cache_need_to_update will be true here if beam search is activated // in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly // if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager - if (m_kv_history_manager.does_kv_cache_need_to_update()) { + if (m_kv_history_manager.does_history_cache_need_to_update()) { trusted_history_length = m_kv_history_manager.trusted_history_length; } else { m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_history.size() - trusted_history_length; diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index 95e3064548..15250f3ec4 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -21,9 +21,6 @@ using namespace ov::genai; namespace { - -template struct overloaded : Ts... {using Ts::operator()...;}; -template overloaded(Ts...) -> overloaded; constexpr size_t BATCH_SIZE = 1; @@ -187,19 +184,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { SequenceGroup::Ptr sequence_group = std::make_shared(request_id, prompt_ids, generation_config, block_size); requests.push_back(sequence_group); - std::shared_ptr streamer_ptr = std::visit(overloaded{ - [&m_tokenizer = m_tokenizer]( - const std::function& callback - ) -> std::shared_ptr { - return std::make_shared(m_tokenizer, callback); - }, - [](const std::shared_ptr& ptr) { - return ptr; - }, - [](std::monostate) { - return std::shared_ptr{nullptr}; - }, - }, streamer); + std::shared_ptr streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer); OPENVINO_ASSERT(streamer_ptr == nullptr || generation_config.num_return_sequences == 1 && (generation_config.is_greedy_decoding() || generation_config.is_multinomial()), @@ -216,10 +201,10 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { m_sampler.set_seed(generation_config.rng_seed); } - ov::genai::EncodedResults encoded_result; - std::optional last_disappeared_token; - std::tie(encoded_result, last_disappeared_token) = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests, + ov::genai::utils::GenerationFinishInfo finish_info = ov::genai::get_lm_encoded_results(m_language, inputs_embeds, new_atten_mask, streamer_ptr, m_sampler, requests, position_ids, m_embedding, rope_delta); + ov::genai::EncodedResults encoded_result = finish_info.results; + auto decode_start_time = std::chrono::steady_clock::now(); VLMDecodedResults decoded; @@ -229,7 +214,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { } auto decode_end_time = std::chrono::steady_clock::now(); - m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], last_disappeared_token, generation_config.is_beam_search(), + m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], finish_info.probably_disappeared_token, generation_config.is_beam_search(), m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size)); std::string decoded_results = decoded.texts.at(0); diff --git a/src/python/openvino_genai/__init__.py b/src/python/openvino_genai/__init__.py index 0ad7ba3f12..2772f1c47b 100644 --- a/src/python/openvino_genai/__init__.py +++ b/src/python/openvino_genai/__init__.py @@ -16,6 +16,7 @@ PerfMetrics, StreamerBase, get_version, + StreamerRunningStatus ) __version__ = get_version() @@ -84,5 +85,5 @@ GenerationResult, SchedulerConfig, CacheEvictionConfig, - AggregationMode, + AggregationMode ) diff --git a/src/python/openvino_genai/__init__.pyi b/src/python/openvino_genai/__init__.pyi index 0a401ae958..e54d898b2e 100644 --- a/src/python/openvino_genai/__init__.pyi +++ b/src/python/openvino_genai/__init__.pyi @@ -30,6 +30,7 @@ from openvino_genai.py_openvino_genai import Scheduler from openvino_genai.py_openvino_genai import SchedulerConfig from openvino_genai.py_openvino_genai import StopCriteria from openvino_genai.py_openvino_genai import StreamerBase +from openvino_genai.py_openvino_genai import StreamerRunningStatus from openvino_genai.py_openvino_genai import T5EncoderModel from openvino_genai.py_openvino_genai import Text2ImagePipeline from openvino_genai.py_openvino_genai import TokenizedInputs @@ -45,5 +46,5 @@ from openvino_genai.py_openvino_genai import draft_model from openvino_genai.py_openvino_genai import get_version import os as os from . import py_openvino_genai -__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationResult', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version', 'openvino', 'os', 'py_openvino_genai'] +__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationResult', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'StreamerRunningStatus', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version', 'openvino', 'os', 'py_openvino_genai'] __version__: str diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index cf80e973f4..11ed46f8cf 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -5,7 +5,7 @@ from __future__ import annotations import openvino._pyopenvino import os import typing -__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version'] +__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'StopCriteria', 'StreamerBase', 'StreamerRunningStatus', 'T5EncoderModel', 'Text2ImagePipeline', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version'] class Adapter: """ Immutable LoRA Adapter that carries the adaptation matrices and serves as unique adapter identifier. @@ -379,10 +379,10 @@ class ContinuousBatchingPipeline: def add_request(self, request_id: int, prompt: str, generation_config: GenerationConfig) -> GenerationHandle: ... @typing.overload - def generate(self, input_ids: list[openvino._pyopenvino.Tensor], generation_config: list[GenerationConfig], streamer: typing.Callable[[str], bool] | StreamerBase | None = None) -> list[EncodedGenerationResult]: + def generate(self, input_ids: list[openvino._pyopenvino.Tensor], generation_config: list[GenerationConfig], streamer: typing.Callable[[str], int | None] | StreamerBase | None = None) -> list[EncodedGenerationResult]: ... @typing.overload - def generate(self, prompts: list[str], generation_config: list[GenerationConfig], streamer: typing.Callable[[str], bool] | StreamerBase | None = None) -> list[GenerationResult]: + def generate(self, prompts: list[str], generation_config: list[GenerationConfig], streamer: typing.Callable[[str], int | None] | StreamerBase | None = None) -> list[GenerationResult]: ... def get_config(self) -> GenerationConfig: ... @@ -444,8 +444,8 @@ class EncodedGenerationResult: RUNNING = 0 - Default status for ongoing generation. FINISHED = 1 - Status set when generation has been finished. IGNORED = 2 - Status set when generation run into out-of-memory condition and could not be continued. - DROPPED_BY_PIPELINE = 3 - Currently not used, TODO: implement abort functionality. - DROPPED_BY_HANDLE = 4 - Status set when generation handle is dropped. + CANCEL = 3 - Status set when generation handle is canceled. + STOP = 4 - Status set when generation handle is stopped. perf_metrics: Performance metrics for each generation result. @@ -675,6 +675,8 @@ class GenerationHandle: ... def can_read(self) -> bool: ... + def cancel(self) -> None: + ... def drop(self) -> None: ... def get_status(self) -> GenerationStatus: @@ -683,6 +685,8 @@ class GenerationHandle: ... def read_all(self) -> list[GenerationOutput]: ... + def stop(self) -> None: + ... class GenerationOutput: finish_reason: GenerationFinishReason generated_ids: list[int] @@ -702,8 +706,8 @@ class GenerationResult: RUNNING = 0 - Default status for ongoing generation. FINISHED = 1 - Status set when generation has been finished. IGNORED = 2 - Status set when generation run into out-of-memory condition and could not be continued. - DROPPED_BY_PIPELINE = 3 - Currently not used, TODO: implement abort functionality. - DROPPED_BY_HANDLE = 4 - Status set when generation handle is dropped. + CANCEL = 3 - Status set when generation handle is canceled. + STOP = 4 - Status set when generation handle is stopped. perf_metrics: Performance metrics for each generation result. @@ -733,16 +737,16 @@ class GenerationStatus: IGNORED - DROPPED_BY_PIPELINE + CANCEL - DROPPED_BY_HANDLE + STOP """ - DROPPED_BY_HANDLE: typing.ClassVar[GenerationStatus] # value = - DROPPED_BY_PIPELINE: typing.ClassVar[GenerationStatus] # value = + CANCEL: typing.ClassVar[GenerationStatus] # value = FINISHED: typing.ClassVar[GenerationStatus] # value = IGNORED: typing.ClassVar[GenerationStatus] # value = RUNNING: typing.ClassVar[GenerationStatus] # value = - __members__: typing.ClassVar[dict[str, GenerationStatus]] # value = {'RUNNING': , 'FINISHED': , 'IGNORED': , 'DROPPED_BY_PIPELINE': , 'DROPPED_BY_HANDLE': } + STOP: typing.ClassVar[GenerationStatus] # value = + __members__: typing.ClassVar[dict[str, GenerationStatus]] # value = {'RUNNING': , 'FINISHED': , 'IGNORED': , 'CANCEL': , 'STOP': } def __eq__(self, other: typing.Any) -> bool: ... def __getstate__(self) -> int: @@ -959,7 +963,7 @@ class LLMPipeline: """ This class is used for generation with LLMs """ - def __call__(self, inputs: openvino._pyopenvino.Tensor | TokenizedInputs | str | list[str], generation_config: GenerationConfig | None = None, streamer: typing.Callable[[str], bool] | StreamerBase | None = None, **kwargs) -> EncodedResults | DecodedResults: + def __call__(self, inputs: openvino._pyopenvino.Tensor | TokenizedInputs | str | list[str], generation_config: GenerationConfig | None = None, streamer: typing.Callable[[str], int | None] | StreamerBase | None = None, **kwargs) -> EncodedResults | DecodedResults: """ Generates sequences or tokens for LLMs. If input is a string or list of strings then resulting sequences will be already detokenized. @@ -1044,7 +1048,7 @@ class LLMPipeline: """ def finish_chat(self) -> None: ... - def generate(self, inputs: openvino._pyopenvino.Tensor | TokenizedInputs | str | list[str], generation_config: GenerationConfig | None = None, streamer: typing.Callable[[str], bool] | StreamerBase | None = None, **kwargs) -> EncodedResults | DecodedResults: + def generate(self, inputs: openvino._pyopenvino.Tensor | TokenizedInputs | str | list[str], generation_config: GenerationConfig | None = None, streamer: typing.Callable[[str], int | None] | StreamerBase | None = None, **kwargs) -> EncodedResults | DecodedResults: """ Generates sequences or tokens for LLMs. If input is a string or list of strings then resulting sequences will be already detokenized. @@ -1516,6 +1520,46 @@ class StreamerBase: """ Put is called every time new token is decoded. Returns a bool flag to indicate whether generation should be stopped, if return true generation stops """ +class StreamerRunningStatus: + """ + Members: + + RUNNING + + CANCEL + + STOP + """ + CANCEL: typing.ClassVar[StreamerRunningStatus] # value = + RUNNING: typing.ClassVar[StreamerRunningStatus] # value = + STOP: typing.ClassVar[StreamerRunningStatus] # value = + __members__: typing.ClassVar[dict[str, StreamerRunningStatus]] # value = {'RUNNING': , 'CANCEL': , 'STOP': } + def __eq__(self, other: typing.Any) -> bool: + ... + def __getstate__(self) -> int: + ... + def __hash__(self) -> int: + ... + def __index__(self) -> int: + ... + def __init__(self, value: int) -> None: + ... + def __int__(self) -> int: + ... + def __ne__(self, other: typing.Any) -> bool: + ... + def __repr__(self) -> str: + ... + def __setstate__(self, state: int) -> None: + ... + def __str__(self) -> str: + ... + @property + def name(self) -> str: + ... + @property + def value(self) -> int: + ... class T5EncoderModel: """ T5EncoderModel class. @@ -1821,7 +1865,7 @@ class VLMPipeline: def finish_chat(self) -> None: ... @typing.overload - def generate(self, prompt: str, images: list[openvino._pyopenvino.Tensor], generation_config: GenerationConfig, streamer: typing.Callable[[str], bool] | StreamerBase | None = None, **kwargs) -> VLMDecodedResults: + def generate(self, prompt: str, images: list[openvino._pyopenvino.Tensor], generation_config: GenerationConfig, streamer: typing.Callable[[str], int | None] | StreamerBase | None = None, **kwargs) -> VLMDecodedResults: """ Generates sequences for VLMs. @@ -1844,7 +1888,7 @@ class VLMPipeline: :rtype: VLMDecodedResults """ @typing.overload - def generate(self, prompt: str, images: openvino._pyopenvino.Tensor, generation_config: GenerationConfig, streamer: typing.Callable[[str], bool] | StreamerBase | None = None, **kwargs) -> VLMDecodedResults: + def generate(self, prompt: str, images: openvino._pyopenvino.Tensor, generation_config: GenerationConfig, streamer: typing.Callable[[str], int | None] | StreamerBase | None = None, **kwargs) -> VLMDecodedResults: """ Generates sequences for VLMs. diff --git a/src/python/py_continuous_batching_pipeline.cpp b/src/python/py_continuous_batching_pipeline.cpp index aa08eb6da1..20944e3749 100644 --- a/src/python/py_continuous_batching_pipeline.cpp +++ b/src/python/py_continuous_batching_pipeline.cpp @@ -83,8 +83,8 @@ auto generation_result_docstring = R"( RUNNING = 0 - Default status for ongoing generation. FINISHED = 1 - Status set when generation has been finished. IGNORED = 2 - Status set when generation run into out-of-memory condition and could not be continued. - DROPPED_BY_PIPELINE = 3 - Currently not used, TODO: implement abort functionality. - DROPPED_BY_HANDLE = 4 - Status set when generation handle is dropped. + CANCEL = 3 - Status set when generation handle is canceled. + STOP = 4 - Status set when generation handle is stopped. perf_metrics: Performance metrics for each generation result. @@ -130,8 +130,8 @@ void init_continuous_batching_pipeline(py::module_& m) { .value("RUNNING", ov::genai::GenerationStatus::RUNNING) .value("FINISHED", ov::genai::GenerationStatus::FINISHED) .value("IGNORED", ov::genai::GenerationStatus::IGNORED) - .value("DROPPED_BY_PIPELINE", ov::genai::GenerationStatus::DROPPED_BY_PIPELINE) - .value("DROPPED_BY_HANDLE", ov::genai::GenerationStatus::DROPPED_BY_HANDLE); + .value("CANCEL", ov::genai::GenerationStatus::CANCEL) + .value("STOP", ov::genai::GenerationStatus::STOP); py::class_(m, "GenerationResult", generation_result_docstring) .def(py::init<>()) @@ -180,6 +180,8 @@ void init_continuous_batching_pipeline(py::module_& m) { .def("get_status", &GenerationHandleImpl::get_status) .def("can_read", &GenerationHandleImpl::can_read) .def("drop", &GenerationHandleImpl::drop) + .def("stop", &GenerationHandleImpl::stop) + .def("cancel", &GenerationHandleImpl::cancel) .def("back", &GenerationHandleImpl::back) .def("read", &GenerationHandleImpl::read) .def("read_all", &GenerationHandleImpl::read_all); @@ -251,14 +253,28 @@ void init_continuous_batching_pipeline(py::module_& m) { .def("has_non_finished_requests", &ContinuousBatchingPipeline::has_non_finished_requests) .def( "generate", - py::overload_cast&, const std::vector&, const ov::genai::StreamerVariant&>(&ContinuousBatchingPipeline::generate), + [](ov::genai::ContinuousBatchingPipeline& pipe, + const std::vector& input_ids, + const std::vector& sampling_params, + const pyutils::PyBindStreamerVariant& py_streamer + ) -> py::typing::List { + const ov::genai::StreamerVariant streamer = pyutils::pystreamer_to_streamer(py_streamer); + return py::cast(pipe.generate(input_ids, sampling_params, streamer)); + }, py::arg("input_ids"), py::arg("generation_config"), py::arg("streamer") = std::monostate{} ) .def( "generate", - py::overload_cast&, const std::vector&, const ov::genai::StreamerVariant&>(&ContinuousBatchingPipeline::generate), + [](ov::genai::ContinuousBatchingPipeline& pipe, + const std::vector& prompts, + const std::vector& sampling_params, + const pyutils::PyBindStreamerVariant& py_streamer + ) -> py::typing::List { + const ov::genai::StreamerVariant streamer = pyutils::pystreamer_to_streamer(py_streamer); + return py::cast(pipe.generate(prompts, sampling_params, streamer)); + }, py::arg("prompts"), py::arg("generation_config"), py::arg("streamer") = std::monostate{} diff --git a/src/python/py_openvino_genai.cpp b/src/python/py_openvino_genai.cpp index 8b8bd831b0..492d149385 100644 --- a/src/python/py_openvino_genai.cpp +++ b/src/python/py_openvino_genai.cpp @@ -118,6 +118,11 @@ PYBIND11_MODULE(py_openvino_genai, m) { .def("put", &StreamerBase::put, "Put is called every time new token is decoded. Returns a bool flag to indicate whether generation should be stopped, if return true generation stops", py::arg("token")) .def("end", &StreamerBase::end, "End is called at the end of generation. It can be used to flush cache if your own streamer has one"); + py::enum_(m, "StreamerRunningStatus") + .value("RUNNING", ov::genai::StreamerRunningStatus::RUNNING) + .value("CANCEL", ov::genai::StreamerRunningStatus::CANCEL) + .value("STOP", ov::genai::StreamerRunningStatus::STOP); + init_tokenizer(m); init_lora_adapter(m); init_generation_config(m); diff --git a/src/python/py_utils.cpp b/src/python/py_utils.cpp index 90cce498cd..0124a44f19 100644 --- a/src/python/py_utils.cpp +++ b/src/python/py_utils.cpp @@ -316,19 +316,31 @@ ov::genai::StreamerVariant pystreamer_to_streamer(const PyBindStreamerVariant& p ov::genai::StreamerVariant streamer = std::monostate(); std::visit(overloaded { - [&streamer](const std::function& py_callback){ - // Wrap python streamer with manual utf-8 decoding. Do not rely - // on pybind automatic decoding since it raises exceptions on incomplete strings. - auto callback_wrapped = [py_callback](std::string subword) -> bool { - auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace"); - return py_callback(py::reinterpret_borrow(py_str)); - }; - streamer = callback_wrapped; - }, - [&streamer](std::shared_ptr streamer_cls){ - streamer = streamer_cls; - }, - [](std::monostate none){ /*streamer is already a monostate */ } + [&streamer](const std::function(py::str)>& py_callback){ + // Wrap python streamer with manual utf-8 decoding. Do not rely + // on pybind automatic decoding since it raises exceptions on incomplete strings. + auto callback_wrapped = [py_callback](std::string subword) -> ov::genai::StreamerRunningStatus { + auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace"); + std::optional callback_output = py_callback(py::reinterpret_borrow(py_str)); + auto result = StreamerRunningStatus::RUNNING; + if (callback_output.has_value()) { + if (*callback_output == (uint16_t)StreamerRunningStatus::RUNNING) { + result = StreamerRunningStatus::RUNNING; + } else if (*callback_output == (uint16_t)StreamerRunningStatus::CANCEL) { + result = StreamerRunningStatus::CANCEL; + } else { + result = StreamerRunningStatus::STOP; + } + } + + return result; + }; + streamer = callback_wrapped; + }, + [&streamer](std::shared_ptr streamer_cls){ + streamer = streamer_cls; + }, + [](std::monostate none){ /*streamer is already a monostate */ } }, py_streamer); return streamer; } diff --git a/src/python/py_utils.hpp b/src/python/py_utils.hpp index c3dbdf6aee..0d73b5639d 100644 --- a/src/python/py_utils.hpp +++ b/src/python/py_utils.hpp @@ -17,7 +17,7 @@ namespace ov::genai::pybind::utils { // When StreamerVariant is used utf-8 decoding is done by pybind and can lead to exception on incomplete texts. // Therefore strings decoding should be handled with PyUnicode_DecodeUTF8(..., "replace") to not throw errors. -using PyBindStreamerVariant = std::variant, std::shared_ptr, std::monostate>; +using PyBindStreamerVariant = std::variant(std::string)>, std::shared_ptr, std::monostate>; template struct overloaded : Ts... { diff --git a/tests/python_tests/test_llm_pipeline.py b/tests/python_tests/test_llm_pipeline.py index 8968f2a083..3b9285c688 100644 --- a/tests/python_tests/test_llm_pipeline.py +++ b/tests/python_tests/test_llm_pipeline.py @@ -167,7 +167,12 @@ def user_defined_callback(subword): print(subword) -@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +def user_defined_status_callback(subword): + print(subword) + return ov_genai.StreamerRunningStatus.RUNNING + + +@pytest.mark.parametrize("callback", [print, user_defined_callback, user_defined_status_callback, lambda subword: print(subword)]) @pytest.mark.precommit @pytest.mark.nightly def test_callback_one_string(callback): @@ -177,7 +182,7 @@ def test_callback_one_string(callback): ov_pipe.generate('table is made of', generation_config, callback) -@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +@pytest.mark.parametrize("callback", [print, user_defined_callback, user_defined_status_callback, lambda subword: print(subword)]) @pytest.mark.precommit @pytest.mark.nightly def test_callback_batch_throws(callback): @@ -186,7 +191,7 @@ def test_callback_batch_throws(callback): ov_pipe.generate(['1', '2'], ov_pipe.get_generation_config(), callback) -@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +@pytest.mark.parametrize("callback", [print, user_defined_callback, user_defined_status_callback, lambda subword: print(subword)]) @pytest.mark.precommit @pytest.mark.nightly def test_callback_kwargs_one_string(callback): @@ -194,7 +199,7 @@ def test_callback_kwargs_one_string(callback): pipe.generate('table is made of', max_new_tokens=10, streamer=callback) -@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +@pytest.mark.parametrize("callback", [print, user_defined_callback, user_defined_status_callback, lambda subword: print(subword)]) @pytest.mark.precommit @pytest.mark.nightly @pytest.mark.parametrize("model_descr", get_models_list()) @@ -208,7 +213,7 @@ def test_callback_decoding_metallama(model_descr, callback): ov_pipe.generate(prompt, max_new_tokens=300, streamer=callback) -@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +@pytest.mark.parametrize("callback", [print, user_defined_callback, user_defined_status_callback, lambda subword: print(subword)]) @pytest.mark.precommit @pytest.mark.nightly def test_callback_kwargs_batch_throws(callback): @@ -217,6 +222,106 @@ def test_callback_kwargs_batch_throws(callback): pipe.generate(['1', '2'], max_new_tokens=10, streamer=callback) +@pytest.mark.precommit +@pytest.mark.nightly +def test_callback_terminate_by_bool_sampler(): + pipe = read_model(get_models_list()[0])[4] + + current_iter = 0 + num_iters = 10 + def callback(subword): + nonlocal current_iter + current_iter += 1 + return current_iter == num_iters + + ov_generation_config = GenerationConfig(max_new_tokens=100) + + # without attention mask + input_ids, _ = input_tensors_list[0] + inputs_ov = ov.Tensor(input_ids) + ov_output = pipe.generate(inputs_ov, ov_generation_config, streamer=callback) + + assert len(ov_output.tokens[0]) == num_iters + + +@pytest.mark.precommit +@pytest.mark.nightly +def test_callback_terminate_by_status_sampler(): + pipe = read_model(get_models_list()[0])[4] + + current_iter = 0 + num_iters = 10 + def callback(subword): + nonlocal current_iter + current_iter += 1 + return ov_genai.StreamerRunningStatus.STOP if current_iter == num_iters else ov_genai.StreamerRunningStatus.RUNNING + + ov_generation_config = GenerationConfig(max_new_tokens=100) + + # without attention mask + input_ids, _ = input_tensors_list[0] + inputs_ov = ov.Tensor(input_ids) + ov_output = pipe.generate(inputs_ov, ov_generation_config, streamer=callback) + + assert len(ov_output.tokens[0]) == num_iters + + +@pytest.mark.parametrize("model_descr", get_chat_models_list()) +@pytest.mark.precommit +@pytest.mark.nightly +def test_chat_scenario_callback_cancel(model_descr): + callback_questions = [ + '1+1=', + 'Why is the Sun yellow?', + 'What is the previous answer?', + 'What was my first question?' + ] + + generation_config_kwargs = dict(max_new_tokens=20) + + chat_history_hf = [] + chat_history_ov = [] + + model_id, path, tokenizer, opt_model, ov_pipe = read_model((model_descr[0], model_descr[1] / '_test_chat')) + + ov_generation_config = GenerationConfig(**generation_config_kwargs) + hf_generation_config = convert_to_hf(opt_model.generation_config, ov_generation_config) + + current_iter = 0 + num_iters = 3 + def callback(subword): + nonlocal current_iter + current_iter += 1 + return ov_genai.StreamerRunningStatus.CANCEL if current_iter == num_iters else ov_genai.StreamerRunningStatus.RUNNING + + ov_pipe.start_chat() + for prompt in callback_questions: + if (prompt != callback_questions[1]): + chat_history_hf.append({'role': 'user', 'content': prompt}) + chat_history_ov.append({'role': 'user', 'content': prompt}) + + chat_prompt = tokenizer.apply_chat_template(chat_history_hf, tokenize=False, add_generation_prompt=True) + tokenized = tokenizer(chat_prompt, return_tensors='pt', add_special_tokens=False) + prompt_len = tokenized['input_ids'].numel() + + answer = opt_model.generate(**tokenized, generation_config=hf_generation_config).sequences[0] + answer_str = tokenizer.decode(answer[prompt_len:], skip_special_tokens=True) + chat_history_hf.append({'role': 'assistant', 'content': answer_str}) + + answer_ov = ov_pipe.generate(prompt, generation_config=ov_generation_config) + chat_history_ov.append({'role': 'assistant', 'content': answer_ov}) + else: + answer_ov = ov_pipe.generate(prompt, generation_config=ov_generation_config, streamer=callback) + + ov_pipe.finish_chat() + + if chat_history_ov != chat_history_hf: + print(f'hf_output: {chat_history_hf}') + print(f'ov_output: {chat_history_ov}') + + assert chat_history_ov == chat_history_hf + + class Printer(ov_genai.StreamerBase): def __init__(self, tokenizer): # super() may work, but once you begin mixing Python and C++ @@ -269,7 +374,7 @@ def test_streamer_kwargs_batch_throws(): @pytest.mark.precommit @pytest.mark.nightly -@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +@pytest.mark.parametrize("callback", [print, user_defined_callback, user_defined_status_callback, lambda subword: print(subword)]) def test_operator_with_callback_one_string(callback): ov_pipe = read_model(get_models_list()[0])[4] ten_tokens = ov_pipe.get_generation_config() @@ -279,7 +384,7 @@ def test_operator_with_callback_one_string(callback): @pytest.mark.precommit @pytest.mark.nightly -@pytest.mark.parametrize("callback", [print, user_defined_callback, lambda subword: print(subword)]) +@pytest.mark.parametrize("callback", [print, user_defined_callback, user_defined_status_callback, lambda subword: print(subword)]) def test_operator_with_callback_batch_throws(callback): ov_pipe = read_model(get_models_list()[0])[4] with pytest.raises(RuntimeError): diff --git a/tools/continuous_batching/accuracy/continuous_batching_accuracy.cpp b/tools/continuous_batching/accuracy/continuous_batching_accuracy.cpp index d644ba9418..358ba438b0 100644 --- a/tools/continuous_batching/accuracy/continuous_batching_accuracy.cpp +++ b/tools/continuous_batching/accuracy/continuous_batching_accuracy.cpp @@ -114,7 +114,7 @@ int main(int argc, char* argv[]) try { print_generation_result(generation_result); } break; - case ov::genai::GenerationStatus::DROPPED_BY_PIPELINE: + case ov::genai::GenerationStatus::CANCEL: std::cout << "Request was aborted." < 0) { std::cout << "Partial result:" << std::endl; diff --git a/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp b/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp index eeb3c0f070..d64c6a51fa 100644 --- a/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp +++ b/tools/continuous_batching/accuracy/continuous_batching_speculative_decoding.cpp @@ -124,7 +124,7 @@ int main(int argc, char* argv[]) try { print_cb_generation_result(generation_result); } break; - case ov::genai::GenerationStatus::DROPPED_BY_PIPELINE: + case ov::genai::GenerationStatus::CANCEL: std::cout << "Request was aborted." < 0) { std::cout << "Partial result:" << std::endl;