Skip to content

Commit

Permalink
Whisper pipeline: use Sampler (#1615)
Browse files Browse the repository at this point in the history
Ticket: 152889
Closes #1164
  • Loading branch information
as-suvorov authored Jan 24, 2025
1 parent 9caf53c commit 42b16e5
Show file tree
Hide file tree
Showing 26 changed files with 626 additions and 342 deletions.
23 changes: 2 additions & 21 deletions src/cpp/include/openvino/genai/whisper_generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <filesystem>
#include <optional>

#include "generation_config.hpp"
#include "openvino/genai/tokenizer.hpp"
#include "openvino/runtime/compiled_model.hpp"

Expand All @@ -15,28 +16,14 @@ namespace genai {
/**
* @brief Structure to keep whisper generation config parameters.
*/
class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig {
class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig : public GenerationConfig {
public:
WhisperGenerationConfig() = default;
explicit WhisperGenerationConfig(const std::filesystem::path& json_path);

// Generic

// the maximum length the generated tokens can have. Corresponds to the length of the input prompt +
// `max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set.
size_t max_new_tokens = SIZE_MAX;
// the maximum numbers of tokens to generate, excluding the number of tokens in the prompt.
// max_new_tokens has priority over max_length.
size_t max_length = SIZE_MAX;

// Whisper specific

// Corresponds to the ”<|startoftranscript|>” token.
int64_t decoder_start_token_id = 50258;

// End of stream token id.
int64_t eos_token_id = 50257;

// Padding token id.
int64_t pad_token_id = 50257;

Expand Down Expand Up @@ -110,12 +97,6 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig {
// A list containing the non-speech tokens that will be suppressed during generation.
std::vector<int64_t> suppress_tokens;

/** @brief sets eos_token_id to tokenizer_eos_token_id if eos_token_id is less than 0.
* Otherwise verifies eos_token_id == tokenizer_eos_token_id.
*/
void set_eos_token_id(int64_t tokenizer_eos_token_id);
size_t get_max_new_tokens(size_t prompt_length = 0) const;

void update_generation_config(const ov::AnyMap& config_map = {});

template <typename... Properties>
Expand Down
3 changes: 2 additions & 1 deletion src/cpp/src/debug_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
template <typename T>
void print_array(T * array, size_t size) {
std::cout << " => [ ";
for (size_t i = 0; i < size; ++i) {
for (size_t i = 0; i < std::min(size, size_t(10)); ++i) {
std::cout << array[i] << " ";
}
std::cout << " ] " << std::endl;
}

inline void print_tensor(std::string name, ov::Tensor tensor) {
std::cout << name;
std::cout << " " << tensor.get_shape().to_string();
if (tensor.get_element_type() == ov::element::i32) {
print_array(tensor.data<int>(), tensor.get_size());
} else if (tensor.get_element_type() == ov::element::i64) {
Expand Down
12 changes: 9 additions & 3 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
#include "lm_encoding.hpp"
#include "openvino/genai/perf_metrics.hpp"

namespace {

namespace ov {
namespace genai {

/**
* Set position ids tensor data for next token inference based on provided attention mask
* Supports multi batch
* Supports sparse attention_mask
*/
void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) {
const size_t batch_size = attention_mask.get_shape().at(0);
const size_t sequence_length = attention_mask.get_shape().at(1);
Expand Down Expand Up @@ -65,7 +68,10 @@ void update_attention_mask_with_beams(ov::Tensor&& attention_mask, std::vector<i
attention_mask.data<int64_t>()[result_prompt_offset + new_shape.at(1) - 1] = 1;
}
}
}

namespace ov {
namespace genai {

std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
ov::InferRequest& m_llm,
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace ov::genai {

class Logger {
public:
static void warn(std::string message) {
static void warn(const std::string& message) {
std::cout << "[WARN] " << message << '\n';
};
};
Expand Down
19 changes: 11 additions & 8 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits,
}

// check whether group has finished
group.is_done(m_parameters);
group.is_done();

// group cannot continue if there are no valid child beams
if (child_beams_per_group[group_id].size() == 0) {
Expand Down Expand Up @@ -549,14 +549,14 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
std::vector<int64_t> dropped_seq_ids;
for (auto& running_sequence : sequence_group->get_running_sequences()) {
const auto generated_len = running_sequence->get_generated_len();
if (sampling_params.max_new_tokens <= generated_len ||
if (sequence_group->get_max_new_tokens() <= generated_len ||
is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
// stop sequence by max_new_tokens or stop token (eos included)
running_sequence->set_status(SequenceStatus::FINISHED);

if (is_stop_token_id_hit(running_sequence->get_generated_ids().back(), sampling_params.stop_token_ids) && !sampling_params.ignore_eos) {
running_sequence->set_finish_reason(GenerationFinishReason::STOP);
} else if (sampling_params.max_new_tokens == generated_len) {
} else if (sequence_group->get_max_new_tokens() == generated_len) {
running_sequence->set_finish_reason(GenerationFinishReason::LENGTH);
}

Expand Down Expand Up @@ -800,8 +800,8 @@ SamplerOutput Sampler::sample(const std::vector<SequenceGroup::Ptr> & sequence_g
// max counter of needed to be sampled tokens
OPENVINO_ASSERT(running_sequence->get_generated_len() >= token_offset);
size_t generated_and_verified_len = running_sequence->get_generated_len() - token_offset;
OPENVINO_ASSERT(sampling_params.max_new_tokens >= generated_and_verified_len);
size_t max_num_sampled_token = sampling_params.max_new_tokens - generated_and_verified_len;
OPENVINO_ASSERT(sequence_group->get_max_new_tokens() >= generated_and_verified_len);
size_t max_num_sampled_token = sequence_group->get_max_new_tokens() - generated_and_verified_len;
if (max_num_sampled_token == 0) {
stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, max_removed_tokens_per_request);
break;
Expand Down Expand Up @@ -887,7 +887,7 @@ SamplerOutput Sampler::sample(const std::vector<SequenceGroup::Ptr> & sequence_g
// check max length stop criteria
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
if (!sequence_group->has_finished() &&
running_sequences[0]->get_generated_len() == sampling_params.max_new_tokens) {
running_sequences[0]->get_generated_len() == sequence_group->get_max_new_tokens()) {
// stop sequence by max_new_tokens
m_beam_search_info.at(request_id).finalize(sampler_output);
}
Expand Down Expand Up @@ -956,7 +956,10 @@ int64_t Sampler::GroupBeamSearcher::Group::finish(Beam beam, const ov::genai::Ge
return preeempted_sequence_id;
}

void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfig& sampling_params) {
void Sampler::GroupBeamSearcher::Group::is_done() {
const auto sequence_group = ongoing.front().m_sequence->get_sequence_group_ptr();
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();

assert(sampling_params.num_beams % sampling_params.num_beam_groups == 0 &&
"number of beams should be divisible by number of groups");
size_t group_size = sampling_params.num_beams / sampling_params.num_beam_groups;
Expand All @@ -977,7 +980,7 @@ void Sampler::GroupBeamSearcher::Group::is_done(const ov::genai::GenerationConfi
return;
}
case ov::genai::StopCriteria::NEVER: {
size_t length = sampling_params.length_penalty > 0.0 ? sampling_params.max_new_tokens : cur_len;
size_t length = sampling_params.length_penalty > 0.0 ? sequence_group->get_max_new_tokens() : cur_len;
float highest_attainable_score = best_sum_logprobs / std::pow(float(length), sampling_params.length_penalty);
done = worst_score >= highest_attainable_score;
return;
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class Sampler::GroupBeamSearcher {
bool done = false;

int64_t finish(Beam beam, const ov::genai::GenerationConfig& sampling_params);
void is_done(const ov::genai::GenerationConfig& sampling_params);
void is_done();
};

SequenceGroup::Ptr m_sequence_group;
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ class Scheduler {
for (auto idx = 0; idx < sequence_groups.size(); idx++) {
auto seq_length = sequence_groups[idx]->get_prompt_len() * m_kv_blocks_initial_multiplier;
auto gen_config = sequence_groups[idx]->get_sampling_parameters();
seq_length = std::min(seq_length, sequence_groups[idx]->get_prompt_len() + gen_config.get_max_new_tokens(sequence_groups[idx]->get_prompt_len()));
seq_length = std::min(seq_length, sequence_groups[idx]->get_prompt_len() + sequence_groups[idx]->get_max_new_tokens());
size_t blocks_num = std::ceil((float)seq_length / m_block_manager->get_block_size());
if (gen_config.is_beam_search()) {
blocks_num *= gen_config.num_beams;
Expand Down
6 changes: 5 additions & 1 deletion src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,11 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
GenerationOutputs outputs;
outputs.emplace(0, output);
m_generation_stream->push(std::move(outputs));
}
}

size_t get_max_new_tokens() {
return m_sampling_params.get_max_new_tokens(get_prompt_len());
}
};

inline std::shared_ptr<SequenceGroup> Sequence::get_sequence_group_ptr() const {
Expand Down
33 changes: 0 additions & 33 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,6 @@ void print_tensor(const ov::Tensor& tensor) {
std::cout << "]" << std::endl;
}

int64_t argmax(const ov::Tensor& logits, const size_t batch_idx) {
if (logits.get_shape()[0] <= batch_idx) {
OPENVINO_THROW("logits batch size doesn't match the number of beams");
}

size_t vocab_size = logits.get_shape().back();
size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size;
size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size;
const float* logits_data = logits.data<const float>() + batch_offset + sequence_offset;

int64_t out_token = std::max_element(logits_data, logits_data + vocab_size) - logits_data;
float max_logit = logits_data[out_token];

return out_token;
}

/**
* Initializes position ids based on attention mask and starting position
*/
Expand Down Expand Up @@ -128,23 +112,6 @@ void set_attention_mask(ov::Tensor&& attention_mask, std::vector<int32_t> next_b
}
}

/**
* Set position ids tensor data for next token inference based on provided attention mask
* Supports multi batch
* Supports sparse attention_mask
*/
void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask) {
const size_t batch_size = attention_mask.get_shape().at(0);
const size_t atten_length = attention_mask.get_shape().at(1);
position_ids.set_shape({batch_size, 1});

for (size_t batch = 0; batch < batch_size; batch++) {
int64_t* start = attention_mask.data<int64_t>() + batch * atten_length;
// todo: be careful with start + atten_length, probably need to replace with start + atten_length -1
position_ids.data<int64_t>()[batch] = std::accumulate(start, start + atten_length, 0);
}
}

/**
* Get attention mask tensor for next token inference
* Supports multi batch
Expand Down
4 changes: 0 additions & 4 deletions src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,10 @@ Tensor init_attention_mask(const Tensor& position_ids);

void print_tensor(const ov::Tensor& tensor);

int64_t argmax(const ov::Tensor& logits, const size_t batch_idx);

void initialize_position_ids(ov::Tensor& position_ids, const ov::Tensor& attention_mask, int64_t start_pos = 0);

ov::Tensor extend_attention(ov::Tensor attention_mask);

void update_position_ids(ov::Tensor&& position_ids, const ov::Tensor&& attention_mask);

template <typename T> struct OmitOptional { using value = T; };
template <typename T> struct OmitOptional<std::optional<T>> { using value = T; };

Expand Down
1 change: 0 additions & 1 deletion src/cpp/src/whisper/logit_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void process_whisper_timestamp_logits(ov::Tensor& logits,
const std::vector<int64_t>& generated_tokens,
bool initial_step = false) {
const size_t batch_size = logits.get_shape().at(0);
OPENVINO_ASSERT(batch_size == 1, "Batch != 1 is not supported");

size_t vocab_size = logits.get_shape().back();
size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size;
Expand Down
52 changes: 51 additions & 1 deletion src/cpp/src/whisper/models/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <filesystem>

#include "statefull_decoder.hpp"
#include "utils.hpp"
#include "whisper/whisper_utils.hpp"
#include "with_past_decoder.hpp"

namespace ov::genai {
Expand All @@ -22,5 +22,55 @@ std::shared_ptr<WhisperDecoder> WhisperDecoder::from_path(const std::filesystem:
return std::make_shared<WhisperStatefullDecoder>(models_path, device, properties);
}

std::pair<int64_t, float> WhisperDecoder::detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) {
Tensor input_ids_tensor{ov::element::i64, {1, 1}};
input_ids_tensor.data<int64_t>()[0] = decoder_start_token_id;

Tensor beam_idx_tensor{ov::element::i32, {1}};
beam_idx_tensor.data<int32_t>()[0] = 0;

auto [output_tensor, infer_ms] = decode(encoder_hidden_state, input_ids_tensor, beam_idx_tensor);

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);

reset_state();

return {output_token, infer_ms};
}

/**
* Encoder hidden states expected to be with batch 1
* Copy encoder hidden state tensor from batch 1 to requested batch_size.
* Set new encoder hidden states tensor to infer request.
*/
void WhisperDecoder::_set_encoder_hidden_states_tensor(const Tensor& encoder_hidden_state,
const size_t batch_size,
InferRequest& request) {
const size_t current_batch_size = request.get_tensor("encoder_hidden_states").get_shape().at(0);
// batch hasn't changed, skip
if (current_batch_size == batch_size) {
return;
}

OPENVINO_ASSERT(encoder_hidden_state.get_shape().at(0) == 1);
Shape shape{encoder_hidden_state.get_shape()};
shape[0] = batch_size;

Tensor new_encoder_hidden_states{ov::element::f32, shape};

auto new_encoder_hidden_states_data = new_encoder_hidden_states.data<float>();
auto encoder_hidden_state_data = encoder_hidden_state.data<float>();

for (size_t batch = 0; batch < batch_size; batch++) {
const size_t batch_offset = batch * encoder_hidden_state.get_size();
std::memcpy(new_encoder_hidden_states_data + batch_offset,
encoder_hidden_state_data,
encoder_hidden_state.get_byte_size());
}

request.set_tensor("encoder_hidden_states", new_encoder_hidden_states);
}

WhisperDecoder::~WhisperDecoder() = default;
} // namespace ov::genai
14 changes: 9 additions & 5 deletions src/cpp/src/whisper/models/decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@ class WhisperDecoder {
const std::string& device,
const ov::AnyMap& properties);

virtual std::pair<int64_t, float> detect_language(const ov::Tensor& encoder_hidden_state,
const int64_t decoder_start_token_id) = 0;
std::pair<int64_t, float> detect_language(const Tensor& encoder_hidden_state, const int64_t decoder_start_token_id);

virtual std::pair<ov::Tensor, float> decode(const ov::Tensor& encoder_hidden_state,
const std::vector<int64_t>& input_ids,
const size_t cache_position) = 0;
virtual std::pair<Tensor, float> decode(const Tensor& encoder_hidden_state,
const Tensor& input_ids,
const Tensor& beam_idx) = 0;

virtual void reset_state() = 0;

virtual ~WhisperDecoder();

protected:
void _set_encoder_hidden_states_tensor(const Tensor& encoder_hidden_state,
const size_t batch_size,
InferRequest& request);
};
} // namespace ov::genai
Loading

0 comments on commit 42b16e5

Please sign in to comment.