Skip to content

Commit

Permalink
[GPU] Fix U8 weights identification as a quantization case (openvinot…
Browse files Browse the repository at this point in the history
…oolkit#24488)

### Details:
- Fix the issue where `conv_params.quantization` was not correctly
identified as a quantization case when using U8 weights data (which is
allowed by ConvertConvolutionToInternal transformation pass)

### Tickets:
 - 139740
  • Loading branch information
sshlyapn authored May 21, 2024
1 parent 24137ff commit 248c841
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/plugins/intel_gpu/src/graph/impls/ocl/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ struct convolution_impl : typed_primitive_impl_ocl<convolution> {

if ((impl_param.input_layouts[0].data_type == data_types::u8 ||
impl_param.input_layouts[0].data_type == data_types::i8) &&
impl_param.input_layouts[1].data_type == data_types::i8) {
(impl_param.input_layouts[1].data_type == data_types::i8 ||
impl_param.input_layouts[1].data_type == data_types::u8)) {
if (!primitive->weights_zero_points.empty() && !primitive->activations_zero_points.empty()) {
conv_params.quantization = kernel_selector::QuantizationType::ASYMMETRIC_DATA_AND_WEIGHTS;
} else if (!primitive->weights_zero_points.empty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,60 @@ const std::vector<LayerTestsDefinitions::ConvolutionQDqTransformationParam> para
ov::element::u8.get_type_name()
},

// Actual:
//
// Constant
// | Constant Constant Constant Constant
// | /FP32 /FP32 /FP32 /FP32
// FakeQuantize FakeQuantize
// |FP32 |FP32
// | |
// Convert Constant Convert
// |U8 |U8 |U8
// | | |
// Convert Convert Convert Constant
// \FP32 /FP32 |FP32 /U8
// \ / | /
// Subtract Constant Subtract Constant
// \FP32 /FP32 |FP32 /FP32
// \ / | /
// Multiply Multiply
// \FP32 /FP32
// \ /
// Convolution
//
// Transformed:
//
// Parameter Constant Constant
// \U8 /U8 /U8
// \ / /
// Subtract Subtract
// \FP32 /FP32
// \ /
// Convolution Constant
// \FP32 /FP32
// \ /
// Multiply
{
{ 256ul, {{ 1, 1, 1, 1 }}, { -12.8f }, { 12.7f }, { 0.f }, { 255.f }, ov::element::f32 },
{ ov::element::u8, false },
{
{ov::element::f32},
{ {128.f}, ov::element::f32, {}, false, 1ul, ov::element::u8, true },
{ {0.1f}, ov::element::f32, {}, false }
},
{ std::vector<float>{ 15.f }, ov::element::f32},
{ 256ul, ov::Shape({ 1, 1, 1, 1 }), { 0.f }, { 25.5f }, { 0.f }, { 255.f }, ov::element::f32 },
{ ov::element::u8, false },
{
{ ov::element::f32, false },
{ {0.3f}, ov::element::f32, {}, false, 1ul, ov::element::u8, true },
{ {0.2f}, ov::element::f32, {}, false }
},
"Convolution",
ov::element::u8.get_type_name()
},

// Actual:
//
// Constant
Expand Down

0 comments on commit 248c841

Please sign in to comment.