Skip to content

Commit

Permalink
Merge branch 'master' into at/static-llm-pipeline-dynamic-shape-model
Browse files Browse the repository at this point in the history
  • Loading branch information
TolyaTalamanov authored Jan 4, 2025
2 parents df988be + 002f84f commit a9cae71
Show file tree
Hide file tree
Showing 25 changed files with 559 additions and 732 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/mac.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: macOS (12, Python 3.9)
name: macOS (12, Python 3.10)
on:
workflow_dispatch:
pull_request:
Expand All @@ -16,7 +16,7 @@ concurrency:
cancel-in-progress: true

env:
PYTHON_VERSION: '3.9'
PYTHON_VERSION: '3.10'
OV_BRANCH: master
OV_TARBALL: ''

Expand Down
14 changes: 7 additions & 7 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(

bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control);
utils::apply_gather_before_matmul_transformation(model);

initialize_pipeline(model, scheduler_config, properties, device_config, core);
}
Expand Down Expand Up @@ -444,7 +445,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(
const float * logits_data = logits.data<float>();
ov::Shape logits_shape = logits.get_shape();
OPENVINO_ASSERT(logits_shape.size() == 3);
size_t batch_seq_len = logits_shape[1], vocab_size = logits_shape[2];
size_t vocab_size = logits_shape[2];
for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
// requests not scheduled, in decoding phase or not echoing are not processed
Expand All @@ -454,26 +455,25 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(

size_t num_running_sequences = sequence_group->num_running_seqs();
OPENVINO_ASSERT(num_running_sequences == 1);
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens();
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
size_t output_seq_len = sequence_group->get_output_seq_len();

const float * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;

size_t num_prompt_tokens_processed = sequence_group->get_num_processed_tokens();
OPENVINO_ASSERT(num_prompt_tokens_processed + actual_seq_len <= sequence_group->get_prompt_len());
OPENVINO_ASSERT(num_prompt_tokens_processed + output_seq_len <= sequence_group->get_prompt_len());

// if we processed the whole prompt we don't include last logprob as it will be processed by the sampler (it's already completion)
// otherwise we include it as it will be used in the next part of the prompt
int exclude_last_logprob = 1;
if (num_prompt_tokens_processed + actual_seq_len < sequence_group->get_prompt_len())
if (num_prompt_tokens_processed + output_seq_len < sequence_group->get_prompt_len())
exclude_last_logprob = 0;

// if we start processing the prompt we add "fake" log prob for the first position (begin of sequence)
if (num_prompt_tokens_processed == 0)
sequence_group->append_prompt_log_prob(1.0);

for (int token_logits_offset = 0, token_id_offset = num_prompt_tokens_processed + 1;
token_logits_offset < actual_seq_len - exclude_last_logprob;
token_logits_offset < output_seq_len - exclude_last_logprob;
token_logits_offset++, token_id_offset++) {

const float* token_logits = (sequence_group_logits_data + token_logits_offset * vocab_size);
Expand All @@ -498,7 +498,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(

sequence_group->append_prompt_log_prob(token_logit - max_value - log_sum);
}
currently_processed_tokens += padded_amount_of_processed_tokens * num_running_sequences;
currently_processed_tokens += output_seq_len * num_running_sequences;
// For max_new_tokens == 0, we don't reach sampling so need to notify handle separately
if(sequence_group->get_sampling_parameters().max_new_tokens == 0) {
sequence_group->notify_handle_echo_only();
Expand Down
14 changes: 7 additions & 7 deletions src/cpp/src/generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ void GenerationConfig::validate() const {
OPENVINO_ASSERT(temperature > 0, "When 'do_sample' is true, temperature must be a strictly positive float, but got ", temperature);
} else {
// parameters requiring multinomial
OPENVINO_ASSERT(top_k == std::numeric_limits<size_t>::max(), "When 'do_sample' is false, top_k must be max of size_t, but got ", top_k);
OPENVINO_ASSERT(top_p == 1.0f, "When 'do_sample' is false, top_p must be 1.0f, but got ", top_p);
OPENVINO_ASSERT(temperature == 1.0f, "When 'do_sample' is false, temperature must be a 1.0f, but got ", temperature);
// OPENVINO_ASSERT(top_k == std::numeric_limits<size_t>::max(), "When 'do_sample' is false, top_k must be max of size_t, but got ", top_k);
// OPENVINO_ASSERT(top_p == 1.0f, "When 'do_sample' is false, top_p must be 1.0f, but got ", top_p);
// OPENVINO_ASSERT(temperature == 1.0f, "When 'do_sample' is false, temperature must be a 1.0f, but got ", temperature);
}

if (is_beam_search()) {
Expand All @@ -252,10 +252,10 @@ void GenerationConfig::validate() const {
}
} else {
// parameters requiring beam search
OPENVINO_ASSERT(num_beam_groups == 1, "'num_beam_groups' is supported by beam search only and should be 1 otherwise, but got ", num_beam_groups);
OPENVINO_ASSERT(no_repeat_ngram_size == std::numeric_limits<size_t>::max(), "'no_repeat_ngram_size' is supported only by beam search, otherwise should be set to max of size_t, but got ", no_repeat_ngram_size);
OPENVINO_ASSERT(diversity_penalty == 0.0f, "'diversity_penalty' is set to ", diversity_penalty, " (default is 0.0f), which is supported only by beam search sampling");
OPENVINO_ASSERT(length_penalty == 1.0f, "'length_penalty' is set to ", length_penalty, " (default is 1.0f), which is supported only by beam search sampling");
// OPENVINO_ASSERT(num_beam_groups == 1, "'num_beam_groups' is supported by beam search only and should be 1 otherwise, but got ", num_beam_groups);
// OPENVINO_ASSERT(no_repeat_ngram_size == std::numeric_limits<size_t>::max(), "'no_repeat_ngram_size' is supported only by beam search, otherwise should be set to max of size_t, but got ", no_repeat_ngram_size);
// OPENVINO_ASSERT(diversity_penalty == 0.0f, "'diversity_penalty' is set to ", diversity_penalty, " (default is 0.0f), which is supported only by beam search sampling");
// OPENVINO_ASSERT(length_penalty == 1.0f, "'length_penalty' is set to ", length_penalty, " (default is 1.0f), which is supported only by beam search sampling");
}

// assistant generation
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
const ov::AnyMap& properties,
const ov::genai::GenerationConfig& generation_config)
: LLMPipelineImplBase(tokenizer, generation_config), m_sampler(m_tokenizer) {
utils::slice_matmul_stateful_model(model);
utils::apply_slice_before_matmul_transformation(model);
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);

ov::CompiledModel compiled_model;
Expand Down
51 changes: 42 additions & 9 deletions src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,28 +114,54 @@ class ModelRunner {
subsequence_begins_data[0] = 0;
block_indices_begins_data[0] = 0;

bool matmul_gathering_is_available = false;
size_t gathering_current_index = 0;
std::vector<int64_t> gather_indices_values;
try {
std::ignore = m_request.get_tensor("sampled_tokens_indices");
matmul_gathering_is_available = true;
} catch (const ov::Exception&) {}


for (size_t i = 0; i < num_sequence_groups; ++i) {
size_t seq_group_id = scheduler_output.m_scheduled_sequence_groups_ids[i];
SequenceGroup::CPtr sequence_group = sequence_groups[seq_group_id];
std::vector<Sequence::CPtr> running_sequences = sequence_group->get_running_sequences();
SequenceGroup::Ptr sequence_group = sequence_groups[seq_group_id];
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
size_t num_running_sequences = running_sequences.size();
size_t num_scheduled_tokens = sequence_group->get_num_scheduled_tokens();
size_t group_position_id = sequence_group->get_num_processed_tokens();
size_t prompt_len = sequence_group->get_prompt_len();

// spec: In case of multiple input tokens for current sequence (prompt_len > 1),
// context_len corresponds to first token within subgroup of scheduled tokens
size_t group_context_len = group_position_id;
// Next variables are only for sliced matmul case
size_t output_seq_len = 0;
const bool echo_output = sequence_group->get_sampling_parameters().echo;
const bool sampling_is_required = sequence_group->requires_sampling();
const size_t tokens_to_sample_per_sequence = 1 + sequence_group->get_num_tokens_to_validate();

for (size_t seq_id = 0; seq_id < num_running_sequences; ++seq_id) {
output_seq_len = 0;
Sequence::CPtr sequence = running_sequences[seq_id];

for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id) {
for (size_t token_id = 0, position_id = group_position_id; token_id < num_scheduled_tokens; ++token_id, ++position_id, ++gathering_current_index) {
// compute token for current sequence
input_ids_data[token_id] = position_id < sequence_group->get_prompt_len() ?
input_ids_data[token_id] = position_id < prompt_len ?
sequence_group->get_prompt_ids()[position_id] :
sequence->get_generated_ids()[position_id - sequence_group->get_prompt_len()];
sequence->get_generated_ids()[position_id - prompt_len];

position_ids_data[token_id] = position_id;

// Check if token gathering is required for the entire sequence group
if (matmul_gathering_is_available && (sampling_is_required || echo_output)) {
// Determine if the current token should be gathered
if (echo_output ||
// Skip gathering for prompt tokens
group_position_id + token_id >= prompt_len - 1 &&
// Gather only the last scheduled token or 1 + num_tokens_to_validate tokens for SD
// In SD, tokens_to_sample_per_sequence may exceed num_scheduled_tokens
token_id + tokens_to_sample_per_sequence >= num_scheduled_tokens) {
gather_indices_values.push_back(gathering_current_index);
output_seq_len++;
}
}
}

size_t expected_kv_cache_size = sequence_group->get_num_processed_tokens() - sequence_group->get_num_evicted_tokens();
Expand All @@ -153,6 +179,7 @@ class ModelRunner {
subsequence_begins_data += 1;
block_indices_begins_data += 1;
}
sequence_group->set_output_seq_len(matmul_gathering_is_available ? output_seq_len : num_scheduled_tokens);
}

// typical LLM parameters
Expand All @@ -168,6 +195,12 @@ class ModelRunner {
m_request.set_tensor("block_indices_begins", block_indices_begins);
m_request.set_tensor("max_context_len", max_context_len);

if (matmul_gathering_is_available) {
ov::Tensor gather_indices(ov::element::i64, {gather_indices_values.size()});
std::memcpy(gather_indices.data(), gather_indices_values.data(), gather_indices_values.size() * sizeof(int64_t));
m_request.set_tensor("sampled_tokens_indices", gather_indices);
}

// print_tensor("input_ids", input_ids);
// print_tensor("position_ids", position_ids);

Expand Down
13 changes: 6 additions & 7 deletions src/cpp/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
const float * logits_data = logits.data<float>();
ov::Shape logits_shape = logits.get_shape();
OPENVINO_ASSERT(logits_shape.size() == 3);
size_t batch_seq_len = logits_shape[1], vocab_size = logits_shape[2];
size_t vocab_size = logits_shape[2];

SamplerOutput sampler_output;
for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
Expand All @@ -758,8 +758,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
continue;

size_t num_running_sequences = sequence_group->num_running_seqs();
size_t actual_seq_len = sequence_group->get_num_scheduled_tokens(); // points to a token which needs to be sampled
size_t padded_amount_of_processed_tokens = std::max(actual_seq_len, batch_seq_len);
size_t output_seq_len = sequence_group->get_output_seq_len();
const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters();

const auto request_id = sequence_group->get_request_id();
Expand All @@ -774,13 +773,13 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
auto& stop_strings = m_stop_strings.at(request_id);
auto& logit_processor = m_logit_processors.at(request_id);
const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, output_seq_len, vocab_size}, (void *)sequence_group_logits_data);
size_t max_removed_tokens_per_request = 0, min_generated_len = std::numeric_limits<size_t>::max(), updated_validation_len = 0;
if (sequence_group->requires_sampling()) {
// get number of token to be validated
auto num_tokens_to_process = sequence_group->get_num_tokens_to_validate();
if (num_tokens_to_process > actual_seq_len - 1) {
auto delta = num_tokens_to_process - (actual_seq_len - 1);
if (num_tokens_to_process > output_seq_len - 1) {
auto delta = num_tokens_to_process - (output_seq_len - 1);
updated_validation_len = std::max(updated_validation_len, delta);
num_tokens_to_process -= delta;
}
Expand Down Expand Up @@ -914,7 +913,7 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
}

// accumulate a number of processed tokens
currently_processed_tokens += padded_amount_of_processed_tokens * num_running_sequences;
currently_processed_tokens += output_seq_len * num_running_sequences;
}

return sampler_output;
Expand Down
13 changes: 13 additions & 0 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
size_t m_num_validation_tokens = 0;
// flag to enable/disable token generation, e.g. in speculative decoding scenario
bool m_is_gen_paused = false;
// output seq len at current iteration
size_t m_output_seq_len = 0;

size_t m_num_streamed_tokens = 0, m_stream_window_size = 0;

Expand Down Expand Up @@ -394,6 +396,14 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
return m_num_processed_tokens;
}

size_t get_output_seq_len() const {
return m_output_seq_len;
}

void set_output_seq_len(size_t len) {
m_output_seq_len = len;
}

/**
* Registers within the sequence group that a given amount of tokens
* has been evicted from the underlying KV cache.
Expand Down Expand Up @@ -436,11 +446,14 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {

void schedule_tokens(size_t num_tokens) {
m_num_scheduled_tokens = num_tokens;
// Unless otherwise specified, the sampler will process all scheduled tokens.
m_output_seq_len = num_tokens;
}

void clear_scheduled_tokens() {
m_num_scheduled_tokens = 0;
m_num_validation_tokens = 0;
m_output_seq_len = 0;
}

bool is_scheduled() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::SpeculativeDecodingImpl(con

utils::apply_paged_attention_transformations(main_model, main_model_desc.scheduler_config.use_cache_eviction);
utils::apply_paged_attention_transformations(draft_model, main_model_desc.scheduler_config.use_cache_eviction);
utils::apply_gather_before_matmul_transformation(main_model);
utils::apply_gather_before_matmul_transformation(draft_model);

std::string draft_device = draft_model_desc.device.empty() ? main_model_desc.device : draft_model_desc.device;
bool is_draft_scheduler_undefined = draft_model_desc.scheduler_config == SchedulerConfig();
Expand Down
Loading

0 comments on commit a9cae71

Please sign in to comment.