Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eshiryae committed Jan 10, 2025
1 parent 614e6d9 commit 9e91e32
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
26 changes: 12 additions & 14 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,11 @@ void prepare_decoder_with_past(ov::InferRequest& decoder_with_past, ov::InferReq
};

int64_t detect_language(ov::Tensor& encoder_hidden_state,
ov::InferRequest decoder,
ov::genai::DecoderCache& decoder_cache,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::RawPerfMetrics& raw_metrics) {
auto decoder = decoder_cache.get_model(1);

decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});

std::vector<int32_t> init_ids{static_cast<int32_t>(config.decoder_start_token_id)};
Expand Down Expand Up @@ -246,7 +248,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state,
}

std::vector<int32_t> prepare_init_ids(ov::Tensor& encoder_hidden_state,
ov::InferRequest& decoder,
ov::genai::DecoderCache& decoder_cache,
const ov::genai::WhisperGenerationConfig& config,
const bool return_timestamps,
ov::genai::RawPerfMetrics& raw_metrics) {
Expand All @@ -266,7 +268,7 @@ std::vector<int32_t> prepare_init_ids(ov::Tensor& encoder_hidden_state,
language_token_id = static_cast<int32_t>(config.lang_to_id.at(language));
}
} else {
language_token_id = detect_language(encoder_hidden_state, decoder, config, raw_metrics);
language_token_id = detect_language(encoder_hidden_state, decoder_cache, config, raw_metrics);
}

int32_t task_token_id = static_cast<int32_t>(config.transcribe_token_id);
Expand Down Expand Up @@ -502,18 +504,14 @@ std::shared_ptr<ov::Model> redirect_new_kv_to_output(const std::shared_ptr<ov::M
namespace ov {
namespace genai {

ov::CompiledModel DecoderCache::get_model(uint8_t input_ids_size) {
ov::InferRequest DecoderCache::get_model(uint8_t input_ids_size) {
if (m_cache.find(input_ids_size) == m_cache.cend()) {
if (m_decoder_model->is_dynamic()) { // model is dynamic, reshaping it to static
reshape_to_static(m_decoder_model, input_ids_size, input_ids_size, m_lhs_shape);
} else {
reshape_input_ids(m_decoder_model, input_ids_size);
}
reshape_input_ids(m_decoder_model, input_ids_size);

ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model = core.compile_model(m_decoder_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder model");
m_cache.emplace(input_ids_size, compiled_model);
m_cache.emplace(input_ids_size, compiled_model.create_infer_request());
}

return m_cache.at(input_ids_size);
Expand All @@ -535,6 +533,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
reshape_to_static_encoder(encoder_model, m_feature_extractor.feature_size);

auto last_hidden_state_shape = get_encoder_hidden_state_shape(encoder_model);
reshape_to_static(decoder_model, 1, 1, last_hidden_state_shape);
reshape_to_static(decoder_with_past_model, 1, max_sequence_length, last_hidden_state_shape);

// Replace KV-tensors for the entire cache to tensors only for new token
Expand All @@ -550,7 +549,7 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
m_models.encoder = compiled_model.create_infer_request();

// Will compile decoder model when it's needed
m_decoder_cache = DecoderCache(decoder_model, last_hidden_state_shape);
m_decoder_cache = DecoderCache(decoder_model);

compiled_model = core.compile_model(decoder_with_past_model, "NPU");
ov::genai::utils::print_compiled_model_properties(compiled_model, "Static Whisper decoder with past model");
Expand Down Expand Up @@ -626,11 +625,10 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(

// prepare init_ids just once for whole input
if (init_ids.empty()) {
m_models.decoder = m_decoder_cache.get_model(1).create_infer_request(); // for detect_language()
init_ids = prepare_init_ids(hidden_state_tensor, m_models.decoder, config, return_timestamps, raw_metrics);
init_ids = prepare_init_ids(hidden_state_tensor, m_decoder_cache, config, return_timestamps, raw_metrics);

// Get decoder with size of input_ids
m_models.decoder = m_decoder_cache.get_model(init_ids.size()).create_infer_request();
m_models.decoder = m_decoder_cache.get_model(init_ids.size());
}

auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor,
Expand Down
9 changes: 3 additions & 6 deletions src/cpp/src/whisper_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@ namespace genai {
class DecoderCache {
public:
DecoderCache() = default;
DecoderCache(std::shared_ptr<ov::Model> model, ov::PartialShape shape)
: m_decoder_model(model)
, m_lhs_shape(shape) {}
DecoderCache(std::shared_ptr<ov::Model> model) : m_decoder_model(model) {}

ov::CompiledModel get_model(uint8_t input_ids_size);
ov::InferRequest get_model(uint8_t input_ids_size);
private:
std::unordered_map<uint8_t, ov::CompiledModel> m_cache;
std::unordered_map<uint8_t, ov::InferRequest> m_cache;
std::shared_ptr<ov::Model> m_decoder_model;
ov::PartialShape m_lhs_shape;
};

class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPipelineImplBase {
Expand Down

0 comments on commit 9e91e32

Please sign in to comment.