Skip to content

Commit

Permalink
[GPU][PoC] shape_info common fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dnkurek committed Dec 20, 2024
1 parent 940b0cb commit 0c70e6e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ struct network {
bool _is_dynamic = false;
bool _enable_profiling = false;
bool _reset_arguments;
memory::ptr _ptr;

std::unordered_map<primitive_id, std::shared_ptr<primitive_inst>> _primitives;
std::vector<shared_mem_type> _in_out_shared_mem_types;
Expand Down
9 changes: 4 additions & 5 deletions src/plugins/intel_gpu/src/graph/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,19 +289,18 @@ void network::preallocate_shape_info_buffers() {
}

std::cout << "Sum of shape elements " << sum << std::endl;

if(sum) {
auto& eng = get_engine();
auto ptr = eng.allocate_memory(layout{{sum * 128}, data_types::i32, format::bfyx}, false);
//auto ptr = eng.allocate_memory(layout{{sum}, data_types::i32, format::bfyx}, false);
_ptr = eng.allocate_memory(layout{{sum}, data_types::i32, format::bfyx}, false);
int new_sum = 0;
for (auto const& prim : _exec_order) {
auto& node = prim->get_node();
int64_t shape_elements = node.get_total_shape_info_size();
if(shape_elements == 0) continue;
auto new_ptr = ptr;
auto new_mem = eng.reinterpret_subbuffer(*ptr, layout{{shape_elements}, data_types::i32, format::bfyx}, new_sum);
auto new_mem = eng.reinterpret_subbuffer(*_ptr, layout{{shape_elements}, data_types::i32, format::bfyx}, new_sum);
prim->set_shape_info_memory_ptr(new_mem);
new_sum += 4096;
new_sum += shape_elements * 4;
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/plugins/intel_gpu/src/runtime/ocl/ocl_ext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,21 +932,21 @@ class UsmMemory {

void allocateHost(size_t size) {
cl_int error = CL_SUCCESS;
auto ptr = _usmHelper.allocate_host(nullptr, size, 0, &error);
auto ptr = _usmHelper.allocate_host(nullptr, size, 1024 * 64 * 16, &error);
_check_error(size, ptr, error, "Host");
_allocate(ptr);
}

void allocateShared(size_t size) {
cl_int error = CL_SUCCESS;
auto ptr = _usmHelper.allocate_shared(nullptr, size, 0, &error);
auto ptr = _usmHelper.allocate_shared(nullptr, size, 1024 * 64 * 16, &error);
_check_error(size, ptr, error, "Shared");
_allocate(ptr);
}

void allocateDevice(size_t size) {
cl_int error = CL_SUCCESS;
auto ptr = _usmHelper.allocate_device(nullptr, size, 0, &error);
auto ptr = _usmHelper.allocate_device(nullptr, size, 1024 * 64 * 16, &error);
_check_error(size, ptr, error, "Device");
_allocate(ptr);
}
Expand Down

0 comments on commit 0c70e6e

Please sign in to comment.