diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 365fa1b59af..fdd9be66eee 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -149,12 +149,18 @@ def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>; // Any floating-point or complex tensor types def HLO_FpOrComplexTensor : TensorOf<[HLO_Float, HLO_Complex]>; -// Any int, floating-point or complex tensor types -def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, HLO_Float, HLO_Complex]>; +// Any floating-point, complex or quantized tensor types +def HLO_FpComplexOrQuantizedIntTensor : TensorOf<[HLO_Float, HLO_Complex, HLO_QuantizedInt]>; + +// Any int, floating-point, complex or quantized tensor types +def HLO_IntFpOrComplexOrQuantizedIntTensor : TensorOf<[HLO_Int, HLO_Float, HLO_Complex, HLO_QuantizedInt]>; // Any pred, int or floating-point tensor types def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float]>; +// Any pred, int, floating-point or quantized tensor types +def HLO_PredIntFpOrQuantizedTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float, HLO_QuantizedInt]>; + //===----------------------------------------------------------------------===// // HLO static shape type definitions. //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 7c3aaab08df..46641892144 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -200,8 +200,8 @@ class StableHLO_UnaryElementwiseOp traits, // Abs supports complex to real, so element type is not guaranteed to match. def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs", [Pure, DeclareOpInterfaceMethods], - TensorOf<[HLO_SInt, HLO_Float, HLO_Complex] /* abs_i1 */>, - TensorOf<[HLO_SInt, HLO_Float]>> { + TensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt] /* abs_i1 */>, + TensorOf<[HLO_SInt, HLO_Float, HLO_QuantizedInt]>> { let summary = "Abs operation"; let description = [{ Performs element-wise abs operation on `operand` tensor and produces a @@ -219,7 +219,7 @@ def StableHLO_AbsOp: StableHLO_UnaryElementwiseOp<"abs", def StableHLO_CbrtOp: StableHLO_UnaryElementwiseOp<"cbrt", [Pure, HLO_CompatibleOperandsAndResultType /*cbrt_c1*/], - HLO_FpOrComplexTensor /*cbrt_i1*/> { /*cbrt_c1*/ + HLO_FpComplexOrQuantizedIntTensor /*cbrt_i1*/> { /*cbrt_c1*/ let summary = "Cbrt operation"; let description = [{ Performs element-wise cubic root operation on `operand` tensor and produces @@ -289,7 +289,7 @@ def StableHLO_ClzOp: StableHLO_UnaryElementwiseOp<"count_leading_zeros", } def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { + [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpComplexOrQuantizedIntTensor> { let summary = "Cosine operation"; let description = [{ Performs element-wise cosine operation on `operand` tensor and produces a @@ -307,7 +307,7 @@ def StableHLO_CosineOp: StableHLO_UnaryElementwiseOp<"cosine", def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential", [Pure, HLO_CompatibleOperandsAndResultType /*exponential_c1*/], - HLO_FpOrComplexTensor /*exponential_i1*/> { + HLO_FpComplexOrQuantizedIntTensor /*exponential_i1*/> { let summary = "Exp operation"; let description = [{ Performs element-wise exponential operation on `operand` tensor and produces @@ -325,7 +325,7 @@ def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential", def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one", [Pure, HLO_CompatibleOperandsAndResultType], /*exponential_minus_one_c1*/ - HLO_FpOrComplexTensor /*exponential_minus_one_i1*/> { /*exponential_minus_one_c1*/ + HLO_FpComplexOrQuantizedIntTensor /*exponential_minus_one_i1*/> { /*exponential_minus_one_c1*/ let summary = "Expm1 operation"; let description = [{ Performs element-wise exponential minus one operation on `operand` tensor @@ -402,7 +402,7 @@ def StableHLO_IsFiniteOp: StableHLO_UnaryElementwiseOp<"is_finite", [Pure, def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log", [Pure, HLO_CompatibleOperandsAndResultType /*log_c1*/], - HLO_FpOrComplexTensor /*log_i1*/> { + HLO_FpComplexOrQuantizedIntTensor /*log_i1*/> { let summary = "Log operation"; let description = [{ Performs element-wise logarithm operation on `operand` tensor and produces a @@ -420,7 +420,7 @@ def StableHLO_LogOp: StableHLO_UnaryElementwiseOp<"log", def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one", [Pure, HLO_CompatibleOperandsAndResultType /*log_plus_one_c1*/], - HLO_FpOrComplexTensor /*log_plus_one_i1*/> { /*log_plus_one_c1*/ + HLO_FpComplexOrQuantizedIntTensor /*log_plus_one_i1*/> { /*log_plus_one_c1*/ let summary = "Log1p operation"; let description = [{ Performs element-wise logarithm plus one operation on `operand` tensor and @@ -438,7 +438,7 @@ def StableHLO_Log1pOp: StableHLO_UnaryElementwiseOp<"log_plus_one", def StableHLO_LogisticOp: StableHLO_UnaryElementwiseOp<"logistic", [Pure, HLO_CompatibleOperandsAndResultType /*logistic_c1*/], - HLO_FpOrComplexTensor /*logistic_i1*/> { /*logistic_c1*/ + HLO_FpComplexOrQuantizedIntTensor /*logistic_i1*/> { /*logistic_c1*/ let summary = "Logistic operation"; let description = [{ Performs element-wise logistic operation on `operand` tensor and produces a @@ -472,7 +472,7 @@ def StableHLO_NotOp: StableHLO_UnaryElementwiseOp<"not", } def StableHLO_NegOp: StableHLO_UnaryElementwiseOp<"negate", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> { + [Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexOrQuantizedIntTensor> { let summary = "Neg operation"; let description = [{ Performs element-wise negation of `operand` tensor and produces a `result` @@ -563,7 +563,7 @@ def StableHLO_RoundNearestEvenOp: StableHLO_UnaryElementwiseOp<"round_nearest_ev def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt", [Pure, HLO_CompatibleOperandsAndResultType /* rsqrt_c1 */], - HLO_FpOrComplexTensor /* rsqrt_i1 */> { + HLO_FpComplexOrQuantizedIntTensor /* rsqrt_i1 */> { let summary = "Rsqrt operation"; let description = [{ Performs element-wise reciprocal square root operation on `operand` tensor @@ -582,7 +582,7 @@ def StableHLO_RsqrtOp: StableHLO_UnaryElementwiseOp<"rsqrt", [Pure, def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign", [Pure, HLO_CompatibleOperandsAndResultType /*sign_c1*/], - TensorOf<[HLO_SInt, HLO_Float, HLO_Complex]> /*sign_i1*/> { /*sign_c1*/ + TensorOf<[HLO_SInt, HLO_Float, HLO_Complex, HLO_QuantizedInt]> /*sign_i1*/> { /*sign_c1*/ let summary = "Sign operation"; let description = [{ Returns the sign of the `operand` element-wise and produces a `result` @@ -599,7 +599,7 @@ def StableHLO_SignOp: StableHLO_UnaryElementwiseOp<"sign", } def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { + [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpComplexOrQuantizedIntTensor> { let summary = "Sine operation"; let description = [{ Performs element-wise sine operation on `operand` tensor and produces a @@ -617,7 +617,7 @@ def StableHLO_SineOp: StableHLO_UnaryElementwiseOp<"sine", def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt", [Pure, HLO_CompatibleOperandsAndResultType /* sqrt_c1 */], - HLO_FpOrComplexTensor /* sqrt_i1 */> { + HLO_FpComplexOrQuantizedIntTensor /* sqrt_i1 */> { let summary = "Sqrt operation"; let description = [{ Performs element-wise square root operation on `operand` tensor and produces @@ -635,7 +635,7 @@ def StableHLO_SqrtOp: StableHLO_UnaryElementwiseOp<"sqrt", [Pure, def StableHLO_TanhOp: StableHLO_UnaryElementwiseOp<"tanh", [Pure, HLO_CompatibleOperandsAndResultType], - HLO_FpOrComplexTensor> { + HLO_FpComplexOrQuantizedIntTensor> { let summary = "Tanh operation"; let description = [{ Performs element-wise hyperbolic tangent operation on `operand` tensor and @@ -705,7 +705,7 @@ def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add", def StableHLO_Atan2Op : StableHLO_BinaryElementwiseOp<"atan2", [Pure, HLO_CompatibleOperandsAndResultType /*atan2_c1*/], - HLO_FpOrComplexTensor /*atan2_i1, atan2_i2*/> { /*atan2_c1*/ + HLO_FpComplexOrQuantizedIntTensor /*atan2_i1, atan2_i2*/> { /*atan2_c1*/ let summary = "Atan2 operation"; let description = [{ Performs element-wise atan2 operation on `lhs` and `rhs` tensor and produces @@ -752,7 +752,7 @@ def StableHLO_ComplexOp: StableHLO_BinaryElementwiseOp<"complex", [Pure, def StableHLO_DivOp : StableHLO_BinaryElementwiseOp<"divide", [Pure, HLO_CompatibleOperandsAndResultType /* div_c1 */], - HLO_IntFpOrComplexTensor /* div_i1, div_i2 */> { + HLO_IntFpOrComplexOrQuantizedIntTensor /* div_i1, div_i2 */> { let summary = "Div operation"; let description = [{ Performs element-wise division of dividend `lhs` and divisor `rhs` tensors @@ -821,7 +821,7 @@ def StableHLO_MulOp : StableHLO_BinaryElementwiseOp<"multiply", def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power", [Pure, HLO_CompatibleOperandsAndResultType /* pow_c1 */], - HLO_IntFpOrComplexTensor /* pow_i1, pow_i2 */> { + HLO_IntFpOrComplexOrQuantizedIntTensor /* pow_i1, pow_i2 */> { let summary = "Power operation"; let description = [{ Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and @@ -839,7 +839,7 @@ def StableHLO_PowOp : StableHLO_BinaryElementwiseOp<"power", def StableHLO_RemOp : StableHLO_BinaryElementwiseOp<"remainder", [Pure, HLO_CompatibleOperandsAndResultType /*remainder_c1*/], - HLO_IntFpOrComplexTensor /*remainder_i1, remainder_i2*/> { + HLO_IntFpOrComplexOrQuantizedIntTensor /*remainder_i1, remainder_i2*/> { let summary = "Rem operation"; let description = [{ Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors @@ -910,7 +910,7 @@ def StableHLO_ShiftRightLogicalOp : StableHLO_BinaryElementwiseOp<"shift_right_l } def StableHLO_SubtractOp : StableHLO_BinaryElementwiseOp<"subtract", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> { + [Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexOrQuantizedIntTensor> { let summary = "Subtract operation"; let description = [{ Performs element-wise subtraction of two tensors `lhs` and `rhs` and diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index cf80542b79b..7306835dfd1 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -5301,6 +5301,36 @@ func.func @is_compatible_quant_signedness_mismatch(%arg0: tensor<1x!quant.unifor func.return } +// ----- + +// The following is the not the exhaustive list of ops supporting quantized +// types. The list will be updated as part of adding verification support for +// quantized ops. +func.func @quantization_supported_ops(%arg0: tensor<1x2x2x!quant.uniform>, %arg1: tensor<1x2x2x!quant.uniform>, %arg2: tensor>) { + %0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %1 = "stablehlo.divide"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %2 = "stablehlo.power"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %3 = "stablehlo.remainder"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %4 = "stablehlo.subtract"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform>, tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + + %5 = "stablehlo.abs"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %6 = "stablehlo.cbrt"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %7 = "stablehlo.cosine"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %8 = "stablehlo.exponential"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %9 = "stablehlo.exponential_minus_one"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %10 = "stablehlo.log"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %11 = "stablehlo.log_plus_one"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %12 = "stablehlo.logistic"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %13 = "stablehlo.negate"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %14 = "stablehlo.rsqrt"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %15 = "stablehlo.sign"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %16 = "stablehlo.sine"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %17 = "stablehlo.sqrt"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + %18 = "stablehlo.tanh"(%arg0) : (tensor<1x2x2x!quant.uniform>) -> tensor<1x2x2x!quant.uniform> + func.return +} + + // ----- // CHECK-LABEL: is_compatible_dynamism_bounds