diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index 3e09099c28..5e0f2ef20e 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -425,19 +425,7 @@ void reshape_to_static_encoder(std::shared_ptr model, const size_t fe } void reshape_input_ids(std::shared_ptr model, const uint32_t input_size) { - std::map new_shapes; - for (auto input : model->inputs()) { - const auto& input_name = input.get_any_name(); - ov::PartialShape new_shape; - if (input_name.find("input_ids") != std::string::npos) { - new_shape = ov::PartialShape({1, input_size}); - } else { - new_shape = input.get_partial_shape(); - } - new_shapes.emplace(input_name, new_shape); - } - - model->reshape(new_shapes); + model->reshape({{"input_ids", ov::PartialShape({1, input_size})}}); } void preprocess_encoder(std::shared_ptr model) { @@ -516,10 +504,16 @@ namespace genai { ov::CompiledModel DecoderCache::get_model(uint8_t input_ids_size) { if (m_cache.find(input_ids_size) == m_cache.cend()) { - reshape_input_ids(m_decoder_model, input_ids_size); + if (m_decoder_model->is_dynamic()) { // model is dynamic, reshaping it to static + reshape_to_static(m_decoder_model, input_ids_size, input_ids_size, m_lhs_shape); + } else { + reshape_input_ids(m_decoder_model, input_ids_size); + } ov::Core core = utils::singleton_core(); - m_cache.insert({input_ids_size, core.compile_model(m_decoder_model, "NPU")}); + ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU"); + ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model"); + m_cache.emplace(input_ids_size, compiled_model); } return m_cache.at(input_ids_size); @@ -541,7 +535,6 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size); auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model); - reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape); // for detect_language() reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape); // Replace KV-tensors for the entire cache to tensors only for new token @@ -556,10 +549,8 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper encoder model"); m_models.encoder = compiled_model.create_infer_request(); - m_decoder_cache = DecoderCache(decoder_model); - compiled_model = m_decoder_cache.get_model(1); - ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model"); - m_models.decoder = compiled_model.create_infer_request(); + // Will compile decoder model when it's needed + m_decoder_cache = DecoderCache(decoder_model, last_hidden_state_shape); compiled_model = core.compile_model(decoder_with_past_model, "NPU"); ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model"); @@ -635,7 +626,7 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( // prepare init_ids just once for whole input if (init_ids.empty()) { - OPENVINO_ASSERT(m_models.decoder.get_tensor("input_ids").get_size() == 1); + m_models.decoder = m_decoder_cache.get_model(1).create_infer_request(); // for detect_language() init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics); // Get decoder with size of input_ids diff --git a/src/cpp/src/whisper_pipeline_static.hpp b/src/cpp/src/whisper_pipeline_static.hpp index 9913760792..c0e4aa8220 100644 --- a/src/cpp/src/whisper_pipeline_static.hpp +++ b/src/cpp/src/whisper_pipeline_static.hpp @@ -17,13 +17,16 @@ namespace genai { class DecoderCache { public: - DecoderCache() {} - DecoderCache(std::shared_ptr model) : m_decoder_model(model) {} + DecoderCache() = default; + DecoderCache(std::shared_ptr model, ov::PartialShape shape) + : m_decoder_model(model) + , m_lhs_shape(shape) {} ov::CompiledModel get_model(uint8_t input_ids_size); private: std::unordered_map m_cache; std::shared_ptr m_decoder_model; + ov::PartialShape m_lhs_shape; }; class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase {