diff --git a/docs/spec.md b/docs/spec.md index 9b6476d60d3..a87e5d4fc08 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -140,6 +140,7 @@ which may allow us to remove tuple types from StableHLO ```ebnf ElementType ::= BooleanType | IntegerType | FloatType | ComplexType + | QuantizedType BooleanType ::= 'i1' IntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64' @@ -173,10 +174,66 @@ values of type `tensor`). and an **imaginary part** of the same **element type**. Supported complex types are `complex` (both parts are of type `f32`) and `complex` (both parts are of type `f64`). -* In the future, we are also planning to introduce **quantized types** that - represent integer values obtained via uniform quantization of floating-point - values using given scales and zero points - ([#588](https://github.com/openxla/stablehlo/issues/588)). +* **Uniform Quantized types** represent a mapping of real values to quantized + integers. + +```ebnf +QuantizedType ::= PerTensorQuantizedType | PerAxisQuantizedType +PerTensorQuantizedType ::= 'quantized.uniform' '<' + StorageType ':' ExpressedType ':' MinVal ':' + MaxVal ':' Scale ':' ZeroPoint '>' +PerAxisQuantizedType ::= 'quantized.uniform' '<' + StorageType ':' ExpressedType ':' MinVal ':' + MaxVal ':' Scales ':' ZeroPoints ':' + QuantizationDimension '>' +StorageType ::= 'i4' | 'u4' | 'i8' | 'u8' | 'i16' | 'u16' | 'i32' | 'u32' +ExpressedType ::= 'bf16' | 'f32' +MinVal ::= IntegerConstant +MaxVal ::= IntegerConstant +Scales ::= '[' Scale {',' Scale} ']' +ZeroPoints ::= '[' ZeroPoint {',' ZeroPoint} ']' +Scale ::= '?' | FloatConstant +ZeroPoint ::= '?' | IntegerConstant +QuantizationDimension ::= IntegerConstant +``` + + * `storage_type` is the integer type used for storage. + * `expressed_type` is the data type the quantized type represents. + * `min_val` is the minimum integer value that the quantized type can take. + * `min_val` is the maximum integer value that the quantized type can take. + * `scale` is the multiplicative factor used to dequantize back to the + `expressed_type`. + * `zero_point` is the offset used to dequantize back to the `expressed_type`. + * `per_tensor_quantizatied_types` are quantized per tensor. + * `per_axis_quantizatied_types` are quantized per dimension per tensor for the + given `quantization_dimension`. + + For a given tensor or dimension the quantized value (`Quantization`) is + computed as follows - + +``` +quantized_value = clamp(min_val, + round(expressed_value / scale) + zero_point, + max_val) +``` + + For a given tensor or dimension the expressed value (`DeQuantization`) is + computed as follows - + +``` +dequantized_value = (quantized_value − zero_point) * scale +``` + + Given `alpha` and `beta` are maximum and minimum values (per-tensor or + per-axis) needed to be represented in `expressed_type`, the values of `scale` + and `zero_point` can be computed as follows - + +``` +scale = (alpha - beta) / (max_val - min_val) +zero_point = min_val - round(beta * scale) +``` + + ```ebnf FunctionType ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')' @@ -418,13 +475,25 @@ have the following constraints: * (C1) `is_wellformed(literal[:], element_type(type))`. +```ebnf +QuantizedConstant ::= QuantizedLiteral ':' QuantizedType +QuantizedLiteral ::= IntegerLiteral +``` + +**Quantized constants** represent quantized values via integer literals. +Quantized constants have the following constraints: + +* (C1) `is_wellformed(literal, type)`, i.e. `literal` can be parsed as + a value of type `type`. + ```ebnf TensorConstant ::= TensorLiteral ':' TensorType TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>' DenseLiteral ::= DenseDimension | DenseElements DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']' DenseElements ::= [ElementLiteral {',' ElementLiteral}] -ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral +ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | + ComplexLiteral | QuantizedLiteral ``` **Tensor constants** represent tensor values using nested lists specified via