Skip to content

Commit

Permalink
Move detect_language to base decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
as-suvorov committed Jan 23, 2025
1 parent e802584 commit abde309
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ int main(int argc, char* argv[]) try {
ov::genai::WhisperGenerationConfig config = pipeline.get_generation_config();
config.max_new_tokens = 100; // increase this based on your speech length
// 'task' and 'language' parameters are supported for multilingual models only
config.language = "<|en|>"; // can switch to <|zh|> for Chinese language
// config.language = "<|en|>"; // can switch to <|zh|> for Chinese language
config.task = "transcribe";
config.return_timestamps = true;

Expand Down
17 changes: 17 additions & 0 deletions src/cpp/src/whisper/models/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ std::shared_ptr<WhisperDecoder> WhisperDecoder::from_path(const std::filesystem:
return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties);
}

std::pair<int64_t, float> WhisperDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
Tensor input_ids_tensor{ov::element::i64, {1, 1}};
input_ids_tensor.data<int64_t>()[0] = decoder_start_token_id;

Tensor beam_idx_tensor{ov::element::i32, {1}};
beam_idx_tensor.data<int32_t>()[0] = 0;

auto [output_tensor, infer_ms] = decode(encoder_hidden_state, input_ids_tensor, beam_idx_tensor);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

reset_state();

return {output_token, infer_ms};
}

/**
* Encoder hidden states expected to be with batch 1
* Copy encoder hidden state tensor from batch 1 to requested batch_size.
Expand Down
3 changes: 1 addition & 2 deletions src/cpp/src/whisper/models/decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ class WhisperDecoder {
const std::string& device,
const ov::AnyMap& properties);

virtual std::pair<int64_t, float> detect_language(const Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) = 0;
std::pair<int64_t, float> detect_language(const Tensor& encoder_hidden_state, const int64_t decoder_start_token_id);

virtual std::pair<Tensor, float> decode(const Tensor& encoder_hidden_state,
const Tensor& input_ids,
Expand Down
17 changes: 0 additions & 17 deletions src/cpp/src/whisper/models/statefull_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,6 @@ WhisperStatefullDecoder::WhisperStatefullDecoder(const std::filesystem::path& mo
m_request = compiled_model.create_infer_request();
}

std::pair<int64_t, float> WhisperStatefullDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
Tensor input_ids_tensor{ov::element::i64, {1, 1}};
input_ids_tensor.data<int64_t>()[0] = decoder_start_token_id;

Tensor beam_idx_tensor{ov::element::i32, {1}};
beam_idx_tensor.data<int32_t>()[0] = 0;

auto [output_tensor, infer_ms] = decode(encoder_hidden_state, input_ids_tensor, beam_idx_tensor);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

reset_state();

return {output_token, infer_ms};
}

std::pair<ov::Tensor, float> WhisperStatefullDecoder::decode(const Tensor& encoder_hidden_state,
const Tensor& input_ids,
const Tensor& beam_idx) {
Expand Down
3 changes: 0 additions & 3 deletions src/cpp/src/whisper/models/statefull_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ class WhisperStatefullDecoder : public WhisperDecoder {
const std::string& device,
const ov::AnyMap& properties);

std::pair<int64_t, float> detect_language(const Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) override;

std::pair<Tensor, float> decode(const Tensor& encoder_hidden_state,
const Tensor& input_ids,
const Tensor& beam_idx) override;
Expand Down
17 changes: 0 additions & 17 deletions src/cpp/src/whisper/models/with_past_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,6 @@ WhisperWithPastDecoder::WhisperWithPastDecoder(const std::filesystem::path& mode
m_request_decoder_with_past = compiled_model.create_infer_request();
}

std::pair<int64_t, float> WhisperWithPastDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
Tensor input_ids_tensor{ov::element::i64, {1, 1}};
input_ids_tensor.data<int64_t>()[0] = decoder_start_token_id;

Tensor beam_idx_tensor{ov::element::i32, {1}};
beam_idx_tensor.data<int32_t>()[0] = 0;

auto [output_tensor, infer_ms] = decode(encoder_hidden_state, input_ids_tensor, beam_idx_tensor);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

reset_state();

return {output_token, infer_ms};
}

std::pair<Tensor, float> WhisperWithPastDecoder::decode(const Tensor& encoder_hidden_state,
const Tensor& input_ids,
const Tensor& beam_idx) {
Expand Down
3 changes: 0 additions & 3 deletions src/cpp/src/whisper/models/with_past_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ class WhisperWithPastDecoder : public WhisperDecoder {
const std::string& device,
const ov::AnyMap& properties);

std::pair<int64_t, float> detect_language(const Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) override;

std::pair<Tensor, float> decode(const Tensor& encoder_hidden_state,
const Tensor& input_ids,
const Tensor& beam_idx) override;
Expand Down

0 comments on commit abde309

Please sign in to comment.