Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper pipeline: support stateful decoder #1474

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,9 @@ jobs:
python -m pip install . --verbose --find-links ${env:OV_INSTALL_DIR}/wheels
python -m pip install ./tools/who_what_benchmark --find-links ${env:OV_INSTALL_DIR}/wheels

# will install transformers 4.46.3 version
# transformers 4.46.3 will enable return_timestamps tests
# this check enabled for windows only. Ticket: 160205.
python -m pip install git+https://github.com/huggingface/optimum-intel.git@753f84db6e0966580eb9eaa74a808213be730631
python -m pip install transformers==4.46.3

python -m pytest -v ./tests/python_tests/test_whisper_pipeline.py -k "not test_smoke"

Expand Down
2 changes: 1 addition & 1 deletion src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ int main(int argc, char* argv[]) {

Streaming with a custom class:

C++ template for a stremer.
C++ template for a streamer.
```cpp
#include "openvino/genai/streamer_base.hpp"
#include "openvino/genai/llm_pipeline.hpp"
Expand Down
17 changes: 17 additions & 0 deletions src/cpp/src/logger.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include <iostream>
#include <string>

namespace ov::genai {

class Logger {
public:
static void warn(std::string message) {
std::cout << "[WARN] " << message << '\n';
};
};

} // namespace ov::genai
26 changes: 26 additions & 0 deletions src/cpp/src/whisper/models/decoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "decoder.hpp"

#include <filesystem>

#include "statefull_decoder.hpp"
#include "utils.hpp"
#include "with_past_decoder.hpp"

namespace ov::genai {
std::shared_ptr<WhisperDecoder> WhisperDecoder::from_path(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties) {
bool has_decoder_with_past = std::filesystem::exists(models_path / "openvino_decoder_with_past_model.xml");

if (has_decoder_with_past) {
return std::make_shared<WhisperWithPastDecoder>(models_path, device, properties);
}

return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties);
}

WhisperDecoder::~WhisperDecoder() = default;
} // namespace ov::genai
29 changes: 29 additions & 0 deletions src/cpp/src/whisper/models/decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <filesystem>

#include "openvino/genai/whisper_generation_config.hpp"
#include "openvino/runtime/core.hpp"

namespace ov::genai {
class WhisperDecoder {
public:
static std::shared_ptr<WhisperDecoder> from_path(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);

virtual std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) = 0;

virtual std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) = 0;

virtual void reset_state() = 0;

virtual ~WhisperDecoder();
};
} // namespace ov::genai
60 changes: 60 additions & 0 deletions src/cpp/src/whisper/models/statefull_decoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "statefull_decoder.hpp"

#include "utils.hpp"

namespace ov::genai {
WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties) {
ov::Core core = utils::singleton_core();

auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);

utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
m_request = compiled_model.create_infer_request();
}

std::pair<int64_t, float> WhisperStatefullDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

reset_state();

return {output_token, infer_ms};
}

std::pair<ov::Tensor, float> WhisperStatefullDecoder::decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) {
m_request.set_tensor("encoder_hidden_states", encoder_hidden_state);

ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data());
m_request.set_tensor("input_ids", input_ids_tensor);

ov::Tensor cache_position_tensor = m_request.get_tensor("cache_position");
cache_position_tensor.set_shape({input_ids.size()});

auto cache_data = cache_position_tensor.data<int64_t>();
std::iota(cache_data, cache_data + cache_position_tensor.get_size(), cache_position);

m_request.get_tensor("beam_idx").set_shape({1});
m_request.get_tensor("beam_idx").data<int32_t>()[0] = 0;

const auto infer_start = std::chrono::steady_clock::now();
m_request.infer();
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start);

auto output_tensor = m_request.get_tensor("logits");

return {output_tensor, infer_ms};
};

void WhisperStatefullDecoder::reset_state() {
m_request.reset_state();
}
} // namespace ov::genai
29 changes: 29 additions & 0 deletions src/cpp/src/whisper/models/statefull_decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "decoder.hpp"
#include "openvino/runtime/core.hpp"

namespace ov::genai {

class WhisperStatefullDecoder : public WhisperDecoder {
public:
WhisperStatefullDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);

std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) override;

std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) override;

void reset_state() override;

private:
ov::InferRequest m_request;
};
} // namespace ov::genai
107 changes: 107 additions & 0 deletions src/cpp/src/whisper/models/with_past_decoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "with_past_decoder.hpp"

#include <regex>

#include "logger.hpp"
#include "utils.hpp"

namespace {
void set_past_key_value(ov::InferRequest& source, ov::InferRequest& dest) {
// source outputs:
// present.0.decoder.key
// present.0.decoder.value
// present.0.encoder.key
// present.0.encoder.value

// dest inputs:
// past_key_values.0.decoder.key
// past_key_values.0.decoder.value
// past_key_values.0.encoder.key
// past_key_values.0.encoder.value

for (auto& source_output : source.get_compiled_model().outputs()) {
std::string source_output_name = source_output.get_any_name();
if (source_output_name.find("logits") != std::string::npos) {
continue;
}

std::string with_past_input_name =
std::regex_replace(source_output_name, std::regex("present"), "past_key_values");

auto kv_tensor = source.get_tensor(source_output_name);
dest.set_tensor(with_past_input_name, ov::Tensor{kv_tensor});
}
}
} // namespace

namespace ov::genai {
WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& models_path,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilya-lavrenov @Wovchena I want to add deprecation note for this ctor. I saw OPENVINO_DEPRECATED macros but I think it doesn't fit here as I want to warn user at runtime.
Can I add just std::cout << "[Warning] Whisper decoder with past deprecated ..." ? Does OV have logging utilities we can reuse? Or there is a better way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have a better way than a message in runtime

const std::string& device,
const ov::AnyMap& properties) {
Logger::warn("Whisper decoder models with past is deprecated. Support will be removed in 2026.0.0 release.\n"
"To obtain stateful decoder model use latest `optimum-intel` package:\n"
"pip install optimum-intel@git+https://github.com/huggingface/optimum-intel.git\n"
"optimum-cli export openvino --trust-remote-code --model openai/whisper-tiny whisper-tiny");
ov::Core core = utils::singleton_core();

auto compiled_model = core.compile_model(models_path / "openvino_decoder_model.xml", device, properties);
utils::print_compiled_model_properties(compiled_model, "whisper decoder model");
m_request_decoder = compiled_model.create_infer_request();

compiled_model = core.compile_model(models_path / "openvino_decoder_with_past_model.xml", device, properties);
utils::print_compiled_model_properties(compiled_model, "whisper decoder with past model");
m_request_decoder_with_past = compiled_model.create_infer_request();
}

std::pair<int64_t, float> WhisperWithPastDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
auto [output_tensor, infer_ms] = decode(encoder_hidden_state, {decoder_start_token_id}, 0);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

reset_state();

return {output_token, infer_ms};
}

std::pair<ov::Tensor, float> WhisperWithPastDecoder::decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) {
const bool initial_step = cache_position == 0;
ov::InferRequest& request = initial_step ? m_request_decoder : m_request_decoder_with_past;

request.set_tensor("encoder_hidden_states", encoder_hidden_state);

const ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, (void*)input_ids.data());
request.set_tensor("input_ids", input_ids_tensor);

if (!initial_step) {
ov::Tensor cache_position_tensor = request.get_tensor("cache_position");
cache_position_tensor.set_shape({1});
cache_position_tensor.data<int64_t>()[0] = cache_position;
}

const auto infer_start = std::chrono::steady_clock::now();
request.infer();
const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start);

auto output_tensor = request.get_tensor("logits");

if (initial_step) {
set_past_key_value(m_request_decoder, m_request_decoder_with_past);
} else if (!m_decoder_with_past_kv_value_set) {
set_past_key_value(m_request_decoder_with_past, m_request_decoder_with_past);
m_decoder_with_past_kv_value_set = true;
}

return {output_tensor, infer_ms};
}

void WhisperWithPastDecoder::reset_state() {
m_request_decoder_with_past.reset_state();
m_decoder_with_past_kv_value_set = false;
}
} // namespace ov::genai
32 changes: 32 additions & 0 deletions src/cpp/src/whisper/models/with_past_decoder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "decoder.hpp"
#include "openvino/runtime/core.hpp"

namespace ov::genai {

class WhisperWithPastDecoder : public WhisperDecoder {
public:
WhisperWithPastDecoder(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties);

std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) override;

std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) override;

void reset_state() override;

private:
ov::InferRequest m_request_decoder;
ov::InferRequest m_request_decoder_with_past;
bool m_decoder_with_past_kv_value_set = false;
};

} // namespace ov::genai
Loading
Loading