From 498b793e8fb8c4d9e4ec9e1af3fcadac6ae87bf4 Mon Sep 17 00:00:00 2001 From: Kurek Date: Fri, 20 Dec 2024 10:53:50 +0100 Subject: [PATCH] [GPU][PoC] shape_info common fix --- .../intel_gpu/include/intel_gpu/graph/network.hpp | 1 + src/plugins/intel_gpu/src/graph/network.cpp | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp b/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp index df30f17997ef09..fe78f550df5b46 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp @@ -216,6 +216,7 @@ struct network { bool _is_dynamic = false; bool _enable_profiling = false; bool _reset_arguments; + memory::ptr _ptr; std::unordered_map> _primitives; std::vector _in_out_shared_mem_types; diff --git a/src/plugins/intel_gpu/src/graph/network.cpp b/src/plugins/intel_gpu/src/graph/network.cpp index c12c6b677424ad..60672b32850f0a 100644 --- a/src/plugins/intel_gpu/src/graph/network.cpp +++ b/src/plugins/intel_gpu/src/graph/network.cpp @@ -289,19 +289,18 @@ void network::preallocate_shape_info_buffers() { } std::cout << "Sum of shape elements " << sum << std::endl; + if(sum) { auto& eng = get_engine(); - auto ptr = eng.allocate_memory(layout{{sum * 128}, data_types::i32, format::bfyx}, false); - //auto ptr = eng.allocate_memory(layout{{sum}, data_types::i32, format::bfyx}, false); + _ptr = eng.allocate_memory(layout{{sum}, data_types::i32, format::bfyx}, false); int new_sum = 0; for (auto const& prim : _exec_order) { auto& node = prim->get_node(); int64_t shape_elements = node.get_total_shape_info_size(); if(shape_elements == 0) continue; - auto new_ptr = ptr; - auto new_mem = eng.reinterpret_subbuffer(*ptr, layout{{shape_elements}, data_types::i32, format::bfyx}, new_sum); + auto new_mem = eng.reinterpret_subbuffer(*_ptr, layout{{shape_elements}, data_types::i32, format::bfyx}, new_sum); prim->set_shape_info_memory_ptr(new_mem); - new_sum += 4096; + new_sum += shape_elements * 4; } } }