Skip to content

Commit

Permalink
feat: dynamic_gather to gather opt pass (#193)
Browse files Browse the repository at this point in the history
* feat: dynamic_gather to gather opt pass

* test: test conversion

* feat: support non-int64 slice sizes

* feat: add the pass to td and py
  • Loading branch information
avik-pal authored Dec 16, 2024
1 parent 30861b8 commit 51687b0
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 10 deletions.
69 changes: 59 additions & 10 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6602,6 +6602,54 @@ struct IfToSelect final : public OpRewritePattern<mlir::stablehlo::IfOp> {
}
};

struct DynamicGatherOpIsNotDynamic
: public OpRewritePattern<stablehlo::DynamicGatherOp> {
using OpRewritePattern<stablehlo::DynamicGatherOp>::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::DynamicGatherOp op,
PatternRewriter &rewriter) const override {
// Check if slice sizes are constant.
DenseIntElementsAttr sliceSizesAttr;
if (!matchPattern(op.getSliceSizes(), m_Constant(&sliceSizesAttr))) {
return failure();
}

// dynamic_gather allows non-int64 slice sizes, but we need to convert them
// to int64 for the gather.
if (!sliceSizesAttr.getType().getElementType().isInteger(64)) {
SmallVector<APInt> sliceSizes;
for (auto size : sliceSizesAttr.getValues<APInt>()) {
sliceSizes.push_back(size);
}
auto newSliceSizesAttr = DenseElementsAttr::get(
RankedTensorType::get(sliceSizesAttr.getType().getShape(),
rewriter.getIntegerType(64)),
sliceSizes);
sliceSizesAttr = newSliceSizesAttr.cast<DenseIntElementsAttr>();
}

SmallVector<int64_t> sliceSizes;
for (auto size : sliceSizesAttr.getValues<int64_t>()) {
sliceSizes.push_back(size);
}
auto sliceSizesArrayAttr =
DenseI64ArrayAttr::get(op.getContext(), sliceSizes);

rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, op.getType(), op.getOperand(), op.getStartIndices(),
stablehlo::GatherDimensionNumbersAttr::get(
op.getContext(), op.getDimensionNumbers().getOffsetDims(),
op.getDimensionNumbers().getCollapsedSliceDims(),
/*operandBatchingDims=*/{},
/*startIndicesBatchingDims=*/{},
op.getDimensionNumbers().getStartIndexMap(),
op.getDimensionNumbers().getIndexVectorDim()),
sliceSizesArrayAttr);

return success();
}
};

/// Check if a `t` is a tensor with zero extents.
static std::optional<RankedTensorType> isZeroExtent(Type t) {
auto type = t.dyn_cast<RankedTensorType>();
Expand Down Expand Up @@ -6848,16 +6896,17 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
if (no_nan || all_finite)
patterns.add<NoNan>(context);

patterns
.add<CompareOpCanon, BroadcastInDimOpCanon, ConvertOpCanon,
DynamicBroadcastInDimOpNotActuallyDynamic,
ChainedDynamicBroadcastInDimCanonicalization,
DynamicBroadcastInDimAllDimsNonExpanding, NoopReduceOpCanon,
EmptyReduceOpCanon, DynamicReshapeOpCanon, GetTupleElementOpCanon,
RealOpCanon, ImagOpCanon, ConjComplexNegate,
GetDimensionSizeOpCanon, GatherOpCanon, ReshapeOpCanon,
MergeConsecutiveReshapes, TransposeIsReshape, IfInline, IfToSelect,
ZeroExtentTensorCanon, ReorderElementwiseAndShapeOp>(context);
patterns.add<CompareOpCanon, BroadcastInDimOpCanon, ConvertOpCanon,
DynamicBroadcastInDimOpNotActuallyDynamic,
ChainedDynamicBroadcastInDimCanonicalization,
DynamicBroadcastInDimAllDimsNonExpanding, NoopReduceOpCanon,
EmptyReduceOpCanon, DynamicReshapeOpCanon,
GetTupleElementOpCanon, RealOpCanon, ImagOpCanon,
ConjComplexNegate, GetDimensionSizeOpCanon, GatherOpCanon,
ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape,
IfInline, IfToSelect, ZeroExtentTensorCanon,
ReorderElementwiseAndShapeOp, DynamicGatherOpIsNotDynamic>(
context);
patterns.add<SelectOpCanon>(max_constant_expansion, context,
PatternBenefit(65000));
patterns.add<ConcatenateOpCanon>(max_constant_expansion, context,
Expand Down
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,10 @@ def ApplyConstPropThroughBarrierPatterns : EnzymeHLOPatternOp<
"const_prop_through_barrier"> {
let patterns = ["ConstPropThroughBarrier"];
}
def DynamicGatherOpIsNotDynamic : EnzymeHLOPatternOp<
"dynamic_gather_op_is_not_dynamic"> {
let patterns = ["DynamicGatherOpIsNotDynamic"];
}

// TODO: better naming for parameters requires a static interface for
// constructing them in search.
Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def hlo_opts():
transpose_is_reshape<16>;
zero_extent_tensor_canon<16>;
reorder_elementwise_and_shape_op<16>;
dynamic_gather_op_is_not_dynamic<16>;
cse_broadcast_in_dim<16>;
cse_slice<16>;
Expand Down
35 changes: 35 additions & 0 deletions test/lit_tests/dynamicgathertogather.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

module {
func.func @main(%arg0: tensor<4x4xf64>) -> tensor<2xf64> {
%c = stablehlo.constant dense<1> : tensor<2xi64>
%c_0 = stablehlo.constant dense<[[1, -1], [2, 0]]> : tensor<2x2xi64>
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
%1 = "stablehlo.dynamic_gather"(%0, %c_0, %c) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>}> : (tensor<4x4xf64>, tensor<2x2xi64>, tensor<2xi64>) -> tensor<2xf64>
return %1 : tensor<2xf64>
}
}

// CHECK: func.func @main(%arg0: tensor<4x4xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %c = stablehlo.constant dense<{{\[\[}}1, -1{{\]}}, {{\[}}2, 0{{\]\]}}> : tensor<2x2xi64>
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
// CHECK-NEXT: %1 = "stablehlo.gather"(%0, %c) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<4x4xf64>, tensor<2x2xi64>) -> tensor<2xf64>
// CHECK-NEXT: return %1 : tensor<2xf64>
// CHECK-NEXT: }

module {
func.func @main(%arg0: tensor<4x4xf64>) -> tensor<2xf64> {
%c = stablehlo.constant dense<1> : tensor<2xi32>
%c_0 = stablehlo.constant dense<[[1, -1], [2, 0]]> : tensor<2x2xi64>
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
%1 = "stablehlo.dynamic_gather"(%0, %c_0, %c) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>}> : (tensor<4x4xf64>, tensor<2x2xi64>, tensor<2xi32>) -> tensor<2xf64>
return %1 : tensor<2xf64>
}
}

// CHECK: func.func @main(%arg0: tensor<4x4xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %c = stablehlo.constant dense<{{\[\[}}1, -1{{\]}}, {{\[}}2, 0{{\]\]}}> : tensor<2x2xi64>
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x4xf64>) -> tensor<4x4xf64>
// CHECK-NEXT: %1 = "stablehlo.gather"(%0, %c) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<4x4xf64>, tensor<2x2xi64>) -> tensor<2xf64>
// CHECK-NEXT: return %1 : tensor<2xf64>
// CHECK-NEXT: }

0 comments on commit 51687b0

Please sign in to comment.