From 9e91e32e85891c1bf14d214f978bb1eb25a1dc2a Mon Sep 17 00:00:00 2001 From: Ekaterina Shiryaeva Date: Fri, 10 Jan 2025 10:40:56 +0000 Subject: [PATCH] Address comments --- src/cpp/src/whisper_pipeline_static.cpp | 26 ++++++++++++------------- src/cpp/src/whisper_pipeline_static.hpp | 9 +++------ 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index 5e0f2ef20e..91de478b1c 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -214,9 +214,11 @@ void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferReq }; int64_t detect_language(ov::Tensor& encoder_hidden_state, - ov::InferRequest decoder, + ov::genai::DecoderCache& decoder_cache, const ov::genai::WhisperGenerationConfig& config, ov::genai::RawPerfMetrics& raw_metrics) { + auto decoder = decoder_cache.get_model(1); + decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state}); std::vector init_ids{static_cast(config.decoder_start_token_id)}; @@ -246,7 +248,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state, } std::vector prepare_init_ids(ov::Tensor& encoder_hidden_state, - ov::InferRequest& decoder, + ov::genai::DecoderCache& decoder_cache, const ov::genai::WhisperGenerationConfig& config, const bool return_timestamps, ov::genai::RawPerfMetrics& raw_metrics) { @@ -266,7 +268,7 @@ std::vector prepare_init_ids(ov::Tensor& encoder_hidden_state, language_token_id = static_cast(config.lang_to_id.at(language)); } } else { - language_token_id = detect_language(encoder_hidden_state, decoder, config, raw_metrics); + language_token_id = detect_language(encoder_hidden_state, decoder_cache, config, raw_metrics); } int32_t task_token_id = static_cast(config.transcribe_token_id); @@ -502,18 +504,14 @@ std::shared_ptr redirect_new_kv_to_output(const std::shared_ptris_dynamic()) { // model is dynamic, reshaping it to static - reshape_to_static(m_decoder_model, input_ids_size, input_ids_size, m_lhs_shape); - } else { - reshape_input_ids(m_decoder_model, input_ids_size); - } + reshape_input_ids(m_decoder_model, input_ids_size); ov::Core core = utils::singleton_core(); ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU"); ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model"); - m_cache.emplace(input_ids_size, compiled_model); + m_cache.emplace(input_ids_size, compiled_model.create_infer_request()); } return m_cache.at(input_ids_size); @@ -535,6 +533,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size); auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model); + reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape); reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape); // Replace KV-tensors for the entire cache to tensors only for new token @@ -550,7 +549,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys m_models.encoder = compiled_model.create_infer_request(); // Will compile decoder model when it's needed - m_decoder_cache = DecoderCache(decoder_model, last_hidden_state_shape); + m_decoder_cache = DecoderCache(decoder_model); compiled_model = core.compile_model(decoder_with_past_model, "NPU"); ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model"); @@ -626,11 +625,10 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( // prepare init_ids just once for whole input if (init_ids.empty()) { - m_models.decoder = m_decoder_cache.get_model(1).create_infer_request(); // for detect_language() - init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics); + init_ids = prepare_init_ids(hidden_state_tensor, m_decoder_cache, config, return_timestamps, raw_metrics); // Get decoder with size of input_ids - m_models.decoder = m_decoder_cache.get_model(init_ids.size()).create_infer_request(); + m_models.decoder = m_decoder_cache.get_model(init_ids.size()); } auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor, diff --git a/src/cpp/src/whisper_pipeline_static.hpp b/src/cpp/src/whisper_pipeline_static.hpp index c0e4aa8220..b0618452d4 100644 --- a/src/cpp/src/whisper_pipeline_static.hpp +++ b/src/cpp/src/whisper_pipeline_static.hpp @@ -18,15 +18,12 @@ namespace genai { class DecoderCache { public: DecoderCache() = default; - DecoderCache(std::shared_ptr model, ov::PartialShape shape) - : m_decoder_model(model) - , m_lhs_shape(shape) {} + DecoderCache(std::shared_ptr model) : m_decoder_model(model) {} - ov::CompiledModel get_model(uint8_t input_ids_size); + ov::InferRequest get_model(uint8_t input_ids_size); private: - std::unordered_map m_cache; + std::unordered_map m_cache; std::shared_ptr m_decoder_model; - ov::PartialShape m_lhs_shape; }; class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase {