Skip to content

Commit

Permalink
Distributed kv cache allocatio to prevent memory peak
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonbok committed Feb 6, 2024
1 parent abc2899 commit 08cb8d0
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ struct network {
const ov::intel_gpu::VariableStateInfo& get_variable_info(const std::string &variable_id) const;
const ov::intel_gpu::VariablesMap& get_variables() const;
const ov::intel_gpu::VariablesInfoMap& get_variables_info() const;
std::vector<primitive_id> get_kv_cache_ids() const { return kv_cache_ids; }

const ExecutionConfig& get_config() const { return _config; }

Expand Down Expand Up @@ -255,6 +256,7 @@ struct network {

ov::intel_gpu::VariablesMap _variables_states;
ov::intel_gpu::VariablesInfoMap _variables_state_info;
std::vector<primitive_id> kv_cache_ids;

program::primitives_info _prims_info;
std::map<primitive_id, primitive_id> _ext_id_mapping;
Expand Down
8 changes: 5 additions & 3 deletions src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ class typed_primitive_inst<kv_cache> : public typed_primitive_inst_base<kv_cache

static std::string to_string(const kv_cache_node& node);

static int32_t get_prealloc_iter_num() {
return 128;
}
// Distribute prealloc period to prevent memory peak
int32_t get_prealloc_iter_num() override;

static void update_pad(layout& l, int64_t pad, int64_t sequence_axis_legacy) {
const auto& dyn_pad_dims = l.data_padding.get_dynamic_pad_dims();
Expand Down Expand Up @@ -82,6 +81,9 @@ class typed_primitive_inst<kv_cache> : public typed_primitive_inst_base<kv_cache

typed_primitive_inst(network& network, const kv_cache_node& desc);
typed_primitive_inst(network& network) : parent(network), memory_state::variable("") {}

private:
size_t kv_cache_id = 0;
};

using kv_cache_inst = typed_primitive_inst<kv_cache>;
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/primitive_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ class primitive_inst {

virtual void update_output_memory() {}

virtual int32_t get_prealloc_iter_num() { return -1; }

protected:
primitive_inst(network& network, program_node const& node, bool allocate_memory);

Expand Down
4 changes: 4 additions & 0 deletions src/plugins/intel_gpu/src/graph/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ GPU_DEFINE_PRIMITIVE_TYPE_ID(kv_cache)
kv_cache_inst::typed_primitive_inst(network& network, const kv_cache_node& node) :
parent{network, node, false},
memory_state::variable{node.get_primitive()->variable_info.variable_id} {
kv_cache_id = network.get_kv_cache_ids().size();
}

layout kv_cache_inst::calc_output_layout(const kv_cache_node& node, kernel_impl_params const& impl_param) {
Expand Down Expand Up @@ -55,4 +56,7 @@ std::string kv_cache_inst::to_string(const kv_cache_node& node) {
return primitive_description.str();
}

int32_t kv_cache_inst::get_prealloc_iter_num() {
return 128 + kv_cache_id % 64;
}
} // namespace cldnn
4 changes: 4 additions & 0 deletions src/plugins/intel_gpu/src/graph/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "assign_inst.h"
#include "read_value_inst.h"
#include "reshape_inst.h"
#include "kv_cache_inst.h"
#include "program_helpers.h"
#include "to_string_utils.h"
#include "kernels_cache.hpp"
Expand Down Expand Up @@ -1329,6 +1330,9 @@ void network::allocate_primitive_instance(program_node const& node) {
if (node.is_type<data>())
_data_outputs.push_back(inst);
}
if (node.is_type<kv_cache>()) {
kv_cache_ids.push_back(node.id());
}
if (auto state_prim = std::dynamic_pointer_cast<memory_state::variable>(inst)) {
set_variables_state_info(state_prim->variable_id(), node.get_output_layout(0), state_prim->get_user_specified_type());
}
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ event::ptr primitive_inst::realloc_if_needed() {

auto current_shape = updated_layout.get_shape();
std::pair<bool, ov::Shape> prealloc_info;
int32_t tmp_prealloc_count = _node->is_type<kv_cache>() ? kv_cache_inst::get_prealloc_iter_num() : -1;
int32_t tmp_prealloc_count = get_prealloc_iter_num();
GPU_DEBUG_IF(debug_config->mem_preallocation_params.is_initialized) {
// If debug config is set, repsect the config most
tmp_prealloc_count = -1;
Expand Down

0 comments on commit 08cb8d0

Please sign in to comment.