Skip to content

Commit

Permalink
Specification for UniformQuantizeOp and UniformDequantizeOp (#1496)
Browse files Browse the repository at this point in the history
fixes #531
fixes #530 

## Summary 
The PR proposes the specification for `uniform.quantize` and
`uniform.dequantize` ops.

The specification of `uniform.quantize` also captures the
re-quantization conversions from quantized tensor to quantized tensors.

Please let me know your feedback. 


### Working notes on following the `reference_checklist.md` and
`spec_checkilist.md`

#### uniform_dequantize

We have the following constraints from the spec  

```
(I1)  `operand` is a  quantized tensor
(C1) `shape(operand) = shape(result)`.
(C2) `element_type(result) = expressed_type(operand)`.
```

These constraints will be comprehensively covered by the following
tests:

```
I1: a) `operand: quantized tensors`. (Covered by ODS).
C1: a) `shape(operand) != shape(result)`. (Covered by ODS)
C2: a) `element_type(result) != expressed_type(operand)`.
```

If we drop the "Covered by ODS" pieces, this will leave us with the
following test cases:

```
C2: a) `element_type(result) != expressed_type(operand)`.
```

We already has a type inference test to cover the above. 


#### uniform_quantize

We have the following constraints from the spec  

```
(I1)  `operand: tensor of floating-point or quantized type`.
(C1) `shape(operand) = shape(result)`.
(C2) `expressed_type(result) = is_float(operand) ? element_type(operand) :
  expressed_type(operand)`.
```

These constraints will be comprehensively covered by the following
tests:

```
I1: a) `operand: quantized tensors`. (Covered by ODS).
C1: a) `shape(operand) != shape(result)`. (Covered by ODS)
C2: a) if_float(operand): `expressed_type(result) != element_type(operand)`.
      b) if_quantized(operand): `expressed_type(result) !=  expressed_type(operand)`.
```

If we drop the "Covered by ODS" pieces, this will leave us with the
following test cases:

```
C2: a) if_float(operand): `expressed_type(result) != element_type(operand)`.
      b) if_quantized(operand): `expressed_type(result) !=  expressed_type(operand)`.
```

The above will be covered as part of
#1603.
  • Loading branch information
sdasgup3 authored Jun 19, 2023
1 parent 556db3f commit 8993ad5
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 38 deletions.
127 changes: 119 additions & 8 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,6 @@ in StableHLO programs. In the meanwhile, here is the list of these operations:
`dynamic_gather`, `dynamic_iota`, `dynamic_pad`, `dynamic_reshape`,
`real_dynamic_slice`, `set_dimension_size`
([#8](https://github.com/openxla/stablehlo/issues/8)).
* "Quantization" category of StableHLO operations - they were bootstrapped from
MHLO, but we haven't specced them yet: `uniform_quantize`
([#531](https://github.com/openxla/stablehlo/issues/531)) and
`uniform_dequantize`
([#530](https://github.com/openxla/stablehlo/issues/530)).
* Shape computations, including `arith`, `shape` and `tensor` operations
([#8](https://github.com/openxla/stablehlo/issues/8)).

Expand Down Expand Up @@ -5535,6 +5530,87 @@ Produces a `result` tuple from values `val`.
// %result: ([1.0, 2.0], (3))
```

### uniform_dequantize

#### Semantics

Performs element-wise conversion of quantized tensor `operand` to a
floating-point tensor `result` according to the quantization parameters defined
by the `operand` type.

More formally, `result = dequantize(operand)`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-----------|------------------|-------------|
| (I1) | `operand` | quantized tensor | (C1), (C2) |

#### Outputs

| Name | Type | Constraints |
|----------|-------------------------------|-------------|
| `result` | tensor of floating-point type | (C1), (C2) |

#### Constraints

* (C1) `shape(operand) = shape(result)`.
* (C2) `element_type(result) = expressed_type(operand)`.

#### Examples

```mlir
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
```

### uniform_quantize

#### Semantics

Performs element-wise conversion of floating-point tensor or quantized tensor
`operand` to a quantized tensor `result` according to the quantization
parameters defined by the `result` type.

More formally,

* If `is_float(operand)`:
* `result = quantize(operand, type(result))`.
* If `is_quantized(operand)`:
* `float_result = dequantize(operand)`.
* `result = quantize(float_result, type(result))`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-----------|--------------------------------------------|-------------|
| (I1) | `operand` | tensor of floating-point or quantized type | (C1), (C2) |

#### Outputs

| Name | Type | Constraints |
|----------|------------------|-------------|
| `result` | quantized tensor | (C1), (C2) |

#### Constraints

* (C1) `shape(operand) = shape(result)`.
* (C2) `expressed_type(result) = is_float(operand) ? element_type(operand) :
expressed_type(operand)`.
#### Examples

```mlir
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
```

### while

#### Semantics
Expand Down Expand Up @@ -6136,6 +6212,41 @@ def baseline_type(x: Value | Placeholder | Type) -> Type:
return baseline_element_type(type(x))
```

* `dequantize` is defined on quantized tensor types and turns them into
floating-point tensor types. This happens via converting quantized elements
which represent integer values of the storage type into corresponding
floating-point values of the expressed type using the zero point and scale
associated with the quantized element type. At the moment, this function only
works for per-tensor quantization. Per-axis quantization is work in progress
([#1574](https://github.com/openxla/stablehlo/issues/1574)).

```python
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - zero_point(x_storage)
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * scale(x)
```

* `quantize` is defined on floating-point tensor types and turns them into
quantized tensor types. This happens via converting floating-point values
of the expressed type into corresponding integer values of the storage type
using the zero point and scale associated with the quantized element type.
At the moment, this function only works for per-tensor quantization. Per-axis
quantization is work in progress
([#1574](https://github.com/openxla/stablehlo/issues/1574)).

```python
def quantize(x: Value, type: Type) -> Value:
assert is_float(x) and is_quantized(type)
x_expressed_rounded = round_nearest_even(x / scale(type))
x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
x_storage_add = x_storage_rounded + zero_point(type)
x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
return bitcast_convert(x_storage, type)
```

* `dequantize_op_quantize` is used to specify element-wise computations on
quantized tensors. It dequantizes, i.e. turns quantized elements into their
expressed types, then performs an operation, and then quantizes, i.e. turns
Expand All @@ -6147,10 +6258,10 @@ works for per-tensor quantization. Per-axis quantization is work in progress
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = [(x - zero_point(x)) * scale(x) for x in inputs]

float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
rounded_result = round_nearest_even(float_result / scale(output_type))
return clamp(storage_min(output_type), rounded_result, storage_max(output_type))
return quantize(float_result, output_type)
```

#### Grid computations
Expand Down
4 changes: 2 additions & 2 deletions docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ one of the following tracking labels.
| triangular_solve | yes | revisit | yes | no | revisit |
| tuple | yes | yes | yes | yes | no |
| unary_einsum | no | revisit | no | yes | revisit |
| uniform_dequantize | no | yes\* | yes\* | yes | no |
| uniform_quantize | no | yes\* | infeasible | yes | no |
| uniform_dequantize | yes | yes | yes | yes | no |
| uniform_quantize | yes | revisit | infeasible | yes | no |
| while | yes | revisit | yes | revisit | yes |
| xor | yes | yes | yes | yes | yes |
31 changes: 16 additions & 15 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3055,38 +3055,39 @@ def StableHLO_RngBitGeneratorOp : StableHLO_Op<"rng_bit_generator", [Pure]> {

// TODO(b/230662142): Implement unknown scales/zero_point cases.
def StableHLO_UniformQuantizeOp : StableHLO_UnaryElementwiseOp<"uniform_quantize",
[Pure], TensorOf<[F32, BF16, HLO_QuantizedInt]>,
HLO_QuantizedIntTensor> {
[Pure], TensorOf<[HLO_Float, HLO_QuantizedInt]> /*uniform_quantize_i1*/,
HLO_QuantizedIntTensor> { /*uniform_quantize_c1*/
let summary = "UniformQuantize operation";
let description = [{
This operation is a work in progress, so it is not yet included in
the StableHLO specification: https://github.com/openxla/stablehlo/issues/588.
Performs element-wise conversion of floating-point tensor or quantized
tensor `operand` to a quantized tensor `result` according to the
quantization parameters defined by the `result` type.

Informally, this operation converts floating point tensors or uniform
quantized tensors to uniform quantized tensors according to the quantization
parameters defined by the result type.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_quantize

Example:
```mlir
%result = stablehlo.uniform_quantize %operand : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>
%result = stablehlo.uniform_quantize %operand : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
```
}];
}

def StableHLO_UniformDequantizeOp : StableHLO_UnaryElementwiseOp<"uniform_dequantize",
[InferTensorType, Pure], HLO_QuantizedIntTensor, TensorOf<[F32, BF16]>> {
[InferTensorType, Pure], HLO_QuantizedIntTensor /*uniform_dequantize_i1*/,
HLO_FpTensor> { /*uniform_dequantize_c1, uniform_dequantize_c2*/
let summary = "UniformDequantize operation";
let description = [{
This operation is a work in progress, so it is not yet included in
the StableHLO specification: https://github.com/openxla/stablehlo/issues/588.
Performs element-wise conversion of quantized tensor `operand` to a
floating-point tensor `result` according to the quantization parameters
defined by the `operand` type.

Informally, this operation converts uniform quantized tensors to floating
point tensors according to the quantization parameters defined by the
operand type.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_dequantize

Example:
```mlir
%result = stablehlo.uniform_dequantize %operand : (tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xf32>
%result = stablehlo.uniform_dequantize %operand : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
```
}];
}
Expand Down
2 changes: 2 additions & 0 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3009,6 +3009,7 @@ LogicalResult inferUniformDequantizeOp(
// Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType;
auto quantType = operandType.getElementType().cast<quant::QuantizedType>();
auto shape = operandType.cast<ShapedType>().getShape();
// uniform_dequantize_c1, uniform_dequantize_c2
inferredReturnShapes.emplace_back(shape, quantType.getExpressedType());
return success();
}
Expand All @@ -3017,6 +3018,7 @@ LogicalResult inferUniformQuantizeOp(
std::optional<Location> location, Value operand,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto operandType = operand.getType().cast<ShapedType>();
// uniform_quantize_c1
inferredReturnShapes.emplace_back(
operandType.hasRank() ? operandType.getShape() : ArrayRef<int64_t>{});
return success();
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/tests/infer_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ func.func @clamp(%arg0: tensor<1xi32>) -> tensor<1xindex> {

// -----

// CHECK: func @uniform_dequantize
func.func @uniform_dequantize(%arg: tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xindex> {
// CHECK: func @uniform_dequantize_c2
func.func @uniform_dequantize_c2(%arg: tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xindex> {
%0 = stablehlo.uniform_dequantize %arg : (tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xf32>
// CHECK: types0 = tensor<16x16xf32>
%1 = "hlo_test_infer.get_return_types"(%0) : (tensor<16x16xf32>) -> tensor<16x16xindex>
Expand Down
13 changes: 2 additions & 11 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5090,7 +5090,7 @@ func.func @quantized_dot_general(%arg0: tensor<2x16x32x!quant.uniform<i8:f32, 2.

// -----

// CHECK: func @uniform_quantize
// CHECK-LABEL: func @uniform_quantize
func.func @uniform_quantize(%arg: tensor<16x16xf32>) -> tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>> {
%0 = stablehlo.uniform_quantize %arg : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>
func.return %0 : tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>
Expand All @@ -5106,30 +5106,21 @@ func.func @uniform_requantize(%arg: tensor<16x16x!quant.uniform<i8:f32, 5.0:20>>

// -----

// CHECK: func @uniform_dequantize
// CHECK-LABEL: func @uniform_dequantize
func.func @uniform_dequantize(%arg: tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xf32> {
%0 = stablehlo.uniform_dequantize %arg : (tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xf32>
func.return %0 : tensor<16x16xf32>
}

// -----

// CHECK: func @uniform_dequantize_unranked
func.func @uniform_dequantize_unranked(%arg: tensor<*x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<*xf32> {
%0 = stablehlo.uniform_dequantize %arg : (tensor<*x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @uniform_dequantize_not_quantize(%arg: tensor<16x16xf32>) -> tensor<16x16xf32> {
// expected-error@+1 {{operand #0 must be tensor of 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x16xf32>'}}
%0 = stablehlo.uniform_dequantize %arg : (tensor<16x16xf32>) -> tensor<16x16xf32>
func.return %0 : tensor<16x16xf32>
}

// -----

// CHECK-LABEL: func @quantized_constants
func.func @quantized_constants() -> (tensor<2x!quant.uniform<i8:f32, 2.0:15>>, tensor<2x!quant.uniform<ui8:f32, 34.0:16>>, tensor<2x!quant.uniform<i8:f32, 2.0:15>>) {
%0 = stablehlo.constant() {value = dense<[1, 2]> : tensor<2xi8>} : () -> tensor<2x!quant.uniform<i8:f32, 2.000000e+00:15>>
Expand Down

0 comments on commit 8993ad5

Please sign in to comment.