Skip to content

Commit

Permalink
Add a choice of how to end streaming from callback: STOP or CANCEL
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Jan 26, 2025
1 parent 48613d5 commit cac1834
Show file tree
Hide file tree
Showing 39 changed files with 469 additions and 225 deletions.
5 changes: 3 additions & 2 deletions samples/cpp/text_generation/chat_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ int main(int argc, char* argv[]) try {

ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
std::function<bool(std::string)> streamer = [](std::string word) {

auto streamer = [](std::string word) {
std::cout << word << std::flush;
// Return flag corresponds whether generation should be stopped.
// false means continue generation.
return false;
return ov::genai::StreamerRunningStatus::RUNNING;
};

pipe.start_chat();
Expand Down
3 changes: 1 addition & 2 deletions samples/python/text_generation/chat_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False

return openvino_genai.StreamerRunningStatus.RUNNING

def main():
parser = argparse.ArgumentParser()
Expand Down
10 changes: 5 additions & 5 deletions samples/python/text_generation/prompt_lookup_decoding_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import argparse
import openvino_genai

def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False
def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False

def main():
parser = argparse.ArgumentParser()
Expand Down
19 changes: 13 additions & 6 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ enum class GenerationStatus {
RUNNING = 0, // Default status for ongoing generation
FINISHED = 1, // Status set when generation has been finished
IGNORED = 2, // Status set when generation run into out-of-memory condition and could not be continued
DROPPED_BY_PIPELINE = 3, // Currently not used, TODO: implement abort functionality
DROPPED_BY_HANDLE = 4 // Status set when generation handle is dropped
CANCEL = 3, // Status set when generation handle is canceled
STOP = 4 // Status set when generation handle is stopped
};

struct EncodedGenerationResult {
Expand Down Expand Up @@ -70,10 +70,10 @@ using GenerationOutputs = std::unordered_map<uint64_t, GenerationOutput>;

class GenerationStream;

class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
class OPENVINO_GENAI_EXPORTS
GenerationHandleImpl {
std::shared_ptr<GenerationStream> m_generation_stream;
ov::genai::GenerationConfig m_sampling_params;

ov::genai::GenerationConfig m_sampling_params;
public:
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const ov::genai::GenerationConfig& sampling_params) :
m_generation_stream(std::move(generation_stream)),
Expand All @@ -88,10 +88,17 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
GenerationStatus get_status();

bool can_read();
bool is_dropped();

bool is_stopped();

bool is_canceled();

void drop();

void stop();

void cancel();

GenerationOutputs back();
// Reads result of a generation for single iteration
GenerationOutputs read();
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ov {
namespace genai {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<StreamerRunningStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
using StringInputs = std::variant<std::string, std::vector<std::string>>;
Expand Down
15 changes: 15 additions & 0 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,39 @@
#pragma once

#include "openvino/genai/tokenizer.hpp"
#include "openvino/genai/generation_handle.hpp"
#include <variant>

namespace ov {
namespace genai {

enum class StreamerRunningStatus {
RUNNING = 0, // Continue to run of inference
STOP = 1, // Stop generation, keep history as is, KV cache includes last request and generated tokens
CANCEL = 2 // Stop generate, drop last prompt and all generated tokens from history, KV cache include history but last step
};

using CallbackTypeVariant = std::variant<bool, StreamerRunningStatus, std::monostate>;

/**
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
*
* @param m_tokenizer tokenizer
*/
class OPENVINO_GENAI_EXPORTS StreamerBase {
public:
StreamerRunningStatus m_streaming_finish_status = StreamerRunningStatus::RUNNING;
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
virtual bool put(int64_t token) = 0;

/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;

StreamerRunningStatus get_streaming_status() {
return m_streaming_finish_status;
}

virtual ~StreamerBase();
};

Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/continuous_batching_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
std::vector<std::string> plain_replies;
std::vector<float> plain_scores;
for (GenerationResult& res : generated) {
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP || res.m_status == GenerationStatus::CANCEL, "Got unfinished GenerationStatus");
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies));
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
}
Expand Down Expand Up @@ -189,7 +189,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
std::vector<std::vector<int64_t>> plain_tokens;
std::vector<float> plain_scores;
for (EncodedGenerationResult& res : generated) {
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP, "Got unfinished GenerationStatus");
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_tokens));
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
}
Expand Down
19 changes: 5 additions & 14 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,17 +331,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
set_adapters(sampling_params[0].adapters);

const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit(overloaded{
[](std::monostate) -> std::shared_ptr<StreamerBase> {
return nullptr;
},
[](const std::shared_ptr<StreamerBase>& streamer) {
return streamer;
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);
const std::shared_ptr<StreamerBase>& streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer);

OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 &&
(sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()),
Expand Down Expand Up @@ -377,7 +367,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
std::unordered_map<uint64_t, GenerationOutput> token = generation->read();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
streamer_ptr->get_streaming_status() == ov::genai::StreamerRunningStatus::CANCEL ? generation->cancel() : generation->stop();
break;
}
}
Expand Down Expand Up @@ -436,6 +426,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
result.m_request_id = request_id;
result.m_generation_ids.resize(num_outputs);
result.m_scores.resize(num_outputs);
result.m_status = request->get_generation_stream()->get_status();

for (size_t i = 0; i < num_outputs; ++i) {
const auto & sequence = sequences[i];
Expand Down Expand Up @@ -470,7 +461,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_free_non_running_reque
std::vector<SequenceGroup::Ptr>::iterator requests_iterator = m_requests.begin();
while (requests_iterator != m_requests.end()) {
const auto& request = *requests_iterator;
if (request->has_finished() || request->handle_dropped()) {
if(request->has_finished() || request->handle_stopped() || request->handle_canceled()) {
for (const auto& sequence: request->get_sequences()) {
if (m_scheduler->has_block_table(sequence->get_id())) {
m_scheduler->free_sequence(sequence->get_id());
Expand All @@ -488,7 +479,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_notify_requests_droppe
// Notify the last time by pushing empty output
// This causes read() to unblock by adding anything to the queue
for (SequenceGroup::Ptr& request : m_requests) {
if (request->handle_dropped())
if (request->handle_stopped() || request->handle_canceled())
request->push_empty_outputs();
}
}
Expand Down
28 changes: 20 additions & 8 deletions src/cpp/src/generation_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,44 @@
using namespace ov::genai;

GenerationHandleImpl::~GenerationHandleImpl() {
drop();
stop();
}

GenerationStatus GenerationHandleImpl::get_status() {
return m_generation_stream->get_status();
}

bool GenerationHandleImpl::can_read() {
return !is_dropped() && m_generation_stream->can_read();
return !is_canceled() && !is_stopped() && m_generation_stream->can_read();
}

bool GenerationHandleImpl::is_dropped() {
return get_status() == GenerationStatus::DROPPED_BY_HANDLE;
bool GenerationHandleImpl::is_stopped() {
return get_status() == GenerationStatus::STOP;
}

bool GenerationHandleImpl::is_canceled() {
return get_status() == GenerationStatus::CANCEL;
}

void GenerationHandleImpl::drop() {
m_generation_stream->drop();
m_generation_stream->stop();
}

void GenerationHandleImpl::stop() {
m_generation_stream->stop();
}

void GenerationHandleImpl::cancel() {
m_generation_stream->cancel();
}

std::unordered_map<uint64_t, GenerationOutput> GenerationHandleImpl::back() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped.");
return m_generation_stream->back();
}

std::unordered_map<uint64_t, GenerationOutput> GenerationHandleImpl::read() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped.");
return m_generation_stream->read();
}

Expand All @@ -57,7 +69,7 @@ void add_partial_result(std::unordered_map<uint64_t, GenerationOutput>& partial_
}

std::vector<GenerationOutput> GenerationHandleImpl::read_all() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped.");
std::vector<GenerationOutput> results;
std::unordered_map<uint64_t, GenerationOutput> partial_results;
// We iterate until generation is running or there are tokens we haven't read yet
Expand Down
9 changes: 7 additions & 2 deletions src/cpp/src/generation_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,14 @@ class GenerationStream {
return m_status;
}

void drop() {
void stop() {
std::lock_guard<std::mutex> lock(m_mutex);
m_status = GenerationStatus::DROPPED_BY_HANDLE;
m_status = GenerationStatus::STOP;
}

void cancel() {
std::lock_guard<std::mutex> lock(m_mutex);
m_status = GenerationStatus::CANCEL;
}
};
}
6 changes: 5 additions & 1 deletion src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
const auto decode_start = std::chrono::steady_clock::now();
generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
raw_counters.detokenization_durations.emplace_back(std::chrono::steady_clock::now() - decode_start);
if (m_is_chat_conversation && 0 == idx) {
if (m_is_chat_conversation && 0 == idx && res.m_status != ov::genai::GenerationStatus::CANCEL) {
m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
}
}
Expand All @@ -98,6 +98,10 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
});
}

// if streaming was canceled, prompt/answer of current step shouldn't be presented in history, so let's remove prompt from history
if (m_is_chat_conversation && !encoded.empty() && encoded[0].m_status == ov::genai::GenerationStatus::CANCEL)
m_history.pop_back();

return decoded;
}
}
4 changes: 3 additions & 1 deletion src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ std::pair<ov::AnyMap, ov::genai::static_llm::ModelConfigDesc> split_model_descr(
std::pair<std::string, Any> streamer(StreamerVariant func) {
if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::shared_ptr<StreamerBase>>(*streamer_obj)};
} else {
} else if (auto streamer_obj = std::get_if<std::function<StreamerRunningStatus(std::string)>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::function<StreamerRunningStatus(std::string)>>(*streamer_obj)};
} else {
auto callback = std::get<std::function<bool(std::string)>>(func);
return {utils::STREAMER_ARG_NAME, Any::make<std::function<bool(std::string)>>(callback)};
}
Expand Down
Loading

0 comments on commit cac1834

Please sign in to comment.