Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a choice of how to end streaming from callback: STOP or CANCEL #1476

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions samples/cpp/text_generation/chat_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ int main(int argc, char* argv[]) try {

ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
std::function<bool(std::string)> streamer = [](std::string word) {

auto streamer = [](std::string word) {
std::cout << word << std::flush;
// Return flag corresponds whether generation should be stopped.
// false means continue generation.
return false;
return ov::genai::StreamerRunningStatus::RUNNING;
};

pipe.start_chat();
Expand Down
2 changes: 1 addition & 1 deletion samples/cpp/text_generation/multinomial_causal_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int main(int argc, char* argv[]) try {
config.top_k = 30;
auto streamer = [](std::string subword) {
std::cout << subword << std::flush;
return false;
return ov::genai::StreamerRunningStatus::RUNNING;
};

// Since the streamer is set, the results will
Expand Down
2 changes: 1 addition & 1 deletion samples/cpp/text_generation/prompt_lookup_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ int main(int argc, char* argv[]) try {

auto streamer = [](std::string subword) {
std::cout << subword << std::flush;
return false;
return ov::genai::StreamerRunningStatus::RUNNING;
};

// Since the streamer is set, the results will
Expand Down
2 changes: 1 addition & 1 deletion samples/cpp/text_generation/speculative_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ int main(int argc, char* argv[]) try {

auto streamer = [](std::string subword) {
std::cout << subword << std::flush;
return false;
return ov::genai::StreamerRunningStatus::RUNNING;
};

// Since the streamer is set, the results will
Expand Down
4 changes: 1 addition & 3 deletions samples/python/text_generation/chat_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False

return openvino_genai.StreamerRunningStatus.RUNNING

def main():
parser = argparse.ArgumentParser()
Expand Down
4 changes: 2 additions & 2 deletions samples/python/text_generation/multinomial_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def get_stop_flag(self):
Checks whether the generation process should be stopped.

Returns:
bool: Always returns False in this implementation.
openvino_genai.StreamerRunningStatus: Always returns RUNNING in this implementation.
"""
return False
return openvino_genai.StreamerRunningStatus.RUNNING

def put_word(self, word: str):
"""
Expand Down
9 changes: 4 additions & 5 deletions samples/python/text_generation/prompt_lookup_decoding_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import argparse
import openvino_genai

def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False
def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
return openvino_genai.StreamerRunningStatus.RUNNING

def main():
parser = argparse.ArgumentParser()
Expand Down
3 changes: 1 addition & 2 deletions samples/python/text_generation/speculative_decoding_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False
return openvino_genai.StreamerRunningStatus.RUNNING

def main():
parser = argparse.ArgumentParser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def streamer(subword: str) -> bool:
print(subword, end='', flush=True)

# No value is returned as in this example we don't want to stop the generation in this method.
# "return None" will be treated the same as "return False".
# "return None" will be treated the same as "return ov::genai::StreamerRunningStatus::RUNNING;".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# "return None" will be treated the same as "return ov::genai::StreamerRunningStatus::RUNNING;".
# "return None" will be treated the same as "return openvino_genai.StreamerRunningStatus.RUNNING".



def read_image(path: str) -> Tensor:
Expand Down
3 changes: 1 addition & 2 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ int main(int argc, char* argv[]) {
auto streamer = [](std::string word) {
std::cout << word << std::flush;
// Return flag corresponds whether generation should be stopped.
// false means continue generation.
return false;
return ov::genai::StreamerRunningStatus::RUNNING;
};
std::cout << pipe.generate("The Sun is yellow because", ov::genai::streamer(streamer), ov::genai::max_new_tokens(200));
}
Expand Down
20 changes: 14 additions & 6 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ enum class GenerationStatus {
RUNNING = 0, // Default status for ongoing generation
FINISHED = 1, // Status set when generation has been finished
IGNORED = 2, // Status set when generation run into out-of-memory condition and could not be continued
DROPPED_BY_PIPELINE = 3, // Currently not used, TODO: implement abort functionality
DROPPED_BY_HANDLE = 4 // Status set when generation handle is dropped
Copy link
Contributor

@ilya-lavrenov ilya-lavrenov Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's deprecate DROPPED_BY_HANDLE via OPENVINO_ENUM_DEPRECATED and assign DROPPED_BY_HANDLE = STOP

CANCEL = 3, // Status set when generation handle is canceled
STOP = 4 // Status set when generation handle is stopped
};

struct EncodedGenerationResult {
Expand Down Expand Up @@ -70,10 +70,10 @@ using GenerationOutputs = std::unordered_map<uint64_t, GenerationOutput>;

class GenerationStream;

class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
class OPENVINO_GENAI_EXPORTS
GenerationHandleImpl {
std::shared_ptr<GenerationStream> m_generation_stream;
ov::genai::GenerationConfig m_sampling_params;

ov::genai::GenerationConfig m_sampling_params;
public:
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const ov::genai::GenerationConfig& sampling_params) :
m_generation_stream(std::move(generation_stream)),
Expand All @@ -88,10 +88,18 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
GenerationStatus get_status();

bool can_read();
bool is_dropped();

bool is_stopped();

bool is_canceled();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want to highlight what variant we want to use cancelled or canceled
@Wovchena @sbalandi I see that both spellings are OK, but want to pay your attention additionally..


OPENVINO_DEPRECATED("Please, use `stop()` instead of `drop()`.")
void drop();

void stop();

void cancel();

GenerationOutputs back();
// Reads result of a generation for single iteration
GenerationOutputs read();
Expand Down
6 changes: 4 additions & 2 deletions src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
namespace ov {
namespace genai {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
// Return flag corresponds whether generation should be stopped. It could be:
// ov::genai::StreamerRunningStatus flag, RUNNING means continue generation, STOP means stop generation, CANCEL means stop generation and remove last propmt and answer from history
// *DEPRECATED* bool flag, false means continue generation, true means stop. Please, use `ov::genai::StreamerRunningStatus` instead.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<StreamerRunningStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
Wovchena marked this conversation as resolved.
Show resolved Hide resolved
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
using StringInputs = std::variant<std::string, std::vector<std::string>>;
Expand Down
19 changes: 19 additions & 0 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,29 @@
#pragma once

#include "openvino/genai/tokenizer.hpp"
#include "openvino/genai/generation_handle.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this header file is not required here anymore

#include <variant>

namespace ov {
namespace genai {

enum class StreamerRunningStatus {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
enum class StreamerRunningStatus {
enum class StreamingStatus {

@Wovchena @sbalandi what do you think?

RUNNING = 0, // Continue to run of inference
STOP = 1, // Stop generation, keep history as is, KV cache includes last request and generated tokens
CANCEL = 2 // Stop generate, drop last prompt and all generated tokens from history, KV cache include history but last step
};

using CallbackTypeVariant = std::variant<bool, StreamerRunningStatus, std::monostate>;

/**
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
*
* @param m_tokenizer tokenizer
*/
class OPENVINO_GENAI_EXPORTS StreamerBase {
protected:
StreamerRunningStatus m_streaming_finish_status = StreamerRunningStatus::RUNNING;

public:
/// @brief put is called every time new token is decoded,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can add a new function put_token / write with new return type and deprecate current put with binary status?
IMO, it will be more future proof and removes ambiguity that authors of custom text streamers need to write functions like is_generation_complete

CC @Wovchena @sbalandi @as-suvorov what do you think?

BTW, if you are OK with new method, note, that we need to select more or less generic name, which will allow to put a single token or multiple tokens (Whisper / Spec Dec cases)

Expand All @@ -22,6 +35,12 @@ class OPENVINO_GENAI_EXPORTS StreamerBase {
/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;

/// @brief get_streaming_status() is called by the pipline to take more detailed about streaming status. m_streaming_finish_status, which contains streaming status info, could be set in put().
/// @return ov::genai::StreamerRunningStatus to determine the streaming status of generation, whether generation is running, stopped or cancelled
virtual StreamerRunningStatus get_streaming_status() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
virtual StreamerRunningStatus get_streaming_status() {
virtual StreamerRunningStatus get_streaming_status() const {

return m_streaming_finish_status;
}

virtual ~StreamerBase();
};

Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/continuous_batching_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
std::vector<std::string> plain_replies;
std::vector<float> plain_scores;
for (GenerationResult& res : generated) {
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP || res.m_status == GenerationStatus::CANCEL, "Got unfinished GenerationStatus");
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies));
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
}
Expand Down Expand Up @@ -189,7 +189,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
std::vector<std::vector<int64_t>> plain_tokens;
std::vector<float> plain_scores;
for (EncodedGenerationResult& res : generated) {
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP || res.m_status == GenerationStatus::CANCEL, "Got unfinished GenerationStatus");
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_tokens));
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
}
Expand Down
19 changes: 5 additions & 14 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,17 +420,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
set_adapters(sampling_params[0].adapters);

const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit(overloaded{
[](std::monostate) -> std::shared_ptr<StreamerBase> {
return nullptr;
},
[](const std::shared_ptr<StreamerBase>& streamer) {
return streamer;
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);
const std::shared_ptr<StreamerBase>& streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer);

OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 &&
(sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()),
Expand Down Expand Up @@ -466,7 +456,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
std::unordered_map<uint64_t, GenerationOutput> token = generation->read();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
streamer_ptr->get_streaming_status() == ov::genai::StreamerRunningStatus::CANCEL ? generation->cancel() : generation->stop();
break;
}
}
Expand Down Expand Up @@ -525,6 +515,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
result.m_request_id = request_id;
result.m_generation_ids.resize(num_outputs);
result.m_scores.resize(num_outputs);
result.m_status = request->get_generation_stream()->get_status();

for (size_t i = 0; i < num_outputs; ++i) {
const auto & sequence = sequences[i];
Expand Down Expand Up @@ -559,7 +550,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_free_non_running_reque
std::vector<SequenceGroup::Ptr>::iterator requests_iterator = m_requests.begin();
while (requests_iterator != m_requests.end()) {
const auto& request = *requests_iterator;
if (request->has_finished() || request->handle_dropped()) {
if(request->has_finished() || request->handle_stopped() || request->handle_canceled()) {
for (const auto& sequence: request->get_sequences()) {
if (m_scheduler->has_block_table(sequence->get_id())) {
m_scheduler->free_sequence(sequence->get_id());
Expand All @@ -577,7 +568,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_notify_requests_droppe
// Notify the last time by pushing empty output
// This causes read() to unblock by adding anything to the queue
for (SequenceGroup::Ptr& request : m_requests) {
if (request->handle_dropped())
if (request->handle_stopped() || request->handle_canceled())
request->push_empty_outputs();
}
}
Expand Down
28 changes: 20 additions & 8 deletions src/cpp/src/generation_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,44 @@
using namespace ov::genai;

GenerationHandleImpl::~GenerationHandleImpl() {
drop();
stop();
}

GenerationStatus GenerationHandleImpl::get_status() {
return m_generation_stream->get_status();
}

bool GenerationHandleImpl::can_read() {
return !is_dropped() && m_generation_stream->can_read();
return !is_canceled() && !is_stopped() && m_generation_stream->can_read();
}

bool GenerationHandleImpl::is_dropped() {
return get_status() == GenerationStatus::DROPPED_BY_HANDLE;
bool GenerationHandleImpl::is_stopped() {
return get_status() == GenerationStatus::STOP;
}

bool GenerationHandleImpl::is_canceled() {
return get_status() == GenerationStatus::CANCEL;
}

void GenerationHandleImpl::drop() {
m_generation_stream->drop();
m_generation_stream->stop();
}

void GenerationHandleImpl::stop() {
m_generation_stream->stop();
}

void GenerationHandleImpl::cancel() {
m_generation_stream->cancel();
}

std::unordered_map<uint64_t, GenerationOutput> GenerationHandleImpl::back() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped.");
OPENVINO_ASSERT(!is_stopped() && !is_canceled(), "GenerationHandle cannot be used after it is stopped / canceled.");

the same in other places.

return m_generation_stream->back();
}

std::unordered_map<uint64_t, GenerationOutput> GenerationHandleImpl::read() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped.");
return m_generation_stream->read();
}

Expand All @@ -57,7 +69,7 @@ void add_partial_result(std::unordered_map<uint64_t, GenerationOutput>& partial_
}

std::vector<GenerationOutput> GenerationHandleImpl::read_all() {
OPENVINO_ASSERT(!is_dropped(), "GenerationHandle cannot be used after it is dropped.");
OPENVINO_ASSERT(!is_stopped(), "GenerationHandle cannot be used after it is stopped.");
std::vector<GenerationOutput> results;
std::unordered_map<uint64_t, GenerationOutput> partial_results;
// We iterate until generation is running or there are tokens we haven't read yet
Expand Down
9 changes: 7 additions & 2 deletions src/cpp/src/generation_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,14 @@ class GenerationStream {
return m_status;
}

void drop() {
void stop() {
std::lock_guard<std::mutex> lock(m_mutex);
m_status = GenerationStatus::DROPPED_BY_HANDLE;
m_status = GenerationStatus::STOP;
}

void cancel() {
std::lock_guard<std::mutex> lock(m_mutex);
m_status = GenerationStatus::CANCEL;
}
};
}
6 changes: 5 additions & 1 deletion src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
const auto decode_start = std::chrono::steady_clock::now();
generated.push_back(m_tokenizer.decode(res.m_generation_ids.at(idx)));
raw_counters.detokenization_durations.emplace_back(std::chrono::steady_clock::now() - decode_start);
if (m_is_chat_conversation && 0 == idx) {
if (m_is_chat_conversation && 0 == idx && res.m_status != ov::genai::GenerationStatus::CANCEL) {
m_history.push_back({{"role", "assistant"}, {"content", generated.back()}});
}
}
Expand All @@ -110,6 +110,10 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
});
}

// if streaming was canceled, prompt/answer of current step shouldn't be presented in history, so let's remove prompt from history
if (m_is_chat_conversation && !encoded.empty() && encoded[0].m_status == ov::genai::GenerationStatus::CANCEL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need !encoded.empty() ? I suppose encoded always has the same size as number of inputs

m_history.pop_back();

return decoded;
}
}
4 changes: 3 additions & 1 deletion src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ std::pair<ov::AnyMap, ov::genai::static_llm::ModelConfigDesc> split_model_descr(
std::pair<std::string, Any> streamer(StreamerVariant func) {
if (auto streamer_obj = std::get_if<std::shared_ptr<StreamerBase>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::shared_ptr<StreamerBase>>(*streamer_obj)};
} else {
} else if (auto streamer_obj = std::get_if<std::function<StreamerRunningStatus(std::string)>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::function<StreamerRunningStatus(std::string)>>(*streamer_obj)};
} else {
auto callback = std::get<std::function<bool(std::string)>>(func);
return {utils::STREAMER_ARG_NAME, Any::make<std::function<bool(std::string)>>(callback)};
}
Expand Down
Loading
Loading