Skip to content

Commit

Permalink
Implement token rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Nov 12, 2024
1 parent 7039c3e commit 0d60110
Show file tree
Hide file tree
Showing 13 changed files with 690 additions and 38 deletions.
70 changes: 70 additions & 0 deletions src/cpp/src/cache_eviction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,74 @@ namespace ov::genai {
m_scores[decoder_layer_idx] = new_scores;
m_cache_counter[decoder_layer_idx] = new_counter;
}

CacheRotationCalculator::CacheRotationCalculator(size_t block_size, size_t max_context_length, size_t kv_head_size, double rope_theta) : m_block_size(block_size), m_head_size(kv_head_size) {
// Frequencies follow the original recipe from RoFormer:
// https://arxiv.org/pdf/2104.09864v5
//
// However, the way the rotation coefficients are ultimately applied in Llama and related models from huggingface is very different
// from the RoFormer - the embedding-dimension coefficients are not treated as consecutive x-y coordinate pairs, but are rather
// divided into contiguous x-like and y-like halves - see `rotate_half` function in HF transformers. It can be shown that this form
// still preserves the relative positioning property from the RoFormer article.
OPENVINO_ASSERT(rope_theta > 0, "rope_theta must be positive");
size_t max_position_angle_multiplier = max_context_length;
size_t num_freqs = kv_head_size / 2;
m_rope_sin_lut.resize(max_position_angle_multiplier);
m_rope_cos_lut.resize(max_position_angle_multiplier);

for (size_t i = 0; i < max_position_angle_multiplier; i++) {
m_rope_sin_lut[i].reserve(num_freqs);
m_rope_cos_lut[i].reserve(num_freqs);
for (size_t j = 0; j < num_freqs; j++) {
double exponent = - static_cast<double>(2 * j) / kv_head_size;
double base_angle = std::pow(rope_theta, exponent);
m_rope_sin_lut[i].push_back(-std::sin(i * base_angle)); // minus since we will be rotating by an inverse angle
m_rope_cos_lut[i].push_back(std::cos(i * base_angle));
}
}
}

std::vector<CacheRotationCalculator::BlockRotationData> CacheRotationCalculator::get_rotation_coefficients(const std::set<size_t>& evicted_block_logical_indices, size_t num_logical_blocks_before_eviction) {
OPENVINO_ASSERT(num_logical_blocks_before_eviction * m_block_size < m_rope_sin_lut.size(),
"num_logical_blocks_before_eviction may not correspond to less tokens than max_context_length");

std::vector<BlockRotationData> retval;
if (evicted_block_logical_indices.empty()) {
return retval;
}

for (auto idx : evicted_block_logical_indices) {
OPENVINO_ASSERT(idx < num_logical_blocks_before_eviction);
}

// num_logical_blocks_before_eviction > evicted_block_logical_indices.size() is automatically guaranteed by the
// set property and the previous assertion
retval.reserve(num_logical_blocks_before_eviction - evicted_block_logical_indices.size());

ptrdiff_t current_rotation_delta_in_blocks = 0;
std::vector<size_t> logical_block_space(num_logical_blocks_before_eviction);
std::iota(logical_block_space.begin(), logical_block_space.end(), 0);

for (size_t logical_block_idx : logical_block_space) {
if (evicted_block_logical_indices.find(logical_block_idx) != evicted_block_logical_indices.end()) {
current_rotation_delta_in_blocks += 1;
}
else {
if (current_rotation_delta_in_blocks != 0) {
BlockRotationData block_rotation_data;
block_rotation_data.logical_block_idx = logical_block_idx - current_rotation_delta_in_blocks;
block_rotation_data.cosines.reserve(m_block_size);
block_rotation_data.sines.reserve(m_block_size);
for (size_t i = 0; i < m_block_size; i++) {
block_rotation_data.cosines.push_back(m_rope_cos_lut[current_rotation_delta_in_blocks * m_block_size]);
block_rotation_data.sines.push_back(m_rope_sin_lut[current_rotation_delta_in_blocks * m_block_size]);
}

retval.push_back(block_rotation_data);
}
}
}

return retval;
}
}
64 changes: 64 additions & 0 deletions src/cpp/src/cache_eviction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,68 @@ class CacheEvictionAlgorithm {
std::vector<std::vector<size_t>> m_cache_counter;
};

/**
* @brief Computes, based on the logical indices of the blocks to be evicted, the rotation coefficients for the
* remaining cache blocks.
*
* The rotation assumes that the executed model applies rotary positional embedding (RoPE) during the execution of
* the attention operation. Each cache block therefore has the RoPE values already "baked in", with positions equivalent
* to the point in time when the cache block values were originally computed in one of the previous attention operations.
* When blocks are evicted, the logical index space of the remaining blocks is in general no longer contiguous with respect to
* the effective positions of tokens in the blocks. Cache rotation allows to remedy this by effectively adjusting the RoPE positions
* of certain blocks in the cache after eviction, by additionally "rotating" them (in the same sense as in RoPE) by such angles that
* the cache blocks in the logical index space are again contiguous in terms of the RoPE positions. This is supposed to make the
* eviction process less impactful on the accuracy of the generation.
*
* Currently only the basic RoPE method is supported (as applied in the Llama original models). Each model in general may have
* its own RoPE method (e.g. non-linear/NTK frequency scaling), and ideally the cache rotation calculator should be adjusted based on
* the specifics of the RoPE defined by the LLM.
*/
class CacheRotationCalculator {
public:
/**
* Constructs a CacheRotationCalculator.
* @param block_size Block size of the KV cache to evict from.
* @param max_context_length Maximum length possible for a sequence in the current pipeline.
* @param kv_head_size The size (in elements) of the embedding dimension in the attention operation.
* @param rope_theta The base RoPE angle used in the original LLM.
*/
CacheRotationCalculator(size_t block_size, size_t max_context_length, size_t kv_head_size, double rope_theta = 10000.0f);

using RotationCoefficientsPerToken = std::vector<std::vector<double>>; // dimensions: [BLOCK_SIZE, head_size / 2]

/**
* Basic output structure for the calculator.
*/
struct BlockRotationData {
bool operator==(const BlockRotationData& rhs) const {
return (logical_block_idx == rhs.logical_block_idx) && (sines == rhs.sines) && (cosines == rhs.cosines);
}
size_t logical_block_idx; /** Logical index of the block AFTER eviction to which the sine and cosine coefficients should be applied */
RotationCoefficientsPerToken sines; /** The sine coefficients to be applied to this block's contents for rotation, in order of the block's elements */
RotationCoefficientsPerToken cosines; /** The cosine coefficients to be applied to this block's contents for rotation, in order of the block's elements */
};

/**
* Computes the rotation coefficients for the given state of the logical block space when eviction is about to take place.
* @param evicted_block_logical_indices The logical block indices that the prior cache eviction algorithm step determined to be necessary to evict.
* @param num_logical_blocks_before_eviction Number of logical blocks that the evicted-from sequence occupied before the eviction step.
* @return A vector of per-block rotation data, including the indices of blocks after eviction that should be rotated, and the pre-computed trigonometric coefficients necessary for rotation.
*/
std::vector<BlockRotationData> get_rotation_coefficients(const std::set<size_t>& evicted_block_logical_indices, size_t num_logical_blocks_before_eviction);

/**
* @return The size of the embedding dimension that this CacheRotationCalculator was initialized with.
*/
size_t get_head_size() const {
return m_head_size;
}

private:
size_t m_block_size;
size_t m_head_size;
std::vector<std::vector<double>> m_rope_sin_lut; // dimensions: [ max_context_length, head_size / 2]
std::vector<std::vector<double>> m_rope_cos_lut; // dimensions: [ max_context_length, head_size / 2]
};

}
1 change: 1 addition & 0 deletions src/cpp/src/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class CacheManager {
m_value_cache.reserve(m_device_config.get_num_layers());

const std::string device_name = device_config.get_device();
std::cout << "VSHAMPOR: cache precision is " << device_config.get_cache_precision() << std::endl;
if (device_name.find("GPU") == std::string::npos) {// Allocate KV caches
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Tensor key_cache(device_config.get_cache_precision(), device_config.get_key_cache_shape());
Expand Down
78 changes: 75 additions & 3 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,31 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
}

m_scheduler = std::make_shared<Scheduler>(device_config.get_block_size(), updated_config, device_config.get_num_layers(), can_use_partial_preemption);

// and finally create model runner
bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction;
m_model_runner = std::make_shared<ModelRunner>(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers(), is_use_cache_eviction);
if (is_use_cache_eviction) {
m_model_runner = std::make_shared<ModelRunner>(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers(),
/* collect_attention_scores = */ true,
/* is_use_per_layer_cache_control = */ true);
m_rotation_coefficient_stores.reserve(device_config.get_num_layers());
ov::Shape rotation_coefficient_store_shape{ device_config.get_head_size() * (m_scheduler->get_block_size() * scheduler_config.num_kv_blocks) };
for (size_t i = 0; i < device_config.get_num_layers(); i++) {
ov::Tensor store(ov::element::f32, rotation_coefficient_store_shape);
std::memset(store.data(), 0, store.get_byte_size());
m_rotation_coefficient_stores.push_back(store);
}
m_next_step_rotation_coefficients.resize(device_config.get_num_layers());
m_next_step_rotated_block_logical_indices_per_sequence.resize(device_config.get_num_layers());
m_cache_rotation_calculator = std::make_shared<CacheRotationCalculator>(m_scheduler->get_block_size(),
// TODO (vshampor): LUT size equal to max cache size in tokens
// is overkill - find a way to pass the max sequence length instead
m_scheduler->get_block_size() * scheduler_config.num_kv_blocks,
device_config.get_head_size());
} else {
m_model_runner = std::make_shared<ModelRunner>(infer_request, m_scheduler->get_block_size(), device_config.get_num_layers());
}

m_sampler = std::make_shared<Sampler>(m_tokenizer);
m_sampler->set_seed(m_generation_config.rng_seed);

Expand Down Expand Up @@ -196,6 +218,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
// evict unimportant blocks from KV cache, if requested
if (sched_config.use_cache_eviction) {
maybe_evict_cache_blocks(sched_config);
m_model_runner->set_cache_rotation_data(std::move(m_next_step_rotation_coefficients),
std::move(m_next_step_rotated_block_logical_indices_per_sequence));
}

#ifdef DEBUG_CACHE_STATE_DUMP
Expand Down Expand Up @@ -378,19 +402,60 @@ float ContinuousBatchingPipeline::ContinuousBatchingImpl::_get_current_running_a
void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_blocks(const SchedulerConfig& sched_config) {
std::unordered_map<SequenceGroup::Ptr, size_t> seq_group_to_num_blocks_evicted_map;
auto sequence_attention_scores = m_model_runner->get_last_attention_scores();

OPENVINO_ASSERT(!sequence_attention_scores.empty());
size_t num_decoder_layers = sequence_attention_scores.begin()->second.size();
std::vector<size_t> num_blocks_to_rotate_for_each_layer(num_decoder_layers, 0);
size_t head_size = m_cache_rotation_calculator->get_head_size();

// necessary since we move from these members during previous steps
m_next_step_rotation_coefficients.clear();
m_next_step_rotated_block_logical_indices_per_sequence.clear();
m_next_step_rotated_block_logical_indices_per_sequence.resize(num_decoder_layers);

for (auto& seq_id_and_attention_scores : sequence_attention_scores) {
auto seq_id = seq_id_and_attention_scores.first;
const auto& attention_scores_for_all_decoder_layers = seq_id_and_attention_scores.second;
if (m_seq_group_id_to_cache_eviction_algo_map.find(seq_id) == m_seq_group_id_to_cache_eviction_algo_map.end()) {
auto num_decoder_layers = attention_scores_for_all_decoder_layers.size();

m_seq_group_id_to_cache_eviction_algo_map[seq_id] = CacheEvictionAlgorithm(sched_config.cache_eviction_config, m_scheduler->get_block_size(), num_decoder_layers);
}
auto& cache_eviction_algo = m_seq_group_id_to_cache_eviction_algo_map[seq_id];

cache_eviction_algo.register_new_token_scores(attention_scores_for_all_decoder_layers);
auto logical_blocks_to_evict = cache_eviction_algo.evict_logical_blocks();


for (size_t layer_idx = 0; layer_idx < logical_blocks_to_evict.size(); layer_idx++) {
if (logical_blocks_to_evict[layer_idx].empty()) {
continue;
}
size_t num_blocks_before_eviction = m_scheduler->get_block_tables(seq_id)[layer_idx].size();
auto rotation_multipliers =
m_cache_rotation_calculator->get_rotation_coefficients(logical_blocks_to_evict[layer_idx],
num_blocks_before_eviction);
for (size_t i = 0; i < rotation_multipliers.size(); i++) {
const auto& block_rotation_data = rotation_multipliers[i];
const auto& rotation_multipliers_cos = block_rotation_data.cosines;
const auto& rotation_multipliers_sin = block_rotation_data.sines;
OPENVINO_ASSERT(rotation_multipliers_cos.size() == rotation_multipliers_sin.size());
OPENVINO_ASSERT(rotation_multipliers_cos.size() == m_scheduler->get_block_size());

m_next_step_rotated_block_logical_indices_per_sequence[layer_idx][seq_id].push_back(block_rotation_data.logical_block_idx);

// Fill the store tensor with rotation coefficient data - cos and sin coefficients are each contiguous, cos goes first
size_t block_offset = num_blocks_to_rotate_for_each_layer[layer_idx] * m_scheduler->get_block_size() * head_size;
auto rotation_multipliers_tensor_data = m_rotation_coefficient_stores[layer_idx].data<float>() + block_offset;
for (size_t tok_idx = 0; tok_idx < rotation_multipliers_cos.size(); tok_idx++) {
size_t position_offset = head_size * tok_idx;
for (size_t embedding_pair_idx = 0; embedding_pair_idx < head_size / 2; embedding_pair_idx++) {
rotation_multipliers_tensor_data[position_offset + embedding_pair_idx] = rotation_multipliers_cos[tok_idx][embedding_pair_idx];
rotation_multipliers_tensor_data[position_offset + embedding_pair_idx + head_size / 2] = rotation_multipliers_sin[tok_idx][embedding_pair_idx];
}
}
num_blocks_to_rotate_for_each_layer[layer_idx] += 1;
}
}

m_scheduler->free_blocks_from_sequence(seq_id, logical_blocks_to_evict);

auto seq_group_ptr_it = std::find_if(m_requests.begin(), m_requests.end(), [seq_id](const SequenceGroup::Ptr& val) { return val->has_sequence_with_id(seq_id); });
Expand All @@ -405,12 +470,19 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::maybe_evict_cache_block
}

}

// Select the previously filled rotation coefficients from the store tensor
for (size_t i = 0; i < num_decoder_layers; i++) {
m_next_step_rotation_coefficients.emplace_back(m_rotation_coefficient_stores[i], ov::Coordinate{0}, ov::Coordinate{num_blocks_to_rotate_for_each_layer[i] * m_scheduler->get_block_size() * head_size});
}

for (const auto& seq_group_ptr_and_num_blocks_evicted : seq_group_to_num_blocks_evicted_map) {
// Assuming that the evicted blocks are always full (since they by design are only selected from intermediate-age blocks)
auto seq_group_ptr = seq_group_ptr_and_num_blocks_evicted.first;
auto num_blocks_evicted = seq_group_ptr_and_num_blocks_evicted.second;
seq_group_ptr->register_token_eviction(num_blocks_evicted * m_scheduler->get_block_size());
}

}

void ContinuousBatchingPipeline::ContinuousBatchingImpl::_fill_prompt_log_probs(std::vector<SequenceGroup::Ptr>& sequence_groups, ov::Tensor& logits) {
Expand Down
16 changes: 14 additions & 2 deletions src/cpp/src/continuous_batching_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,22 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc

static const size_t AVG_CACHE_USAGE_WINDOW_SIZE_IN_STEPS = 1000;
std::deque<float> m_previous_step_cache_usages;

// flag to enable validation mode for sampler
bool m_is_validation_mode_enabled = false;

// Pre-allocated per-layer storages for the per-token cache re-rotation coefficients used in cache eviction case
std::vector<ov::Tensor> m_rotation_coefficient_stores;

// Per-layer ROI tensors, reusing storage from the pre-allocated tensors above, that actually represent the
// re-rotation coefficients to be sent to the proper model inputs at the *next* pipeline step.
std::vector<ov::Tensor> m_next_step_rotation_coefficients;

using SeqIdToRotatedLogicalBlocksMap = std::map<size_t, std::vector<size_t>>;
std::vector<SeqIdToRotatedLogicalBlocksMap> m_next_step_rotated_block_logical_indices_per_sequence;

std::shared_ptr<ov::genai::CacheRotationCalculator> m_cache_rotation_calculator;

#ifdef DEBUG_CACHE_STATE_DUMP
size_t step_count = 0;
#endif
Expand Down Expand Up @@ -86,4 +98,4 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
const std::vector<GenerationConfig>& sampling_params,
const StreamerVariant& streamer) override;
};
}
}
4 changes: 4 additions & 0 deletions src/cpp/src/device_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ class DeviceConfig {
return m_num_decoder_layers;
}

size_t get_head_size() const {
return m_head_size;
}

ov::Shape get_key_cache_shape() const {
OPENVINO_ASSERT(!m_key_cache_shape.empty());
return m_key_cache_shape;
Expand Down
Loading

0 comments on commit 0d60110

Please sign in to comment.