Skip to content

Commit

Permalink
Unpack per-channel hybrid quantized MHLO ops to float ops
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621760079
  • Loading branch information
doyeonkim0 authored and copybara-github committed Apr 5, 2024
1 parent be5c637 commit 304bc3f
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,18 @@ struct DotLikeDimensionNumbers {
SmallVector<int64_t> rhsContractingDims;
};

// Checks if zero points of the given quantized type are zero.
bool isZeroPointZero(QuantType type) {
if (isPerTensorType(type)) {
return getPerTensorType(type).getZeroPoint() == 0;
}
if (isPerChannelType(type)) {
ArrayRef<int64_t> zeroPoints = getPerChannelType(type).getZeroPoints();
return llvm::all_of(zeroPoints, [](int64_t zp) { return zp == 0; });
}
return false;
}

// A shared matchAndRewrite implementation for dot-like hybrid quantized
// operators. Hybrid ops are currently only interpreted as weight-only
// quantization ops, this might change in the future.
Expand All @@ -611,30 +623,37 @@ LogicalResult matchAndRewriteDotLikeHybridOp(
adaptor.getRhs());
Operation::result_range resultRange = barrier.getResults();
Value rhs = resultRange.front();
auto rhsElementType = getElementTypeOrSelf(op.getRhs().getType())
.template cast<quant::UniformQuantizedType>();
FailureOr<QuantType> rhsElementQuantType =
getQuantType(op.getRhs().getType());
if (failed(rhsElementQuantType)) {
return failure();
}
auto resFloat32TensorType =
op.getResult().getType().template cast<TensorType>();
auto rhsFloat32TensorType =
op.getRhs().getType().template cast<TensorType>().clone(
rewriter.getF32Type());

// Get scales and zero points for rhs.
Value rhsZeroPoint = rewriter.create<mhlo::ConstantOp>(
op->getLoc(), rewriter.getF32FloatAttr((rhsElementType.getZeroPoint())));
Value rhsScaleConstant = rewriter.create<mhlo::ConstantOp>(
op->getLoc(),
rewriter.getF32FloatAttr(static_cast<float>(rhsElementType.getScale())));
Value rhsScale, rhsZeroPoint;
DenseI64ArrayAttr broadcastDims;
getQuantizationParams(rewriter, op->getLoc(), *rhsElementQuantType, rhsScale,
rhsZeroPoint,
/*outputZeroPointInFp=*/true, broadcastDims);

// Dequantize rhs_float32_tensor.
Value rhsFloat32Tensor =
rewriter.create<mhlo::ConvertOp>(op->getLoc(), rhsFloat32TensorType, rhs);
rhsFloat32Tensor = rewriter.create<chlo::BroadcastSubOp>(
op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsZeroPoint,
nullptr);

// Subtract zero points only when it is not zero.
if (!isZeroPointZero(*rhsElementQuantType)) {
rhsFloat32Tensor = rewriter.create<chlo::BroadcastSubOp>(
op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsZeroPoint,
broadcastDims);
}
rhsFloat32Tensor = rewriter.create<chlo::BroadcastMulOp>(
op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsScaleConstant,
nullptr);
op->getLoc(), rhsFloat32TensorType, rhsFloat32Tensor, rhsScale,
broadcastDims);

// Execute conversion target op.
SmallVector<Value, 2> operands{lhsFloat32Tensor, rhsFloat32Tensor};
Expand Down Expand Up @@ -1045,7 +1064,8 @@ FailureOr<bool> isDotLikeOpHybrid(DotLikeOp op) {
// both per-tensor quantized.
return false;
}
if (!isLhsQuant && !isLhsQuantPerChannel && isRhsQuant && !isResQuant &&
if (!isLhsQuant && !isLhsQuantPerChannel &&
(isRhsQuant || isRhsQuantPerChannel) && !isResQuant &&
!isResQuantPerChannel) {
return true;
}
Expand Down
128 changes: 128 additions & 0 deletions xla/mlir_hlo/tests/Dialect/mhlo/mhlo-quant-legalize-to-int.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,61 @@ func.func @dot_hybrid(

// -----

// CHECK-LABEL: func @dot_general_hybrid_per_channel
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xi8>
func.func @dot_general_hybrid_per_channel(
%arg0: tensor<3x2xf32>,
%arg1: tensor<2x2x!quant.uniform<i8<-127:127>:f32:1, {3.000000e+00, 4.000000e+00}>>
) -> tensor<3x2xf32> {
// CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<2x2xi8>
// CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<2x2xi8>) -> tensor<2x2xf32>
// CHECK-NOT: chlo.broadcast_subtract
// CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[CONVERT]], %[[SCALES]] {broadcast_dimensions = array<i64: 1>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
// CHECK: %[[DOT:.*]] = "mhlo.dot_general"(%[[ARG0]], %[[MUL]])
// CHECK: {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>} : (tensor<3x2xf32>, tensor<2x2xf32>) -> tensor<3x2xf32>
// CHECK: return %[[DOT]]

%0 = "mhlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [0]>} : (
tensor<3x2xf32>,
tensor<2x2x!quant.uniform<i8<-127:127>:f32:1, {3.000000e+00, 4.000000e+00}>>
) -> tensor<3x2xf32>
return %0 : tensor<3x2xf32>
}

// -----

// CHECK-LABEL: func @dot_general_hybrid_per_channel_asymmetric
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xi8>
func.func @dot_general_hybrid_per_channel_asymmetric(
%arg0: tensor<3x2xf32>,
%arg1: tensor<2x2x!quant.uniform<i8<-127:127>:f32:1, {3.000000e+00:10, 4.000000e+00:20}>>
) -> tensor<3x2xf32> {
// CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<2x2xi8>
// CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[3.000000e+00, 4.000000e+00]> : tensor<2xf32>
// CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[1.000000e+01, 2.000000e+01]> : tensor<2xf32>
// CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<2x2xi8>) -> tensor<2x2xf32>
// CHECK: %[[SUB:.*]] = chlo.broadcast_subtract %[[CONVERT]], %[[ZPS]] {broadcast_dimensions = array<i64: 1>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
// CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[SUB]], %[[SCALES]] {broadcast_dimensions = array<i64: 1>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
// CHECK: %[[DOT:.*]] = "mhlo.dot_general"(%[[ARG0]], %[[MUL]])
// CHECK: {dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [0]>} : (tensor<3x2xf32>, tensor<2x2xf32>) -> tensor<3x2xf32>
// CHECK: return %[[DOT]]

%0 = "mhlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #mhlo.dot<lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [0]>} : (
tensor<3x2xf32>,
tensor<2x2x!quant.uniform<i8<-127:127>:f32:1, {3.000000e+00:10, 4.000000e+00:20}>>
) -> tensor<3x2xf32>
return %0 : tensor<3x2xf32>
}

// -----

func.func @dot_hybrid_result_type_not_float(
%arg0: tensor<?x?xf32>,
%arg1: tensor<?x?x!quant.uniform<i8:f32, 1.000000e+00:3>>) {
Expand Down Expand Up @@ -1839,6 +1894,79 @@ func.func @conv2d_static_hybrid(

// -----

// CHECK-LABEL: func @conv2d_hybrid_per_channel
// CHECK-SAME: %[[ARG0:.*]]: tensor<128x28x28x1xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<3x3x1x2xi8>
func.func @conv2d_hybrid_per_channel(
%arg0: tensor<128x28x28x1xf32>,
%arg1: tensor<3x3x1x2x!quant.uniform<i8:f32:3, {2.000000e+00:0, 1.000000e+00:0}>>
) -> tensor<128x26x26x2xf32> {
// CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<3x3x1x2xi8>
// CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32>
// CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<3x3x1x2xi8>) -> tensor<3x3x1x2xf32>
// CHECK-NOT: chlo.broadcast_subtract
// CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[CONVERT]], %[[SCALES]] {broadcast_dimensions = array<i64: 3>} : (tensor<3x3x1x2xf32>, tensor<2xf32>) -> tensor<3x3x1x2xf32>
// CHECK: %[[CONV:.*]] = mhlo.convolution(%[[ARG0]], %[[MUL]])
// CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
// CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<128x28x28x1xf32>, tensor<3x3x1x2xf32>) -> tensor<128x26x26x2xf32>
// CHECK: return %[[CONV]]

%0 = mhlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {
stride = [1, 1], pad = [[0, 0], [0, 0]],
lhs_dilate = [1, 1],
rhs_dilate = [1, 1]
}
{
batch_group_count = 1 : i64,
feature_group_count = 1 : i64
} : (
tensor<128x28x28x1xf32>,
tensor<3x3x1x2x!quant.uniform<i8:f32:3, {2.000000e+00:0, 1.000000e+00:0}>>)
-> tensor<128x26x26x2xf32>
return %0 : tensor<128x26x26x2xf32>
}

// -----

// CHECK-LABEL: func @conv2d_hybrid_per_channel_asymmetric
// CHECK-SAME: %[[ARG0:.*]]: tensor<128x28x28x1xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<3x3x1x2xi8>
func.func @conv2d_hybrid_per_channel_asymmetric(
%arg0: tensor<128x28x28x1xf32>,
%arg1: tensor<3x3x1x2x!quant.uniform<i8:f32:3, {2.000000e+00:10, 1.000000e+00:20}>>
) -> tensor<128x26x26x2xf32> {
// CHECK-DAG: %[[BARRIER:.*]] = mhlo.optimization_barrier %[[ARG1]] : tensor<3x3x1x2xi8>
// CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32>
// CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<[1.000000e+01, 2.000000e+01]> : tensor<2xf32>
// CHECK-DAG: %[[CONVERT:.*]] = mhlo.convert %[[BARRIER]] : (tensor<3x3x1x2xi8>) -> tensor<3x3x1x2xf32>
// CHECK: %[[SUB:.*]] = chlo.broadcast_subtract %[[CONVERT]], %[[ZPS]] {broadcast_dimensions = array<i64: 3>} : (tensor<3x3x1x2xf32>, tensor<2xf32>) -> tensor<3x3x1x2xf32>
// CHECK: %[[MUL:.*]] = chlo.broadcast_multiply %[[SUB]], %[[SCALES]] {broadcast_dimensions = array<i64: 3>} : (tensor<3x3x1x2xf32>, tensor<2xf32>) -> tensor<3x3x1x2xf32>
// CHECK: %[[CONV:.*]] = mhlo.convolution(%[[ARG0]], %[[MUL]])
// CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
// CHECK-SAME: {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<128x28x28x1xf32>, tensor<3x3x1x2xf32>) -> tensor<128x26x26x2xf32>
// CHECK: return %[[CONV]]

%0 = mhlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {
stride = [1, 1], pad = [[0, 0], [0, 0]],
lhs_dilate = [1, 1],
rhs_dilate = [1, 1]
}
{
batch_group_count = 1 : i64,
feature_group_count = 1 : i64
} : (
tensor<128x28x28x1xf32>,
tensor<3x3x1x2x!quant.uniform<i8:f32:3, {2.000000e+00:10, 1.000000e+00:20}>>)
-> tensor<128x26x26x2xf32>
return %0 : tensor<128x26x26x2xf32>
}

// -----

func.func @conv2d_hybrid_result_not_float(
%arg0: tensor<128x28x28x1xf32>,
%arg1: tensor<3x3x1x128x!quant.uniform<i8:f32, 3.000000e+00:0>>) {
Expand Down

0 comments on commit 304bc3f

Please sign in to comment.