Skip to content

Commit

Permalink
[GPU] Update gemm_tiled_opt dynamic support
Browse files Browse the repository at this point in the history
Revert some earlier PR 28252 and PR 28764.  These PR causing memory                                                                                                                                                                                                     increase as well as 2X performance drop.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        Instead to fix the issue primitive_inst, update gemm_kernel_tiled_opt to                                                                                                                                                                                                reject to use this kernel if the inputs has dynamic inputs but output is                                                                                                                                                                                                static.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         No unit tests created for this as PR 28252 and 28764 already include                                                                                                                                                                                                    tests to cover this.

CVS-162230
  • Loading branch information
clee30 committed Feb 22, 2025
1 parent f77ef0f commit b9c4709
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 24 deletions.
25 changes: 1 addition & 24 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2669,34 +2669,11 @@ bool primitive_inst::is_valid_fusion() const {
const auto& outer_dep = _deps[outer_dep_idx];

const auto& outer_dep_pshape = outer_dep.first->_impl_params->get_output_layout().get_partial_shape();
size_t outer_dep_pshape_count = outer_dep_pshape.is_static() ? ov::shape_size(outer_dep_pshape.to_shape()) : 0;
auto merged_shape = out_pshape;
bool can_broadcast = true;
if (fd.is_type<eltwise>())
can_broadcast = ov::PartialShape::broadcast_merge_into(merged_shape, outer_dep_pshape, fd.typed_desc<eltwise>()->broadcast_spec);

// Check if broadcast happens more than single axis.
// Current gemm_tiled_opt kernel FUSED_OP_LOAD macro cannot support broadcast on dynamic dimension.
if (_node->is_type<gemm>() && can_broadcast == true && merged_shape.rank().get_length() >= outer_dep_pshape.rank().get_length() &&
outer_dep_pshape_count != 1) {
uint8_t broadcast_more_than_single_axis = 0;
auto updated_outer_dep_pshape = ov::PartialShape(outer_dep_pshape);

// Update outer_dep_pshape to merged_shape rank
if (merged_shape.rank().get_length() > outer_dep_pshape.rank().get_length()) {
updated_outer_dep_pshape.insert(updated_outer_dep_pshape.begin(),
merged_shape.rank().get_length() - outer_dep_pshape.rank().get_length(), ov::Dimension(1));
}

for (int64_t i = 0; i < merged_shape.rank().get_length(); i++) {
if (merged_shape[i] != updated_outer_dep_pshape[i])
broadcast_more_than_single_axis++;
}

if (broadcast_more_than_single_axis > 1)
can_broadcast = false;
}

#ifdef ENABLE_ONEDNN_FOR_GPU
// WA for OneDNN binary add fusions: we need to broadcast batch dimension to avoid situation with
// batch dimension mismatch in OneDNN tensor descriptors as follow:
Expand All @@ -2714,7 +2691,7 @@ bool primitive_inst::is_valid_fusion() const {
cldnn::format::dimension(data_layout.format),
false);

if (gemm_dims[0] != data_dims[0] && outer_dep_pshape_count != 1)
if (gemm_dims[0] != data_dims[0])
return false;
} else if (_node->is_type<fully_connected>() && _node->get_preferred_impl_type() == impl_types::onednn) {
const auto& fc_layout = _impl_params->get_output_layout();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,9 @@ bool GemmKernelTiledOpt::Validate(const Params& params) const {
if (gmm_params.has_dynamic_inputs() && !gmm_params.is_shape_agnostic)
return false;

if (gmm_params.has_dynamic_inputs() && !gmm_params.has_dynamic_outputs())
return false;

for (size_t i = 1; i < num_inputs; i++)
if (gmm_params.inputs[0].GetDType() != gmm_params.inputs[i].GetDType())
return false;
Expand Down

0 comments on commit b9c4709

Please sign in to comment.