Skip to content

Commit

Permalink
[GPU] scatter element update fix and dynamic impl (#29087)
Browse files Browse the repository at this point in the history
### Details:
- added input idx calculation because output idx was different due to
output padding
 - part of accuracy fix of model FasterRCNN_ResNet50_FPN_V2.onnx

### Tickets:
 - small part of 101294
  • Loading branch information
michal-miotk authored Feb 21, 2025
1 parent 587360c commit b6fd0f9
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,18 @@ struct scatter_elements_update_impl : typed_primitive_impl_ocl<scatter_elements_
return make_deep_copy<scatter_elements_update_impl, kernel_params_t>(*this);
}

static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param) {
void load(BinaryInputBuffer& ib) override {
parent::load(ib);
if (is_dynamic() && _kernel_data.kernelName.length() != 0) {
auto& kernel_selector = kernel_selector_t::Instance();
auto kernel_impl = kernel_selector.GetImplementation(_kernel_data.kernelName);
kernel_impl->GetUpdateDispatchDataFunc(_kernel_data);
}
}

static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
const auto& primitive = impl_param.typed_desc<scatter_elements_update>();
auto params = get_default_params<kernel_selector::scatter_elements_update_params>(impl_param);
auto params = get_default_params<kernel_selector::scatter_elements_update_params>(impl_param, is_shape_agnostic);

params.axis = convert_axis(primitive->axis, impl_param.get_input_layout(0).get_rank());
params.mode = convert_reduction_mode(primitive->mode);
Expand All @@ -83,6 +92,16 @@ struct scatter_elements_update_impl : typed_primitive_impl_ocl<scatter_elements_
params.inputs.push_back(convert_data_tensor(impl_param.get_input_layout(2)));
return params;
}

void update_dispatch_data(const kernel_impl_params& impl_param) override {
// If model loaded from cache, params are not initialized, so we create a new object and reuse it in the future
if (_kernel_data.params == nullptr) {
_kernel_data.params = std::make_shared<kernel_params_t>(get_kernel_params(impl_param, true));
}

update_shapes(*_kernel_data.params, impl_param);
(_kernel_data.update_dispatch_data_func)(*_kernel_data.params, _kernel_data);
}
};

std::unique_ptr<primitive_impl> ScatterElementsUpdateImplementationManager::create_impl(const program_node& node, const kernel_impl_params& params) const {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "predicates.hpp"
#include "registry.hpp"
#include "intel_gpu/primitives/scatter_elements_update.hpp"
#include "primitive_inst.h"
Expand All @@ -18,6 +19,7 @@ using namespace cldnn;
const std::vector<std::shared_ptr<cldnn::ImplementationManager>>& Registry<scatter_elements_update>::get_implementations() {
static const std::vector<std::shared_ptr<ImplementationManager>> impls = {
OV_GPU_CREATE_INSTANCE_OCL(ocl::ScatterElementsUpdateImplementationManager, shape_types::static_shape)
OV_GPU_CREATE_INSTANCE_OCL(ocl::ScatterElementsUpdateImplementationManager, shape_types::dynamic_shape)
};

return impls;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define GET_INDICES_INDEX(idx_order) INPUT1_GET_INDEX(idx_order)
#define GET_UPDATES_INDEX(idx_order) INPUT2_GET_INDEX(idx_order)
#define GET_OUTPUT_INDEX(idx_order) OUTPUT_GET_INDEX(idx_order)
#define GET_INPUT_INDEX(idx_order) INPUT0_GET_INDEX(idx_order)
#if OUTPUT_DIMS == 4
#define ORDER b,f,y,x
#define IDX_ORDER idx_b,idx_f,idx_y,idx_x
Expand Down Expand Up @@ -63,7 +64,8 @@
}
#endif

KERNEL(scatter_elements_update_ref)(const __global INPUT0_TYPE* data,
KERNEL(scatter_elements_update_ref)(OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* data,
const __global INPUT1_TYPE* indices,
const __global INPUT2_TYPE* updates,
__global OUTPUT_TYPE* output
Expand All @@ -76,7 +78,6 @@ KERNEL(scatter_elements_update_ref)(const __global INPUT0_TYPE* data,
const uint dim0 = get_global_id(0);
const uint dim1 = get_global_id(1);
const uint dim2 = get_global_id(2);

#ifndef IS_SECOND_ITER // First kernel
#if OUTPUT_DIMS == 4
const uint x = dim0;
Expand All @@ -97,16 +98,15 @@ KERNEL(scatter_elements_update_ref)(const __global INPUT0_TYPE* data,
const uint f = dim2 % OUTPUT_FEATURE_NUM;
const uint b = dim2 / OUTPUT_FEATURE_NUM;
#endif

const uint input_idx = GET_INPUT_INDEX(ORDER);
const uint output_idx = GET_OUTPUT_INDEX(ORDER);
INPUT0_TYPE val = data[output_idx];
INPUT0_TYPE val = data[input_idx];
#if HAS_FUSED_OPS
FUSED_OPS_FIRST_KERNEL;
output[output_idx] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_FIRST_KERNEL);
#else
output[output_idx] = ACTIVATION(val, ACTIVATION_PARAMS);
#endif

#else // Second kernel
#if OUTPUT_DIMS == 4
const uint idx_x = dim0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ ParamsKey ScatterElementsUpdateKernelRef::GetSupportedKey() const {
k.EnableTensorPitches();
k.EnableBatching();
k.EnableDifferentTypes();
k.EnableDynamicShapesSupport();
return k;
}

Expand Down Expand Up @@ -162,6 +163,20 @@ bool ScatterElementsUpdateKernelRef::Validate(const Params& p) const {
return true;
}

void ScatterElementsUpdateKernelRef::GetUpdateDispatchDataFunc(KernelData& kd) const {
kd.update_dispatch_data_func = [this](const Params& params, KernelData& kd) {
const auto& prim_params = static_cast<const scatter_elements_update_params&>(params);
OPENVINO_ASSERT(kd.kernels.size() == 2, "[GPU] Invalid kernels size for update dispatch data func");

for (size_t i = 0; i < 2; ++i) {
auto dispatchData = SetDefault(prim_params, i == 1);
kd.kernels[i].params.workGroups.global = dispatchData.gws;
kd.kernels[i].params.workGroups.local = dispatchData.lws;
kd.kernels[i].skip_execution = KernelData::SkipKernelExecution(prim_params);
}
};
}

KernelsData ScatterElementsUpdateKernelRef::GetKernelsData(const Params& params) const {
if (!Validate(params)) {
return {};
Expand All @@ -171,6 +186,8 @@ KernelsData ScatterElementsUpdateKernelRef::GetKernelsData(const Params& params)
scatter_elements_update_params& newParams = *static_cast<scatter_elements_update_params*>(kd.params.get());
auto cldnn_jit = GetJitConstants(newParams);

GetUpdateDispatchDataFunc(kd);

for (int i = 0; i < 2; i++) {
auto dispatchData = SetDefault(newParams, (i == 1));
auto entry_point = GetEntryPoint(kernelName, newParams.layerID, params, i);
Expand All @@ -182,7 +199,8 @@ KernelsData ScatterElementsUpdateKernelRef::GetKernelsData(const Params& params)

clKernelData& kernel = kd.kernels[i];

FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 3, GetFusedPrimitiveInputsCount(params));
FillCLKernelData(kernel, dispatchData, params.engineInfo, kernelName, jit, entry_point, "", false, false, 3, GetFusedPrimitiveInputsCount(params), 1,
params.is_shape_agnostic);
}

return {kd};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ class ScatterElementsUpdateKernelRef : public KernelBaseOpenCL {

protected:
bool Validate(const Params& p) const override;
void GetUpdateDispatchDataFunc(KernelData& kd) const override;
};
} // namespace kernel_selector

0 comments on commit b6fd0f9

Please sign in to comment.