Skip to content

Commit

Permalink
Adding the missing quantized-type support in the ODS (#1608)
Browse files Browse the repository at this point in the history
Some of the stablehlo ops does not have support for quantized types in
their tablegen specification, prohibits writing StableHLO quantized
programs using those ops. The PR is about adding the missing support for
the following ops. Also, I believe the ongoing specification
[work](#588), should not
deviate much from the proposed changes here.

```
stablehlo.atan2
stablehlo.divide
stablehlo.power
stablehlo.remainder
stablehlo.subtract

stablehlo.abs
stablehlo.cbrt
stablehlo.cosine
stablehlo.exponential
stablehlo.exponential_minus_one
stablehlo.log
stablehlo.log_plus_one
stablehlo.logistic
stablehlo.negate
stablehlo.rsqrt
stablehlo.sign
stablehlo.sine
stablehlo.sqrt
stablehlo.tanh

stablehlo.cholesky
stablehlo.triangular_solve
```

Other than these ops, we have `fft`, `rng`, and `rng_bit_generator` (or
something else which I might be missing) which could be potential
candidates for the support. I propose that we add the support after
adding the specification of those op as adding the support might need
some non-trivial discussion.
  • Loading branch information
sdasgup3 authored Jun 29, 2023
1 parent ce932dd commit 1e33939
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 22 deletions.
10 changes: 8 additions & 2 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,18 @@ def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
// Any floating-point or complex tensor types
def HLO_FpOrComplexTensor : TensorOf<[HLO_Float, HLO_Complex]>;

// Any int, floating-point or complex tensor types
def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, HLO_Float, HLO_Complex]>;
// Any floating-point, complex or quantized tensor types
def HLO_FpComplexOrQuantizedIntTensor : TensorOf<[HLO_Float, HLO_Complex, HLO_QuantizedInt]>;

// Any int, floating-point, complex or quantized tensor types
def HLO_IntFpOrComplexOrQuantizedIntTensor : TensorOf<[HLO_Int, HLO_Float, HLO_Complex, HLO_QuantizedInt]>;

// Any pred, int or floating-point tensor types
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float]>;

// Any pred, int, floating-point or quantized tensor types
def HLO_PredIntFpOrQuantizedTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float, HLO_QuantizedInt]>;

//===----------------------------------------------------------------------===//
// HLO static shape type definitions.
//===----------------------------------------------------------------------===//
Expand Down
40 changes: 20 additions & 20 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ class StableHLO_UnaryElementwiseOp<string mnemonic, list<Trait> traits,
// Abs supports complex to real, so element type is not guaranteed to match.
def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>],
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex] /* abs_i1 */>,
TensorOf<[HLO_SInt, HLO_Float]>> {
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt] /* abs_i1 */>,
TensorOf<[HLO_SInt, HLO_Float, HLO_QuantizedInt]>> {
let summary = "Abs operation";
let description = [{
Performs element-wise abs operation on `operand` tensor and produces a
Expand All @@ -219,7 +219,7 @@ def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs",

def StableHLO_CbrtOp: StableHLO_UnaryElementwiseOp<"cbrt",
[Pure, HLO_CompatibleOperandsAndResultType /*cbrt_c1*/],
HLO_FpOrComplexTensor /*cbrt_i1*/> { /*cbrt_c1*/
HLO_FpComplexOrQuantizedIntTensor /*cbrt_i1*/> { /*cbrt_c1*/
let summary = "Cbrt operation";
let description = [{
Performs element-wise cubic root operation on `operand` tensor and produces
Expand Down Expand Up @@ -289,7 +289,7 @@ def StableHLO_ClzOp: StableHLO_UnaryElementwiseOp<"count_leading_zeros",
}

def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpComplexOrQuantizedIntTensor> {
let summary = "Cosine operation";
let description = [{
Performs element-wise cosine operation on `operand` tensor and produces a
Expand All @@ -307,7 +307,7 @@ def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine",

def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential",
[Pure, HLO_CompatibleOperandsAndResultType /*exponential_c1*/],
HLO_FpOrComplexTensor /*exponential_i1*/> {
HLO_FpComplexOrQuantizedIntTensor /*exponential_i1*/> {
let summary = "Exp operation";
let description = [{
Performs element-wise exponential operation on `operand` tensor and produces
Expand All @@ -325,7 +325,7 @@ def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential",

def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one",
[Pure, HLO_CompatibleOperandsAndResultType], /*exponential_minus_one_c1*/
HLO_FpOrComplexTensor /*exponential_minus_one_i1*/> { /*exponential_minus_one_c1*/
HLO_FpComplexOrQuantizedIntTensor /*exponential_minus_one_i1*/> { /*exponential_minus_one_c1*/
let summary = "Expm1 operation";
let description = [{
Performs element-wise exponential minus one operation on `operand` tensor
Expand Down Expand Up @@ -402,7 +402,7 @@ def StableHLO_IsFiniteOp: StableHLO_UnaryElementwiseOp<"is_finite", [Pure,

def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log",
[Pure, HLO_CompatibleOperandsAndResultType /*log_c1*/],
HLO_FpOrComplexTensor /*log_i1*/> {
HLO_FpComplexOrQuantizedIntTensor /*log_i1*/> {
let summary = "Log operation";
let description = [{
Performs element-wise logarithm operation on `operand` tensor and produces a
Expand All @@ -420,7 +420,7 @@ def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log",

def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one",
[Pure, HLO_CompatibleOperandsAndResultType /*log_plus_one_c1*/],
HLO_FpOrComplexTensor /*log_plus_one_i1*/> { /*log_plus_one_c1*/
HLO_FpComplexOrQuantizedIntTensor /*log_plus_one_i1*/> { /*log_plus_one_c1*/
let summary = "Log1p operation";
let description = [{
Performs element-wise logarithm plus one operation on `operand` tensor and
Expand All @@ -438,7 +438,7 @@ def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one",

def StableHLO_LogisticOp: StableHLO_UnaryElementwiseOp<"logistic",
[Pure, HLO_CompatibleOperandsAndResultType /*logistic_c1*/],
HLO_FpOrComplexTensor /*logistic_i1*/> { /*logistic_c1*/
HLO_FpComplexOrQuantizedIntTensor /*logistic_i1*/> { /*logistic_c1*/
let summary = "Logistic operation";
let description = [{
Performs element-wise logistic operation on `operand` tensor and produces a
Expand Down Expand Up @@ -472,7 +472,7 @@ def StableHLO_NotOp: StableHLO_UnaryElementwiseOp<"not",
}

def StableHLO_NegOp: StableHLO_UnaryElementwiseOp<"negate",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Neg operation";
let description = [{
Performs element-wise negation of `operand` tensor and produces a `result`
Expand Down Expand Up @@ -563,7 +563,7 @@ def StableHLO_RoundNearestEvenOp: StableHLO_UnaryElementwiseOp<"round_nearest_ev

def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt", [Pure,
HLO_CompatibleOperandsAndResultType /* rsqrt_c1 */],
HLO_FpOrComplexTensor /* rsqrt_i1 */> {
HLO_FpComplexOrQuantizedIntTensor /* rsqrt_i1 */> {
let summary = "Rsqrt operation";
let description = [{
Performs element-wise reciprocal square root operation on `operand` tensor
Expand All @@ -582,7 +582,7 @@ def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt", [Pure,

def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign",
[Pure, HLO_CompatibleOperandsAndResultType /*sign_c1*/],
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex]> /*sign_i1*/> { /*sign_c1*/
TensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt]> /*sign_i1*/> { /*sign_c1*/
let summary = "Sign operation";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
Expand All @@ -599,7 +599,7 @@ def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign",
}

def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_FpComplexOrQuantizedIntTensor> {
let summary = "Sine operation";
let description = [{
Performs element-wise sine operation on `operand` tensor and produces a
Expand All @@ -617,7 +617,7 @@ def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine",

def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt", [Pure,
HLO_CompatibleOperandsAndResultType /* sqrt_c1 */],
HLO_FpOrComplexTensor /* sqrt_i1 */> {
HLO_FpComplexOrQuantizedIntTensor /* sqrt_i1 */> {
let summary = "Sqrt operation";
let description = [{
Performs element-wise square root operation on `operand` tensor and produces
Expand All @@ -635,7 +635,7 @@ def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt", [Pure,

def StableHLO_TanhOp: StableHLO_UnaryElementwiseOp<"tanh",
[Pure, HLO_CompatibleOperandsAndResultType],
HLO_FpOrComplexTensor> {
HLO_FpComplexOrQuantizedIntTensor> {
let summary = "Tanh operation";
let description = [{
Performs element-wise hyperbolic tangent operation on `operand` tensor and
Expand Down Expand Up @@ -705,7 +705,7 @@ def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add",

def StableHLO_Atan2Op : StableHLO_BinaryElementwiseOp<"atan2",
[Pure, HLO_CompatibleOperandsAndResultType /*atan2_c1*/],
HLO_FpOrComplexTensor /*atan2_i1, atan2_i2*/> { /*atan2_c1*/
HLO_FpComplexOrQuantizedIntTensor /*atan2_i1, atan2_i2*/> { /*atan2_c1*/
let summary = "Atan2 operation";
let description = [{
Performs element-wise atan2 operation on `lhs` and `rhs` tensor and produces
Expand Down Expand Up @@ -752,7 +752,7 @@ def StableHLO_ComplexOp: StableHLO_BinaryElementwiseOp<"complex", [Pure,

def StableHLO_DivOp : StableHLO_BinaryElementwiseOp<"divide", [Pure,
HLO_CompatibleOperandsAndResultType /* div_c1 */],
HLO_IntFpOrComplexTensor /* div_i1, div_i2 */> {
HLO_IntFpOrComplexOrQuantizedIntTensor /* div_i1, div_i2 */> {
let summary = "Div operation";
let description = [{
Performs element-wise division of dividend `lhs` and divisor `rhs` tensors
Expand Down Expand Up @@ -821,7 +821,7 @@ def StableHLO_MulOp : StableHLO_BinaryElementwiseOp<"multiply",

def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power",
[Pure, HLO_CompatibleOperandsAndResultType /* pow_c1 */],
HLO_IntFpOrComplexTensor /* pow_i1, pow_i2 */> {
HLO_IntFpOrComplexOrQuantizedIntTensor /* pow_i1, pow_i2 */> {
let summary = "Power operation";
let description = [{
Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and
Expand All @@ -839,7 +839,7 @@ def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power",

def StableHLO_RemOp : StableHLO_BinaryElementwiseOp<"remainder",
[Pure, HLO_CompatibleOperandsAndResultType /*remainder_c1*/],
HLO_IntFpOrComplexTensor /*remainder_i1, remainder_i2*/> {
HLO_IntFpOrComplexOrQuantizedIntTensor /*remainder_i1, remainder_i2*/> {
let summary = "Rem operation";
let description = [{
Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors
Expand Down Expand Up @@ -910,7 +910,7 @@ def StableHLO_ShiftRightLogicalOp : StableHLO_BinaryElementwiseOp<"shift_right_l
}

def StableHLO_SubtractOp : StableHLO_BinaryElementwiseOp<"subtract",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> {
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Subtract operation";
let description = [{
Performs element-wise subtraction of two tensors `lhs` and `rhs` and
Expand Down
30 changes: 30 additions & 0 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5301,6 +5301,36 @@ func.func @is_compatible_quant_signedness_mismatch(%arg0: tensor<1x!quant.unifor
func.return
}

// -----

// The following is the not the exhaustive list of ops supporting quantized
// types. The list will be updated as part of adding verification support for
// quantized ops.
func.func @quantization_supported_ops(%arg0: tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, %arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>) {
%0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%1 = "stablehlo.divide"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%2 = "stablehlo.power"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%3 = "stablehlo.remainder"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%4 = "stablehlo.subtract"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>

%5 = "stablehlo.abs"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%6 = "stablehlo.cbrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%7 = "stablehlo.cosine"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%8 = "stablehlo.exponential"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%9 = "stablehlo.exponential_minus_one"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%10 = "stablehlo.log"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%11 = "stablehlo.log_plus_one"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%12 = "stablehlo.logistic"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%13 = "stablehlo.negate"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%14 = "stablehlo.rsqrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%15 = "stablehlo.sign"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%16 = "stablehlo.sine"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%17 = "stablehlo.sqrt"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%18 = "stablehlo.tanh"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
func.return
}


// -----

// CHECK-LABEL: is_compatible_dynamism_bounds
Expand Down

0 comments on commit 1e33939

Please sign in to comment.