diff --git a/docs/spec.md b/docs/spec.md index de7f6fdeefa..e6ae74ca2c5 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -335,11 +335,6 @@ in StableHLO programs. In the meanwhile, here is the list of these operations: `dynamic_gather`, `dynamic_iota`, `dynamic_pad`, `dynamic_reshape`, `real_dynamic_slice`, `set_dimension_size` ([#8](https://github.com/openxla/stablehlo/issues/8)). -* "Quantization" category of StableHLO operations - they were bootstrapped from - MHLO, but we haven't specced them yet: `uniform_quantize` - ([#531](https://github.com/openxla/stablehlo/issues/531)) and - `uniform_dequantize` - ([#530](https://github.com/openxla/stablehlo/issues/530)). * Shape computations, including `arith`, `shape` and `tensor` operations ([#8](https://github.com/openxla/stablehlo/issues/8)). @@ -5535,6 +5530,87 @@ Produces a `result` tuple from values `val`. // %result: ([1.0, 2.0], (3)) ``` +### uniform_dequantize + +#### Semantics + +Performs element-wise conversion of quantized tensor `operand` to a +floating-point tensor `result` according to the quantization parameters defined +by the `operand` type. + +More formally, `result = dequantize(operand)`. + +#### Inputs + +| Label | Name | Type | Constraints | +|-------|-----------|------------------|-------------| +| (I1) | `operand` | quantized tensor | (C1), (C2) | + +#### Outputs + +| Name | Type | Constraints | +|----------|-------------------------------|-------------| +| `result` | tensor of floating-point type | (C1), (C2) | + +#### Constraints + +* (C1) `shape(operand) = shape(result)`. +* (C2) `element_type(result) = expressed_type(operand)`. + +#### Examples + +```mlir +// %operand: [10, 10] +%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform>) -> tensor<2xf32> +// %result: [4.0, 15.0] +``` + +### uniform_quantize + +#### Semantics + +Performs element-wise conversion of floating-point tensor or quantized tensor +`operand` to a quantized tensor `result` according to the quantization +parameters defined by the `result` type. + +More formally, + +* If `is_float(operand)`: + * `result = quantize(operand, type(result))`. +* If `is_quantized(operand)`: + * `float_result = dequantize(operand)`. + * `result = quantize(float_result, type(result))`. + +#### Inputs + +| Label | Name | Type | Constraints | +|-------|-----------|--------------------------------------------|-------------| +| (I1) | `operand` | tensor of floating-point or quantized type | (C1), (C2) | + +#### Outputs + +| Name | Type | Constraints | +|----------|------------------|-------------| +| `result` | quantized tensor | (C1), (C2) | + +#### Constraints + +* (C1) `shape(operand) = shape(result)`. +* (C2) `expressed_type(result) = is_float(operand) ? element_type(operand) : + expressed_type(operand)`. + +#### Examples + +```mlir +// %operand: [4.0, 15.0] +%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform> +// %result: [10, 10] + +// %operand: [10, 10] +%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> +// %result: [20, 45] +``` + ### while #### Semantics @@ -6136,6 +6212,41 @@ def baseline_type(x: Value | Placeholder | Type) -> Type: return baseline_element_type(type(x)) ``` +* `dequantize` is defined on quantized tensor types and turns them into +floating-point tensor types. This happens via converting quantized elements +which represent integer values of the storage type into corresponding +floating-point values of the expressed type using the zero point and scale +associated with the quantized element type. At the moment, this function only +works for per-tensor quantization. Per-axis quantization is work in progress +([#1574](https://github.com/openxla/stablehlo/issues/1574)). + +```python +def dequantize(x: Value) -> Value: + assert is_quantized(x) + x_storage = bitcast_convert(x, storage_type(x)) + x_storage_sub = x_storage - zero_point(x_storage) + x_expressed_sub = convert(x_storage_sub, expressed_type(x)) + return x_expressed_sub * scale(x) +``` + +* `quantize` is defined on floating-point tensor types and turns them into +quantized tensor types. This happens via converting floating-point values +of the expressed type into corresponding integer values of the storage type +using the zero point and scale associated with the quantized element type. +At the moment, this function only works for per-tensor quantization. Per-axis +quantization is work in progress +([#1574](https://github.com/openxla/stablehlo/issues/1574)). + +```python +def quantize(x: Value, type: Type) -> Value: + assert is_float(x) and is_quantized(type) + x_expressed_rounded = round_nearest_even(x / scale(type)) + x_storage_rounded = convert(x_expressed_rounded, storage_type(type)) + x_storage_add = x_storage_rounded + zero_point(type) + x_storage = clamp(storage_min(type), x_storage_add, storage_max(type)) + return bitcast_convert(x_storage, type) +``` + * `dequantize_op_quantize` is used to specify element-wise computations on quantized tensors. It dequantizes, i.e. turns quantized elements into their expressed types, then performs an operation, and then quantizes, i.e. turns @@ -6147,10 +6258,10 @@ works for per-tensor quantization. Per-axis quantization is work in progress def dequantize_op_quantize(op, *inputs_and_output_type): inputs = inputs_and_output_type[:-1] output_type = inputs_and_output_type[-1] - float_inputs = [(x - zero_point(x)) * scale(x) for x in inputs] + + float_inputs = map(dequantize, inputs) float_result = op(*float_inputs) - rounded_result = round_nearest_even(float_result / scale(output_type)) - return clamp(storage_min(output_type), rounded_result, storage_max(output_type)) + return quantize(float_result, output_type) ``` #### Grid computations diff --git a/docs/status.md b/docs/status.md index 820bfc3dbdf..0b8e1d59146 100644 --- a/docs/status.md +++ b/docs/status.md @@ -153,7 +153,7 @@ one of the following tracking labels. | triangular_solve | yes | revisit | yes | no | revisit | | tuple | yes | yes | yes | yes | no | | unary_einsum | no | revisit | no | yes | revisit | -| uniform_dequantize | no | yes\* | yes\* | yes | no | -| uniform_quantize | no | yes\* | infeasible | yes | no | +| uniform_dequantize | yes | yes | yes | yes | no | +| uniform_quantize | yes | revisit | infeasible | yes | no | | while | yes | revisit | yes | revisit | yes | | xor | yes | yes | yes | yes | yes | diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index db0f2add210..6276f5e2915 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -3055,38 +3055,39 @@ def StableHLO_RngBitGeneratorOp : StableHLO_Op<"rng_bit_generator", [Pure]> { // TODO(b/230662142): Implement unknown scales/zero_point cases. def StableHLO_UniformQuantizeOp : StableHLO_UnaryElementwiseOp<"uniform_quantize", - [Pure], TensorOf<[F32, BF16, HLO_QuantizedInt]>, - HLO_QuantizedIntTensor> { + [Pure], TensorOf<[HLO_Float, HLO_QuantizedInt]> /*uniform_quantize_i1*/, + HLO_QuantizedIntTensor> { /*uniform_quantize_c1*/ let summary = "UniformQuantize operation"; let description = [{ - This operation is a work in progress, so it is not yet included in - the StableHLO specification: https://github.com/openxla/stablehlo/issues/588. + Performs element-wise conversion of floating-point tensor or quantized + tensor `operand` to a quantized tensor `result` according to the + quantization parameters defined by the `result` type. - Informally, this operation converts floating point tensors or uniform - quantized tensors to uniform quantized tensors according to the quantization - parameters defined by the result type. + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_quantize Example: ```mlir - %result = stablehlo.uniform_quantize %operand : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> + %result = stablehlo.uniform_quantize %operand : (tensor<2xf32>) -> tensor<2x!quant.uniform> ``` }]; } def StableHLO_UniformDequantizeOp : StableHLO_UnaryElementwiseOp<"uniform_dequantize", - [InferTensorType, Pure], HLO_QuantizedIntTensor, TensorOf<[F32, BF16]>> { + [InferTensorType, Pure], HLO_QuantizedIntTensor /*uniform_dequantize_i1*/, + HLO_FpTensor> { /*uniform_dequantize_c1, uniform_dequantize_c2*/ let summary = "UniformDequantize operation"; let description = [{ - This operation is a work in progress, so it is not yet included in - the StableHLO specification: https://github.com/openxla/stablehlo/issues/588. + Performs element-wise conversion of quantized tensor `operand` to a + floating-point tensor `result` according to the quantization parameters + defined by the `operand` type. - Informally, this operation converts uniform quantized tensors to floating - point tensors according to the quantization parameters defined by the - operand type. + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_dequantize Example: ```mlir - %result = stablehlo.uniform_dequantize %operand : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> + %result = stablehlo.uniform_dequantize %operand : (tensor<2x!quant.uniform>) -> tensor<2xf32> ``` }]; } diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index a60b873501e..d4233dfbc20 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3009,6 +3009,7 @@ LogicalResult inferUniformDequantizeOp( // Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType; auto quantType = operandType.getElementType().cast(); auto shape = operandType.cast().getShape(); + // uniform_dequantize_c1, uniform_dequantize_c2 inferredReturnShapes.emplace_back(shape, quantType.getExpressedType()); return success(); } @@ -3017,6 +3018,7 @@ LogicalResult inferUniformQuantizeOp( std::optional location, Value operand, SmallVectorImpl& inferredReturnShapes) { auto operandType = operand.getType().cast(); + // uniform_quantize_c1 inferredReturnShapes.emplace_back( operandType.hasRank() ? operandType.getShape() : ArrayRef{}); return success(); diff --git a/stablehlo/tests/infer_stablehlo.mlir b/stablehlo/tests/infer_stablehlo.mlir index eb694bbc5f4..e8082e501ad 100644 --- a/stablehlo/tests/infer_stablehlo.mlir +++ b/stablehlo/tests/infer_stablehlo.mlir @@ -234,8 +234,8 @@ func.func @clamp(%arg0: tensor<1xi32>) -> tensor<1xindex> { // ----- -// CHECK: func @uniform_dequantize -func.func @uniform_dequantize(%arg: tensor<16x16x!quant.uniform>) -> tensor<16x16xindex> { +// CHECK: func @uniform_dequantize_c2 +func.func @uniform_dequantize_c2(%arg: tensor<16x16x!quant.uniform>) -> tensor<16x16xindex> { %0 = stablehlo.uniform_dequantize %arg : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> // CHECK: types0 = tensor<16x16xf32> %1 = "hlo_test_infer.get_return_types"(%0) : (tensor<16x16xf32>) -> tensor<16x16xindex> diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index a6ebfd781b7..c4b81e28816 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -5090,7 +5090,7 @@ func.func @quantized_dot_general(%arg0: tensor<2x16x32x!quant.uniform) -> tensor<16x16x!quant.uniform> { %0 = stablehlo.uniform_quantize %arg : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform> func.return %0 : tensor<16x16x!quant.uniform> @@ -5106,7 +5106,7 @@ func.func @uniform_requantize(%arg: tensor<16x16x!quant.uniform> // ----- -// CHECK: func @uniform_dequantize +// CHECK-LABEL: func @uniform_dequantize func.func @uniform_dequantize(%arg: tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> { %0 = stablehlo.uniform_dequantize %arg : (tensor<16x16x!quant.uniform>) -> tensor<16x16xf32> func.return %0 : tensor<16x16xf32> @@ -5114,7 +5114,6 @@ func.func @uniform_dequantize(%arg: tensor<16x16x!quant.uniform // ----- -// CHECK: func @uniform_dequantize_unranked func.func @uniform_dequantize_unranked(%arg: tensor<*x!quant.uniform>) -> tensor<*xf32> { %0 = stablehlo.uniform_dequantize %arg : (tensor<*x!quant.uniform>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -5122,14 +5121,6 @@ func.func @uniform_dequantize_unranked(%arg: tensor<*x!quant.uniform) -> tensor<16x16xf32> { - // expected-error@+1 {{operand #0 must be tensor of 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x16xf32>'}} - %0 = stablehlo.uniform_dequantize %arg : (tensor<16x16xf32>) -> tensor<16x16xf32> - func.return %0 : tensor<16x16xf32> -} - -// ----- - // CHECK-LABEL: func @quantized_constants func.func @quantized_constants() -> (tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) { %0 = stablehlo.constant() {value = dense<[1, 2]> : tensor<2xi8>} : () -> tensor<2x!quant.uniform>