From d2f85d74a9492a4203e768eee8392253ecd0fbc7 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 1 Mar 2024 21:32:16 -0500 Subject: [PATCH 01/17] WIP jax reverse more --- WORKSPACE | 13 ++++++--- src/enzyme_ad/jax/primitives.py | 47 +++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index f337884bf..d05df1766 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -67,13 +67,18 @@ ENZYME_SHA256 = "f9479530b08aeb3ecbf0c420d0e2f222fdf8bcf6c20a218271b365db3a3053a # path = "../Enzyme/enzyme" # ) -http_archive( +local_repository( name = "enzyme", - sha256 = ENZYME_SHA256, - strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", - urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], + path = "../Enzyme/enzyme" ) +# http_archive( +# name = "enzyme", +# sha256 = ENZYME_SHA256, +# strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", +# urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], +# ) + JAX_COMMIT = "9a098e922aff62a3b49bd673b9518d97ee599248" JAX_SHA256 = "" diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index df2460c34..67f905468 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -1051,6 +1051,53 @@ def fwd_partial_eval(trace, *args, **kwargs): pe.custom_partial_eval_rules[_enzyme_fwd_p] = fwd_partial_eval +def primal_partial_eval(trace, *args, **kwargs): + pipeline_options = kwargs["pipeline_options"] + if not pipeline_options.mlir_ad() or kwargs["lang"] != LANG_MHLO or pipeline_options.ad_level() == 0: + return trace.default_process_primitive(_enzyme_primal_p, args, kwargs) + + assert len(args) % 2 == 0 + nr_primals = len(args) // 2 + primals, tangents = args[0::2], args[1::2] + all_primals_known = all(p.is_known() for p in primals) + some_tangents_unknown = any(not t.is_known() for t in tangents) + + if not (all_primals_known and some_tangents_unknown): + return trace.default_process_primitive(_enzyme_primal_p, args, kwargs) + + shadow_aug_args = primals + tangents + + out_shapes = kwargs["out_shapes"] + out_shapes2 = out_shapes[:len(out_shapes)//2] + del kwargs["out_shapes"] + + shadows_known = trace.default_process_primitive( + _enzyme_shadow_aug_p, shadow_aug_args, kwargs | {'out_shapes':out_shapes2} + ) + + passes = pipeline_options.pass_pipeline() + start = passes.rindex("enzyme-wrap{") + prev_passes = passes[:start] + end = passes.index("}", start) + post_passes = passes[end+1:] + newpasses = prev_passes + post_passes[1:] + + if pipeline_options.stablehlo_inject(): + pipeline_options = JaXPipeline(newpasses) + else: + pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad()) + + (in_tree, in_idx_map, mfunc) = kwargs["source"] + + avals = {k//2: v for k, v in in_idx_map.items() if k % 2 == 0} + source = (in_tree, avals, mfunc) + + primalret = trace.default_process_primitive(_enzyme_primal_p, primals, {'out_shapes':out_shapes2, 'source':source, 'fn':kwargs['fn'], 'argv':kwargs['argv'], 'lang':kwargs['lang'], 'pipeline_options':pipeline_options}) + return primalret + shadows_known + + +pe.custom_partial_eval_rules[_enzyme_primal_p] = primal_partial_eval + def primal_partial_eval(trace, *args, **kwargs): pipeline_options = kwargs["pipeline_options"] From a3512d56152412a78884d9acf43fcfd5c3df792e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 2 Mar 2024 15:05:34 -0500 Subject: [PATCH 02/17] fix --- src/enzyme_ad/jax/primitives.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 67f905468..c724c95b9 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -613,6 +613,7 @@ def zero(ty): else: z = zero(orig_types[v]) results2.append(z) + results = tuple(results2) else: identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( @@ -1087,10 +1088,11 @@ def primal_partial_eval(trace, *args, **kwargs): else: pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad()) - (in_tree, in_idx_map, mfunc) = kwargs["source"] + (in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"] avals = {k//2: v for k, v in in_idx_map.items() if k % 2 == 0} - source = (in_tree, avals, mfunc) + outmap2 = {k//2: v for k, v in out_idx_map.items() if k % 2 == 0} + source = (in_tree, avals, outmap2, mfunc) primalret = trace.default_process_primitive(_enzyme_primal_p, primals, {'out_shapes':out_shapes2, 'source':source, 'fn':kwargs['fn'], 'argv':kwargs['argv'], 'lang':kwargs['lang'], 'pipeline_options':pipeline_options}) return primalret + shadows_known From d9a85c675dd18239f7663d3a5523c3107249b028 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 2 Mar 2024 15:07:03 -0500 Subject: [PATCH 03/17] fix enzyme commit --- WORKSPACE | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index d05df1766..f337884bf 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -67,18 +67,13 @@ ENZYME_SHA256 = "f9479530b08aeb3ecbf0c420d0e2f222fdf8bcf6c20a218271b365db3a3053a # path = "../Enzyme/enzyme" # ) -local_repository( +http_archive( name = "enzyme", - path = "../Enzyme/enzyme" + sha256 = ENZYME_SHA256, + strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", + urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], ) -# http_archive( -# name = "enzyme", -# sha256 = ENZYME_SHA256, -# strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", -# urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], -# ) - JAX_COMMIT = "9a098e922aff62a3b49bd673b9518d97ee599248" JAX_SHA256 = "" From b3779787838d210aba4676d52b829af02675216f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 2 Mar 2024 15:10:52 -0500 Subject: [PATCH 04/17] cleanup format --- src/enzyme_ad/jax/primitives.py | 40 +++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index c724c95b9..f71729033 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -1052,11 +1052,16 @@ def fwd_partial_eval(trace, *args, **kwargs): pe.custom_partial_eval_rules[_enzyme_fwd_p] = fwd_partial_eval + def primal_partial_eval(trace, *args, **kwargs): pipeline_options = kwargs["pipeline_options"] - if not pipeline_options.mlir_ad() or kwargs["lang"] != LANG_MHLO or pipeline_options.ad_level() == 0: + if ( + not pipeline_options.mlir_ad() + or kwargs["lang"] != LANG_MHLO + or pipeline_options.ad_level() == 0 + ): return trace.default_process_primitive(_enzyme_primal_p, args, kwargs) - + assert len(args) % 2 == 0 nr_primals = len(args) // 2 primals, tangents = args[0::2], args[1::2] @@ -1067,34 +1072,45 @@ def primal_partial_eval(trace, *args, **kwargs): return trace.default_process_primitive(_enzyme_primal_p, args, kwargs) shadow_aug_args = primals + tangents - + out_shapes = kwargs["out_shapes"] - out_shapes2 = out_shapes[:len(out_shapes)//2] + out_shapes2 = out_shapes[: len(out_shapes) // 2] del kwargs["out_shapes"] shadows_known = trace.default_process_primitive( - _enzyme_shadow_aug_p, shadow_aug_args, kwargs | {'out_shapes':out_shapes2} + _enzyme_shadow_aug_p, shadow_aug_args, kwargs | {"out_shapes": out_shapes2} ) - + passes = pipeline_options.pass_pipeline() start = passes.rindex("enzyme-wrap{") prev_passes = passes[:start] end = passes.index("}", start) - post_passes = passes[end+1:] + post_passes = passes[end + 1 :] newpasses = prev_passes + post_passes[1:] - + if pipeline_options.stablehlo_inject(): pipeline_options = JaXPipeline(newpasses) else: pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad()) - + (in_tree, in_idx_map, out_idx_map, mfunc) = kwargs["source"] - avals = {k//2: v for k, v in in_idx_map.items() if k % 2 == 0} - outmap2 = {k//2: v for k, v in out_idx_map.items() if k % 2 == 0} + avals = {k // 2: v for k, v in in_idx_map.items() if k % 2 == 0} + outmap2 = {k // 2: v for k, v in out_idx_map.items() if k % 2 == 0} source = (in_tree, avals, outmap2, mfunc) - primalret = trace.default_process_primitive(_enzyme_primal_p, primals, {'out_shapes':out_shapes2, 'source':source, 'fn':kwargs['fn'], 'argv':kwargs['argv'], 'lang':kwargs['lang'], 'pipeline_options':pipeline_options}) + primalret = trace.default_process_primitive( + _enzyme_primal_p, + primals, + { + "out_shapes": out_shapes2, + "source": source, + "fn": kwargs["fn"], + "argv": kwargs["argv"], + "lang": kwargs["lang"], + "pipeline_options": pipeline_options, + }, + ) return primalret + shadows_known From ad4090bda22925994eecac1ef8e29ded2e9a321c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 5 Mar 2024 23:25:46 -0500 Subject: [PATCH 05/17] fixup --- WORKSPACE | 4 +- .../StableHLOAutoDiffOpInterfaceImpl.cpp | 212 ++++++++++++++++++ src/enzyme_ad/jax/primitives.py | 6 +- test/bench_vs_xla.py | 4 +- test/llama.py | 2 +- 5 files changed, 220 insertions(+), 8 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index f337884bf..547cf7dc1 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen pip_install_dependencies() -ENZYME_COMMIT = "0b621884bc531329095d202f042f6599a86614ec" -ENZYME_SHA256 = "f9479530b08aeb3ecbf0c420d0e2f222fdf8bcf6c20a218271b365db3a3053ad" +ENZYME_COMMIT = "97066352a40b3c66f9a1f41ec1802af255216c0c" +ENZYME_SHA256 = "" # local_repository( # name = "enzyme", # path = "../Enzyme/enzyme" diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 3b7ebaaa9..e2f509c3d 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -187,6 +187,215 @@ class AutoDiffReduceFwd } }; +class AutoDiffBroadcastInDimRev + : public ReverseAutoDiffOpInterface::ExternalModel< + AutoDiffBroadcastInDimRev, BroadcastInDimOp> { +public: + LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto op = cast(orig); + auto inTy = op.getOperand().getType(); + auto outTy = op.getType(); + auto zero = gutils->getShadowType(inTy) + .cast() + .createNullValue(builder, op.getLoc()); + auto inDiffe = gutils->diffe(op, builder); + gutils->zeroDiffe(op, builder); + + SmallVector bcastDims(op.getBroadcastDimensions().begin(), + op.getBroadcastDimensions().end()); + + if (bcastDims.size() == 0 && inTy.getShape().size() == 0) { + for (size_t i = 0; i < outTy.getShape().size(); i++) { + bcastDims.push_back(i); + } + } + assert(outTy.getShape().size() == + inTy.getShape().size() + bcastDims.size()); + + auto red = builder.create(op.getLoc(), + TypeRange(gutils->getShadowType(inTy)), + inDiffe, zero, bcastDims); + red.getBody().push_back(new Block()); + Block &body = red.getBody().front(); + OpBuilder bodyBuilder(orig->getContext()); + bodyBuilder.setInsertionPointToEnd(&body); + + body.addArgument(zero.getType(), op.getLoc()); + body.addArgument(zero.getType(), op.getLoc()); + auto add = bodyBuilder.create(op.getLoc(), body.getArgument(0), + body.getArgument(1)); + bodyBuilder.create(op.getLoc(), ValueRange(add)); + + gutils->addToDiffe(op.getOperand(), red->getResult(0), builder); + return success(); + } + + SmallVector cacheValues(Operation *orig, + MGradientUtilsReverse *gutils) const { + return {}; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + +class AutoDiffSliceRev + : public ReverseAutoDiffOpInterface::ExternalModel { +public: + LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto op = cast(orig); + auto inTy = op.getOperand().getType(); + auto outTy = op.getType(); + auto zero = inTy.cast().createNullValue(builder, + op.getLoc()); + auto inDiffe = gutils->diffe(op, builder); + gutils->zeroDiffe(op, builder); + + Value idxs; + { + SmallVector concat_data; + for (size_t i = 0; i < outTy.getShape().size(); i++) { + concat_data.push_back(outTy.getShape()[i]); + } + concat_data.push_back(1); + auto toConcatType = + RankedTensorType::get(concat_data, builder.getI32Type()); + std::vector inds; + size_t idx = 0; + for (auto &&[start, limit, stride] : llvm::zip( + op.getStartIndices(), op.getLimitIndices(), op.getStrides())) { + std::vector data; + for (int32_t i = start; i < limit; i += stride) { + data.push_back(i); + } + + Value ind = builder.create(op.getLoc(), RankedTensorType::get({data.size()}, builder.getI32Type()), + builder.getI32TensorAttr(data)); + + auto bcast_ind = builder.getDenseI64ArrayAttr({idx}); + ind = builder.create(op.getLoc(), toConcatType, ind, + bcast_ind); + inds.push_back(ind); + idx++; + } + idxs = builder.create( + op.getLoc(), inds, builder.getI64IntegerAttr(concat_data.size() - 1)); + } + + // empty extra index into the slice + std::vector update_window_dims; + std::vector scatter_dims_to_operand_dims; + std::vector inserted_window_dims; + for (int i = 0; i < inTy.getShape().size(); i++) { + scatter_dims_to_operand_dims.push_back(i); + inserted_window_dims.push_back(i); + } + + int64_t indexVectorDim = inTy.getShape().size(); + + auto dims = ScatterDimensionNumbersAttr::get( + builder.getContext(), update_window_dims, inserted_window_dims, + scatter_dims_to_operand_dims, indexVectorDim); + + // auto prev = gutils->diffe(op.getOperand(), builder); + + auto red = builder.create( + op.getLoc(), TypeRange(gutils->getShadowType(inTy)), ValueRange(zero), + idxs, ValueRange(inDiffe), dims, + /*indices_are_sorted*/ builder.getBoolAttr(true), + /*unique_indices*/ builder.getBoolAttr(true)); + + red.getUpdateComputation().push_back(new Block()); + Block &body = red.getUpdateComputation().front(); + OpBuilder bodyBuilder(orig->getContext()); + bodyBuilder.setInsertionPointToEnd(&body); + + auto TT = RankedTensorType::get({}, inTy.getElementType()); + body.addArgument(TT, op.getLoc()); + body.addArgument(TT, op.getLoc()); + /* + auto add = bodyBuilder.create(op.getLoc(), body.getArgument(0), + body.getArgument(1)); + bodyBuilder.create(op.getLoc(), ValueRange(add)); + */ + bodyBuilder.create(op.getLoc(), ValueRange(body.getArgument(1))); + + gutils->addToDiffe(op.getOperand(), red->getResult(0), builder); + // gutils->setDiffe(op.getOperand(), red->getResult(0), builder); + + return success(); + } + + SmallVector cacheValues(Operation *orig, + MGradientUtilsReverse *gutils) const { + return {}; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + +class AutoDiffReduceRev + : public ReverseAutoDiffOpInterface::ExternalModel { +public: + LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto op = cast(orig); + if (!isEligibleForCompactPrint(op)) { + orig->emitError() << "Unsupported operation in reduction rev autodiff(1): " + << *orig << "\n"; + return failure(); + } + + Operation &innerOp = op.getBody().front().front(); + + auto inTy = op->getOperand(0).getType(); + auto zero = inTy.cast().createNullValue(builder, + op.getLoc()); + auto inDiffe = gutils->diffe(op->getResult(0), builder); + gutils->zeroDiffe(op->getResult(0), builder); + + if (isa(innerOp)) { + if (!gutils->isConstantValue(op.getInputs()[0])) { + Value bcast; + + if (op->getResult(0).getType().cast().getShape().size() == 0) + bcast = builder.create(op.getLoc(), gutils->getShadowType(inTy), inDiffe, builder.getDenseI64ArrayAttr({})); + else + bcast = builder.create(op.getLoc(), gutils->getShadowType(inTy), inDiffe, op.getDimensions()); + + gutils->addToDiffe(op.getInputs()[0], bcast, builder); + } + if (!gutils->isConstantValue(op.getInitValues()[0])) { + gutils->addToDiffe(op.getInitValues()[0], inDiffe, builder); + } + return success(); + } + + if (isa(innerOp) || isa(innerOp)) { + } + + orig->emitError() << "Unsupported operation in reduction rev autodiff(1): " + << *orig << "\n"; + return failure(); + } + + SmallVector cacheValues(Operation *orig, + MGradientUtilsReverse *gutils) const { + return {}; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + } // namespace void mlir::enzyme::registerStableHLODialectAutoDiffInterface( @@ -196,5 +405,8 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( registerInterfaces(context); ReduceOp::attachInterface>(*context); ReduceOp::attachInterface>(*context); + BroadcastInDimOp::attachInterface(*context); + SliceOp::attachInterface(*context); + ReduceOp::attachInterface(*context); }); } diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index f71729033..dec87152f 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -933,7 +933,7 @@ def make_zero(tan, prim): shadconv = None if pipeline_options.mlir_ad() and kwargs["lang"] == LANG_MHLO: act_tup = ",".join(["enzyme_dup" for a in arg_primals]) - afterad = "arith-raise{stablehlo=true}, enzyme-hlo-opt, cse, canonicalize" + afterad = "arith-raise{stablehlo=true}, enzyme-hlo-opt, print, cse, canonicalize" newpasses = ( "inline{default-pipeline=canonicalize max-iterations=4}," + "enzyme-hlo-opt,cse,enzyme-wrap{infn=main outfn= retTy=enzyme_dup argTys=" @@ -1196,8 +1196,8 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs): ad_pass = ad_pass.replace("ForwardMode", "ReverseModeCombined") newpasses = ( prev_passes - + ad_pass - + ",canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse" + + "print," + ad_pass + + ",canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse, print" + post_passes ) diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index 0971b7870..cf6cf5e22 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -257,7 +257,7 @@ def setUp(self): self.douts = [1.0] def nomlir(x): - return [(name, a) for (name, a) in x if not a.mlir_ad()] + return [(name, a) for (name, a) in x if name != "NewXLAMLIR"] self.revfilter = nomlir @@ -276,7 +276,7 @@ def setUp(self): self.douts = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)] def nomlir(x): - return [(name, a) for (name, a) in x if not a.mlir_ad()] + return [(name, a) for (name, a) in x if name != "NewXLAMLIR"] self.revfilter = nomlir diff --git a/test/llama.py b/test/llama.py index 90eb16069..7eb661ee3 100644 --- a/test/llama.py +++ b/test/llama.py @@ -33,7 +33,7 @@ def silu(x): pipeline = enzyme_jax.NewXLAPipeline(mlirad=True) pipeline = enzyme_jax.JaXPipeline() -pipeline = enzyme_jax.NewXLAPipeline(mlirad=False) +# pipeline = enzyme_jax.NewXLAPipeline(mlirad=False) def forward(x, config, weights, key_cache, value_cache): From 7f3c90d94355da14e90e29ea273bd8fd65758790 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 6 Mar 2024 13:52:35 -0500 Subject: [PATCH 06/17] tmp --- WORKSPACE | 17 +-- .../jax/Implementations/HLODerivatives.td | 2 +- .../jax/Implementations/MHLODerivatives.td | 2 + .../StableHLOAutoDiffOpInterfaceImpl.cpp | 81 ++++++++++- .../Implementations/StableHLODerivatives.td | 2 + src/enzyme_ad/jax/primitives.py | 4 +- test/bench_vs_xla.py | 42 ++++++ test/llama.py | 126 ++++-------------- 8 files changed, 156 insertions(+), 120 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 547cf7dc1..0bd4bc04c 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -62,18 +62,19 @@ pip_install_dependencies() ENZYME_COMMIT = "97066352a40b3c66f9a1f41ec1802af255216c0c" ENZYME_SHA256 = "" -# local_repository( -# name = "enzyme", -# path = "../Enzyme/enzyme" -# ) -http_archive( +local_repository( name = "enzyme", - sha256 = ENZYME_SHA256, - strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", - urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], + path = "../Enzyme/enzyme" ) +# http_archive( +# name = "enzyme", +# sha256 = ENZYME_SHA256, +# strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", +# urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], +# ) + JAX_COMMIT = "9a098e922aff62a3b49bd673b9518d97ee599248" JAX_SHA256 = "" diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index f1c07383a..9c75260a2 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -100,7 +100,7 @@ def : HLOReadOnlyIdentityOp<"SliceOp">; def Reduce : HLOInst<"ReduceOp">; def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">; -def : HLOReadOnlyIdentityOp<"ConcatenateOp">; +def : HLOMemoryIdentityOp<"ConcatenateOp", [], [-1]>; // convert diff --git a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td index ca4de2fd0..2606731da 100644 --- a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td @@ -4,6 +4,8 @@ class HLODerivative resultOps, dag class HLOInst : Inst; +class HLOMemoryIdentityOp ptrargs_, list storedargs_ = [], dag patternToMatch=(Unimplemented), list reverse_ = []> : MemoryIdentityOp<"mhlo", opName_, ptrargs_, storedargs_, patternToMatch, reverse_>; + class HLOReadOnlyIdentityOp ptrargs_ = [0], dag patternToMatch=(Unimplemented), list reverse_ = []> : ReadOnlyIdentityOp<"mhlo", opName_, ptrargs_, patternToMatch, reverse_>; class HLOControlFlowOp : ControlFlowOp<"mhlo", opName_, impl_>; diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index e2f509c3d..c8e5ac1b5 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -197,25 +197,33 @@ class AutoDiffBroadcastInDimRev auto op = cast(orig); auto inTy = op.getOperand().getType(); auto outTy = op.getType(); - auto zero = gutils->getShadowType(inTy) - .cast() - .createNullValue(builder, op.getLoc()); auto inDiffe = gutils->diffe(op, builder); gutils->zeroDiffe(op, builder); SmallVector bcastDims(op.getBroadcastDimensions().begin(), op.getBroadcastDimensions().end()); + Value zero = nullptr; if (bcastDims.size() == 0 && inTy.getShape().size() == 0) { for (size_t i = 0; i < outTy.getShape().size(); i++) { bcastDims.push_back(i); } + zero = gutils->getShadowType(inTy) + .cast() + .createNullValue(builder, op.getLoc()); + } else { + SmallVector dims; + for (size_t i = 0; i < inTy.getShape().size(); i++) { + if (llvm::is_contained(bcastDims, i)) continue; + dims.push_back(i); + } + zero = gutils->getShadowType(RankedTensorType::get(dims, inTy.getElementType())) + .cast() + .createNullValue(builder, op.getLoc()); } - assert(outTy.getShape().size() == - inTy.getShape().size() + bcastDims.size()); auto red = builder.create(op.getLoc(), - TypeRange(gutils->getShadowType(inTy)), + TypeRange(zero.getType()), inDiffe, zero, bcastDims); red.getBody().push_back(new Block()); Block &body = red.getBody().front(); @@ -228,7 +236,14 @@ class AutoDiffBroadcastInDimRev body.getArgument(1)); bodyBuilder.create(op.getLoc(), ValueRange(add)); - gutils->addToDiffe(op.getOperand(), red->getResult(0), builder); + llvm::errs() << " red: " << *red << "\n"; + + Value res = red->getResult(0); + Type resTy = gutils->getShadowType(op.getOperand().getType()); + if (res.getType() != resTy) + res = builder.create(op.getLoc(), resTy, res); + + gutils->addToDiffe(op.getOperand(), res, builder); return success(); } @@ -396,6 +411,57 @@ class AutoDiffReduceRev MGradientUtilsReverse *gutils) const {} }; +class AutoDiffConcatenateRev + : public ReverseAutoDiffOpInterface::ExternalModel { +public: + LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto op = cast(orig); + + auto inDiffe = gutils->diffe(op->getResult(0), builder); + gutils->zeroDiffe(op->getResult(0), builder); + + auto dim = op.getDimension(); + for (auto &ope : op->getOpOperands()) { + auto op = ope.get(); + if (gutils->isConstantValue(op)) continue; + auto inTy = gutils->getShadowType(op.getType()); + SmallVector start; + SmallVector limit; + SmallVector strides; + SmallVector tys; + auto RT = inTy.cast(); + for (auto i=0; i(op.getLoc(), RankedTensorType::get(tys, RT.getElementType()), inDiffe, start, limit, strides); + auto res2 = builder.create(op.getLoc(), inTy, res); + gutils->addToDiffe(op, res2, builder); + } + return success(); + } + + SmallVector cacheValues(Operation *orig, + MGradientUtilsReverse *gutils) const { + return {}; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + } // namespace void mlir::enzyme::registerStableHLODialectAutoDiffInterface( @@ -408,5 +474,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( BroadcastInDimOp::attachInterface(*context); SliceOp::attachInterface(*context); ReduceOp::attachInterface(*context); + ConcatenateOp::attachInterface(*context); }); } diff --git a/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td b/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td index a803c077e..f71b9a61c 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td @@ -4,6 +4,8 @@ class HLODerivative resultOps, dag class HLOInst : Inst; +class HLOMemoryIdentityOp ptrargs_, list storedargs_ = [], dag patternToMatch=(Unimplemented), list reverse_ = []> : MemoryIdentityOp<"stablehlo", opName_, ptrargs_, storedargs_, patternToMatch, reverse_>; + class HLOReadOnlyIdentityOp ptrargs_ = [0], dag patternToMatch=(Unimplemented), list reverse_ = []> : ReadOnlyIdentityOp<"stablehlo", opName_, ptrargs_, patternToMatch, reverse_>; class HLOControlFlowOp : ControlFlowOp<"stablehlo", opName_, impl_>; diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index dec87152f..a1a9173c6 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -933,7 +933,7 @@ def make_zero(tan, prim): shadconv = None if pipeline_options.mlir_ad() and kwargs["lang"] == LANG_MHLO: act_tup = ",".join(["enzyme_dup" for a in arg_primals]) - afterad = "arith-raise{stablehlo=true}, enzyme-hlo-opt, print, cse, canonicalize" + afterad = "arith-raise{stablehlo=true}, enzyme-hlo-opt, cse, canonicalize" newpasses = ( "inline{default-pipeline=canonicalize max-iterations=4}," + "enzyme-hlo-opt,cse,enzyme-wrap{infn=main outfn= retTy=enzyme_dup argTys=" @@ -1197,7 +1197,7 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs): newpasses = ( prev_passes + "print," + ad_pass - + ",canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse, print" + + ",canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse" + post_passes ) diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index cf6cf5e22..cd7f5a9ef 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -287,5 +287,47 @@ def cache(x): self.name = "cache" +class Slicing(EnzymeJaxTest): + def setUp(self): + dim = 3 + self.ins = [jnp.array(range(dim), dtype=jnp.float32).reshape(1, dim, 1)] + self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)] + self.douts = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)] + + def nomlir(x): + return [(name, a) for (name, a) in x if name != "NewXLAMLIR"] + + self.revfilter = nomlir + + def slicing(x): + return x[0, 0:1, 0] * jnp.ones((3,)) + + self.fn = slicing + self.name = "slicing" + + +class ActivityMismatch(EnzymeJaxTest): + def setUp(self): + dim = 12 + self.ins = [jnp.array(range(dim), dtype=jnp.float32)] + self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)] + self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))] + + def nomlir(x): + return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"] + + self.revfilter = nomlir + + def f(x): + toconv2 = jnp.ones((dim, dim)) + k = jnp.einsum('jk,k->j', toconv2, x) + kcl = jnp.zeros((1, dim)) + h = jnp.reshape(k, (1, dim)) + kcl = jnp.append(kcl, h, axis=0) + return kcl + + self.fn = f + self.name = "activitymismatch" + if __name__ == "__main__": absltest.main() diff --git a/test/llama.py b/test/llama.py index 7eb661ee3..0377034a1 100644 --- a/test/llama.py +++ b/test/llama.py @@ -11,21 +11,21 @@ def rmsnorm(x, weight): ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5) - return weight * x * ss + return x # weight * x * ss def softmax(x): - max_val = jnp.max(x) - x = jnp.exp(x - max_val) - return x / sum(x) + # max_val = jnp.max(x) + # x = jnp.exp(x - max_val) + return x # / sum(x) def sigmoid(x): - return 1 / (1 + jnp.exp(-x)) + return 1 # / (1 + jnp.exp(-x)) def silu(x): - return x * sigmoid(x) + return x # * sigmoid(x) # Token is token value @@ -134,7 +134,7 @@ def forward(x, config, weights, key_cache, value_cache): fcr = jnp.cos(val) fci = jnp.sin(val) - rotM = jnp.array([[fcr, -fci], [fci, fcr]]) + rotM = jnp.array([[0.0, -1.0], [1., 0.]]) toconv.append(rotM) toconv2 = toconv[: kv_dim // 2] + [jnp.eye(2)] * (dim // 2 - kv_dim // 2) @@ -143,99 +143,20 @@ def forward(x, config, weights, key_cache, value_cache): keys2 = [] values2 = [] - for l in range(n_layers): - xb = rmsnorm(x, rms_att_weight[l, :]) - if asserts: - assert xb.shape == (dim,) - - q = wq[l, :, :] @ xb - if asserts: - assert q.shape == (dim,) - - k = wk[l, :, :] @ xb - if asserts: - assert q.shape == (kv_dim,) - - v = wv[l, :, :] @ xb - if asserts: - assert q.shape == (kv_dim,) - - q_tmp = jnp.reshape(q, (dim // 2, 2)) - k_tmp = jnp.reshape(k, (dim // 2, 2)) - - # dim == head_size * n_heads - - # Batched gemv - k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,)) - q = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv, q_tmp), (dim,)) - - key_cache_l = key_cache[l, :, :] - key_cache_l = jnp.append(key_cache_l, jnp.reshape(k, (1, dim)), axis=0) - value_cache_l = value_cache[l, :, :] - value_cache_l = jnp.append(value_cache_l, jnp.reshape(v, (1, dim)), axis=0) - keys2.append(key_cache_l) - values2.append(value_cache_l) - - xbs2 = [] - for h in range(n_heads): - q2 = q[head_size * h : head_size * (h + 1)] - if asserts: - assert q2.shape == (head_size,) - - # For kv_mul consecutive heads, they share the same kv cache - # reshape key_cache last dim from (kv_dim,) to (kv_mul, head_size) - # generalized einsum reducing the last dim, the rest are batch - att = [] - - key_index = h // kv_mul - - att = jnp.einsum( - "ij,j->i", - key_cache_l[:, key_index * head_size : (key_index + 1) * head_size], - q2, - ) - - att = att / jnp.sqrt(head_size) - - att = softmax(att) - - x_tmp = jnp.einsum( - "ij,i->j", - value_cache_l[:, key_index * head_size : (key_index + 1) * head_size], - att, - ) - - xbs2.append(x_tmp) - - # Todo right concat - xb = jnp.concatenate(xbs2, axis=None) - - xb2 = wo[l, :, :] @ xb - if asserts: - assert xb2.shape == (dim,) - - x += xb2 - - # Rmsnorm and feedforward swiglu - - xb = rmsnorm(x, rms_ffn_weight[l, :]) - - # Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) - # first calculate self.w1(x) and self.w3(x) - - hb = w1[l, :, :] @ xb - hb2 = w3[l, :, :] @ xb - - hb = silu(hb) + k = wk[0, :, :] @ x + + k_tmp = jnp.reshape(k, (dim // 2, 2)) - hb = hb * hb2 + # dim == head_size * n_heads - xb = w2[l, :, :] @ hb + # Batched gemv + k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,)) - x += xb + key_cache_l = key_cache[0, :, :] + h = jnp.reshape(k, (1, dim)) + key_cache_l = jnp.append(key_cache_l, h, axis=0) - x = rmsnorm(x, rms_final_weight) - logits = wcls @ x + x = key_cache_l return x @@ -243,11 +164,12 @@ def forward(x, config, weights, key_cache, value_cache): class Llama(absltest.TestCase): def test_llama_random(self): config = { - "dim": 288, + "dim": 2, "hidden_dim": 768, - "n_layers": 6, - "n_heads": 6, - "n_kv_heads": 6, + "n_layers": 1, + # "n_heads": 6, + "n_heads": 1, + "n_kv_heads": 1, "vocab_size": 32000, "seq_len": 256, } @@ -404,9 +326,9 @@ def erev(x, weights, kc, vc, dx, dkc, dvc): primals, f_vjp = jax.vjp(efunc, x, weights, kc, vc) return f_vjp(dx) # , dkc, dvc) - eres = erev(x, weights, key_cache, value_cache, dx, dkc, dvc) + eres = erev(x, weights, key_cache, value_cache, res, dkc, dvc) print("Enzyme rev", eres) - jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) + jres = jrev(x, weights, key_cache, value_cache, res, dkc, dvc) print("Jax rev", jres) print( From 53e5d583d6b28184118585270ceab3032ed90c2f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 7 Mar 2024 22:53:20 -0800 Subject: [PATCH 07/17] continue --- BUILD | 29 + WORKSPACE | 2 +- patches/xla.patch | 35 + src/enzyme_ad/jax/BUILD | 41 ++ .../jax/Implementations/HLODerivatives.td | 202 +++++- .../jax/Implementations/MHLODerivatives.td | 2 +- .../StableHLOAutoDiffOpInterfaceImpl.cpp | 112 +++- .../Implementations/StableHLODerivatives.td | 2 +- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 597 +++++++++++++++++- src/enzyme_ad/jax/primitives.py | 2 +- test/bench_vs_xla.py | 47 ++ test/llama.py | 304 ++++++--- 12 files changed, 1234 insertions(+), 141 deletions(-) diff --git a/BUILD b/BUILD index 409ea10c7..06caae17d 100644 --- a/BUILD +++ b/BUILD @@ -18,6 +18,35 @@ py_package( packages = ["@//src/enzyme_ad/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"], ) +cc_binary( + name = "enzymexlamlir-opt", + srcs = ["//src/enzyme_ad/jax:enzymexlamlir-opt.cpp"], + visibility = ["//visibility:public"], + deps = [ + "//src/enzyme_ad/jax:XLADerivatives", + "@enzyme//:EnzymeMLIR", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:OpenMPDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Transforms", + ], +) + py_wheel( name = "enzyme_ad", distribution = "enzyme_ad", diff --git a/WORKSPACE b/WORKSPACE index 0bd4bc04c..b73f72f42 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -39,7 +39,7 @@ http_archive( strip_prefix = "xla-" + XLA_COMMIT, urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], patch_args = ["-p1"], - patches = ["//:patches/xla.patch"], + patches = ["//:patches/xla.patch", "//:patches/xla2.patch", ], ) PYRULES_COMMIT = "fe33a4582c37499f3caeb49a07a78fc7948a8949" diff --git a/patches/xla.patch b/patches/xla.patch index 5e799e11a..175553b26 100644 --- a/patches/xla.patch +++ b/patches/xla.patch @@ -16,3 +16,38 @@ ) cc_library( + +--- a/xla/mlir/backends/cpu/transforms/BUILD ++++ b/xla/mlir/backends/cpu/transforms/BUILD +@@ -4,7 +4,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + + package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], +- default_visibility = ["//xla:internal"], ++ default_visibility = ["//xla:friends"], + licenses = ["notice"], + ) + + gentbl_cc_library( + +--- a/xla/mlir/memref/BUILD ++++ b/xla/mlir/memref/BUILD +@@ -1,6 +1,7 @@ + package_group( + name = "friends", + packages = [ ++ "public", + "//xla/mlir/...", + # copybara:uncomment_begin(google-only) + # # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project. + +--- a/xla/mlir/math/BUILD ++++ b/xla/mlir/math/BUILD +@@ -1,6 +1,7 @@ + package_group( + name = "friends", + packages = [ ++ "public", + "//xla/mlir/...", + # copybara:uncomment_begin(google-only) + # # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project. diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 1a6383b86..43c44e2d1 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -3,6 +3,7 @@ load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("@llvm-project//llvm:tblgen.bzl", "gentbl") +exports_files(["enzymexlamlir-opt.cpp"]) licenses(["notice"]) package( @@ -29,10 +30,12 @@ pybind_library( "@llvm-project//llvm:AsmParser", "@llvm-project//llvm:CodeGen", "@llvm-project//llvm:Core", + "@llvm-project//llvm:MC", "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Linker", "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", ], ) @@ -139,6 +142,15 @@ cc_library( ":stablehlo-derivatives", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_passes", + "@stablehlo//:reference_ops", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:CommonFolders", + "@llvm-project//mlir:Transforms", "@xla//xla/mlir_hlo", "@enzyme//:EnzymeMLIR", ] @@ -174,6 +186,9 @@ pybind_library( "@xla//xla/client:client_library", "@xla//xla/client:executable_build_options", "@xla//xla/client:xla_computation", + "@xla//xla/service:service", + "@xla//xla/service:local_service", + "@xla//xla/service:local_service_utils", "@xla//xla/service:buffer_assignment_proto_cc", "@xla//xla/service:buffer_assignment_proto_cc_impl", "@xla//xla/service/cpu:cpu_executable", @@ -191,6 +206,8 @@ pybind_library( "@xla//xla:xla_proto_cc", "@xla//xla:xla_proto_cc_impl", + "@stablehlo//:stablehlo_ops", + # Make CPU target available to XLA. "@xla//xla/service:cpu_plugin", @@ -198,16 +215,27 @@ pybind_library( "@xla//xla/mlir_hlo", "@xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + "@xla//xla/hlo/ir:hlo", + # This is necessary for XLA protobufs to link "@com_google_protobuf//:protobuf", # MLIR dialects and parser. + "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@xla//xla/mlir_hlo:all_passes", + "@xla//xla:printer", + # EnzymeMLIR "@enzyme//:EnzymeMLIR", + + "@com_google_absl//absl/status:statusor", # Mosaic "@jax//jaxlib/mosaic:tpu_dialect", @@ -230,6 +258,19 @@ pybind_extension( "@com_google_absl//absl/status:statusor", "@stablehlo//:stablehlo_passes", "@xla//xla/stream_executor:stream_executor_impl", + "@xla//xla/mlir/backends/cpu/transforms:passes", + "@xla//xla/mlir/memref/transforms:passes", + "@xla//xla/mlir/math/transforms:passes", + "@xla//xla/mlir/runtime/transforms:passes", + "@xla//xla/mlir_hlo:deallocation_passes", + "@xla//xla/mlir_hlo:lhlo", + "@xla//xla/mlir_hlo:lhlo_gpu", + "@xla//xla/mlir_hlo:all_passes", + "@xla//xla/hlo/ir:hlo", + "@xla//xla/service/cpu:cpu_executable", + "@enzyme//:EnzymeStatic", + "@enzyme//:EnzymeMLIR", + ], visibility = ["//visibility:public"], ) diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index 9c75260a2..513dbeeb6 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -12,7 +12,7 @@ def Sin : HLOInst<"SineOp">; def Sqrt : HLOInst<"SqrtOp">; def Exp : HLOInst<"ExpOp">; -def Dot : HLOInst<"DotGeneralOp">; +def Dot : HLOInst<"DotGeneralOp", "->getResult(0)">; def Compare : HLOInst<"CompareOp">; def Select : HLOInst<"SelectOp">; @@ -92,6 +92,7 @@ def : HLODerivative<"MaxOp", (Op $x, $y), (Select (Compare $x, $y, (LT)), (SelectIfActive $y, (Shadow $y), (HLOConstantFP<"0"> $y)), (SelectIfActive $x, (Shadow $x), (HLOConstantFP<"0"> $x))) >; +def Transpose : HLOInst<"TransposeOp">; def Reshape : HLOInst<"ReshapeOp">; def : HLOReadOnlyIdentityOp<"ReshapeOp", [0], (Op $x), [(Reshape (TypeOf $x), (DiffeRet))]>; @@ -107,10 +108,205 @@ def : HLOMemoryIdentityOp<"ConcatenateOp", [], [-1]>; def ResultDotDim : GlobalExpr; def ResultDotPrec : GlobalExpr; + +def ShadowLHSDotDim : GlobalExpr shadowBatchingDimensions; + for (auto en : llvm::enumerate(existingattr.getLhsBatchingDimensions())) + shadowBatchingDimensions.push_back(en.index()); + + SmallVector rhsContractingDimensions; + SmallVector shadowResultContractingDimensions; + + for (auto en : llvm::enumerate(op.getRhs().getType().getShape())) { + if (llvm::is_contained(existingattr.getRhsBatchingDimensions(), en.index())) continue; + if (llvm::is_contained(existingattr.getRhsContractingDimensions(), en.index())) continue; + rhsContractingDimensions.push_back(en.index()); + shadowResultContractingDimensions.push_back(resultidx++); + resultidx++; + } + + DotDimensionNumbersAttr::get(existingattr.getContext(), shadowBatchingDimensions, existingattr.getRhsBatchingDimensions(), shadowResultContractingDimensions, rhsContractingDimensions); +}]>; + +def ShadowLHSDotRes : GlobalExprgetResult(0).getType().cast(); + SmallVector shapes; + // Result order is batches, lhs results, rhs results [in this case contracting dims] + + for (auto en2 : llvm::enumerate(existingattr.getLhsBatchingDimensions())) { + shapes.push_back(op.getLhs().getType().getShape()[en2.value()]); + } + + for (auto en : llvm::enumerate(op.getLhs().getType().getShape())) { + if (llvm::is_contained(existingattr.getLhsBatchingDimensions(), en.index())) continue; + if (llvm::is_contained(existingattr.getLhsContractingDimensions(), en.index())) continue; + shapes.push_back(en.value()); + } + + for (auto en : llvm::enumerate(op.getRhs().getType().getShape())) { + ssize_t contractidx = -1; + + for (auto en2 : llvm::enumerate(existingattr.getRhsContractingDimensions())) { + if (en2.value() == en.index()) { + contractidx = en2.index(); + break; + } + } + + if (contractidx == -1) continue; + + shapes.push_back(op.getRhs().getType().getShape()[existingattr.getRhsContractingDimensions()[contractidx]]); + } + + RankedTensorType::get(shapes, prev.getElementType()); +}]>; + +def ShadowLHSTranspose : GlobalExpr transposes; + + // Result order is batches, lhs results, rhs results [in this case contracting dims] + for (auto en2 : llvm::enumerate(existingattr.getLhsBatchingDimensions())) { + transposes.push_back(en2.value()); + } + + for (auto en : llvm::enumerate(op.getLhs().getType().getShape())) { + if (llvm::is_contained(existingattr.getLhsBatchingDimensions(), en.index())) continue; + if (llvm::is_contained(existingattr.getLhsContractingDimensions(), en.index())) continue; + transposes.push_back(en.index()); + } + + for (auto en : llvm::enumerate(op.getRhs().getType().getShape())) { + if (llvm::is_contained(existingattr.getRhsBatchingDimensions(), en.index())) continue; + + ssize_t contractidx = -1; + + for (auto en2 : llvm::enumerate(existingattr.getRhsContractingDimensions())) { + if (en2.value() == en.index()) { + contractidx = en2.index(); + break; + } + } + + if (contractidx == -1) continue; + + transposes.push_back(existingattr.getLhsContractingDimensions()[contractidx]); + } + + builder.getNamedAttr(TransposeOp::getAttributeNames()[0], builder.getDenseI64ArrayAttr(transposes)); +}]>; + +def ShadowRHSDotDim : GlobalExpr shadowBatchingDimensions; + for (auto en : llvm::enumerate(existingattr.getLhsBatchingDimensions())) + shadowBatchingDimensions.push_back(en.index()); + + SmallVector lhsContractingDimensions; + SmallVector shadowResultContractingDimensions; + + for (auto en : llvm::enumerate(op.getLhs().getType().getShape())) { + if (llvm::is_contained(existingattr.getLhsBatchingDimensions(), en.index())) continue; + if (llvm::is_contained(existingattr.getLhsContractingDimensions(), en.index())) continue; + lhsContractingDimensions.push_back(en.index()); + shadowResultContractingDimensions.push_back(resultidx++); + resultidx++; + } + + DotDimensionNumbersAttr::get(existingattr.getContext(), existingattr.getLhsBatchingDimensions(), shadowBatchingDimensions, lhsContractingDimensions, shadowResultContractingDimensions); +}]>; + +def ShadowRHSDotRes : GlobalExprgetResult(0).getType().cast(); + SmallVector shapes; + // Result order is batches, lhs results [in this case contracting dims], rhs results + + for (auto en2 : llvm::enumerate(existingattr.getLhsBatchingDimensions())) { + shapes.push_back(op.getLhs().getType().getShape()[en2.value()]); + } + + for (auto en : llvm::enumerate(op.getLhs().getType().getShape())) { + ssize_t contractidx = -1; + + for (auto en2 : llvm::enumerate(existingattr.getLhsContractingDimensions())) { + if (en2.value() == en.index()) { + contractidx = en2.index(); + break; + } + } + + if (contractidx == -1) continue; + + shapes.push_back(op.getLhs().getType().getShape()[existingattr.getLhsContractingDimensions()[contractidx]]); + } + + for (auto en : llvm::enumerate(op.getRhs().getType().getShape())) { + if (llvm::is_contained(existingattr.getRhsBatchingDimensions(), en.index())) continue; + if (llvm::is_contained(existingattr.getRhsContractingDimensions(), en.index())) continue; + shapes.push_back(en.value()); + } + + RankedTensorType::get(shapes, prev.getElementType()); +}]>; + +def ShadowRHSTranspose : GlobalExpr transposes; + + // Result order is batches, lhs results [in this case contracting dims], rhs results + for (auto en2 : llvm::enumerate(existingattr.getRhsBatchingDimensions())) { + transposes.push_back(en2.value()); + } + + for (auto en : llvm::enumerate(op.getLhs().getType().getShape())) { + if (llvm::is_contained(existingattr.getLhsBatchingDimensions(), en.index())) continue; + + ssize_t contractidx = -1; + + for (auto en2 : llvm::enumerate(existingattr.getLhsContractingDimensions())) { + if (en2.value() == en.index()) { + contractidx = en2.index(); + break; + } + } + + if (contractidx == -1) continue; + + transposes.push_back(existingattr.getRhsContractingDimensions()[contractidx]); + } + + + for (auto en : llvm::enumerate(op.getRhs().getType().getShape())) { + if (llvm::is_contained(existingattr.getRhsBatchingDimensions(), en.index())) continue; + if (llvm::is_contained(existingattr.getRhsContractingDimensions(), en.index())) continue; + transposes.push_back(en.index()); + } + + builder.getNamedAttr(TransposeOp::getAttributeNames()[0], builder.getDenseI64ArrayAttr(transposes)); +}]>; + def : HLODerivative<"DotGeneralOp", (Op $lhs, $rhs), [ - (Dot (ResultTypes), (DiffeRet), $rhs, (ResultDotDim), (ResultDotPrec)), - (Dot (ResultTypes), $lhs, (DiffeRet), (ResultDotDim), (ResultDotPrec)) + (Transpose (TypeOf $lhs), (Dot (ShadowLHSDotRes), (DiffeRet), $rhs, (ShadowLHSDotDim), (ResultDotPrec)), (ShadowLHSTranspose)), + (Transpose (TypeOf $rhs), (Dot (ShadowRHSDotRes), $lhs, (DiffeRet), (ShadowRHSDotDim), (ResultDotPrec)), (ShadowRHSTranspose)) ], (Add (SelectIfActive $lhs, (Dot (ResultTypes), (Shadow $lhs), $rhs, (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">)), (SelectIfActive $rhs, (Dot (ResultTypes), $lhs, (Shadow $rhs), (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">))) >; diff --git a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td index 2606731da..58946931c 100644 --- a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td @@ -2,7 +2,7 @@ include "src/enzyme_ad/jax/Implementations/Common.td" class HLODerivative resultOps, dag forwardOps=(ForwardFromSummedReverse)> : MLIRDerivative<"mhlo", opName_, patternToMatch, resultOps, forwardOps>; -class HLOInst : Inst; +class HLOInst : Inst; class HLOMemoryIdentityOp ptrargs_, list storedargs_ = [], dag patternToMatch=(Unimplemented), list reverse_ = []> : MemoryIdentityOp<"mhlo", opName_, ptrargs_, storedargs_, patternToMatch, reverse_>; diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index c8e5ac1b5..5fb428937 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -203,28 +203,19 @@ class AutoDiffBroadcastInDimRev SmallVector bcastDims(op.getBroadcastDimensions().begin(), op.getBroadcastDimensions().end()); - Value zero = nullptr; - if (bcastDims.size() == 0 && inTy.getShape().size() == 0) { - for (size_t i = 0; i < outTy.getShape().size(); i++) { - bcastDims.push_back(i); - } - zero = gutils->getShadowType(inTy) - .cast() - .createNullValue(builder, op.getLoc()); - } else { - SmallVector dims; - for (size_t i = 0; i < inTy.getShape().size(); i++) { - if (llvm::is_contained(bcastDims, i)) continue; - dims.push_back(i); - } - zero = gutils->getShadowType(RankedTensorType::get(dims, inTy.getElementType())) + SmallVector newDims; + for (auto en : llvm::enumerate(outTy.getShape())) { + if (llvm::is_contained(bcastDims, en.index())) continue; + newDims.push_back(en.index()); + } + + Value zero = gutils->getShadowType(inTy) .cast() .createNullValue(builder, op.getLoc()); - } auto red = builder.create(op.getLoc(), TypeRange(zero.getType()), - inDiffe, zero, bcastDims); + inDiffe, zero, newDims); red.getBody().push_back(new Block()); Block &body = red.getBody().front(); OpBuilder bodyBuilder(orig->getContext()); @@ -236,8 +227,6 @@ class AutoDiffBroadcastInDimRev body.getArgument(1)); bodyBuilder.create(op.getLoc(), ValueRange(add)); - llvm::errs() << " red: " << *red << "\n"; - Value res = red->getResult(0); Type resTy = gutils->getShadowType(op.getOperand().getType()); if (res.getType() != resTy) @@ -271,6 +260,25 @@ class AutoDiffSliceRev auto inDiffe = gutils->diffe(op, builder); gutils->zeroDiffe(op, builder); + SmallVector starts; + SmallVector edge_padding_high; + SmallVector interior_padding; + for (auto &&[start, limit, stride, dim] : llvm::zip( + op.getStartIndices(), op.getLimitIndices(), op.getStrides(), inTy.getShape())) { + starts.push_back(start); + edge_padding_high.push_back(dim - limit); + interior_padding.push_back(stride - 1); + } + + + auto zeroPad = RankedTensorType::get({}, inTy.getElementType()).cast().createNullValue(builder, + op.getLoc()); + auto red = builder.create(op.getLoc(), inDiffe, zeroPad, builder.getDenseI64ArrayAttr(starts), builder.getDenseI64ArrayAttr(edge_padding_high), builder.getDenseI64ArrayAttr(interior_padding)); + + gutils->addToDiffe(op.getOperand(), red->getResult(0), builder); + return success(); + #if 0 + Value idxs; { SmallVector concat_data; @@ -288,11 +296,10 @@ class AutoDiffSliceRev for (int32_t i = start; i < limit; i += stride) { data.push_back(i); } - - Value ind = builder.create(op.getLoc(), RankedTensorType::get({data.size()}, builder.getI32Type()), + Value ind = builder.create(op.getLoc(), RankedTensorType::get({(int64_t)data.size()}, builder.getI32Type()), builder.getI32TensorAttr(data)); - auto bcast_ind = builder.getDenseI64ArrayAttr({idx}); + auto bcast_ind = builder.getDenseI64ArrayAttr({(int64_t)idx}); ind = builder.create(op.getLoc(), toConcatType, ind, bcast_ind); inds.push_back(ind); @@ -344,6 +351,7 @@ class AutoDiffSliceRev // gutils->setDiffe(op.getOperand(), red->getResult(0), builder); return success(); + #endif } SmallVector cacheValues(Operation *orig, @@ -371,20 +379,31 @@ class AutoDiffReduceRev Operation &innerOp = op.getBody().front().front(); - auto inTy = op->getOperand(0).getType(); + auto inTy = op->getOperand(0).getType().cast(); auto zero = inTy.cast().createNullValue(builder, op.getLoc()); auto inDiffe = gutils->diffe(op->getResult(0), builder); gutils->zeroDiffe(op->getResult(0), builder); + SmallVector toBroadcast; + { + size_t idx=0; + for (auto en : llvm::enumerate(inTy.getShape())) { + if (llvm::is_contained(op.getDimensions(), en.index())) { + // reduced op + continue; + } + toBroadcast.push_back(idx); + idx++; + } + } + if (isa(innerOp)) { if (!gutils->isConstantValue(op.getInputs()[0])) { Value bcast; + - if (op->getResult(0).getType().cast().getShape().size() == 0) - bcast = builder.create(op.getLoc(), gutils->getShadowType(inTy), inDiffe, builder.getDenseI64ArrayAttr({})); - else - bcast = builder.create(op.getLoc(), gutils->getShadowType(inTy), inDiffe, op.getDimensions()); + bcast = builder.create(op.getLoc(), gutils->getShadowType(inTy), inDiffe, builder.getDenseI64ArrayAttr(toBroadcast)); gutils->addToDiffe(op.getInputs()[0], bcast, builder); } @@ -395,6 +414,34 @@ class AutoDiffReduceRev } if (isa(innerOp) || isa(innerOp)) { + // TODO: technically we should invert the order here to pick the last value (or divide by count) if multiple are the same as the + // result + auto ores = gutils->getNewFromOriginal(op->getResult(0)); + + if (!gutils->isConstantValue(op.getInputs()[0])) { + auto oprev = gutils->getNewFromOriginal(op.getInputs()[0]); + auto attr = builder.getDenseI64ArrayAttr(toBroadcast); + auto bc = builder.create(op.getLoc(), oprev.getType(), ores, attr); + + auto cmp = builder.create(op.getLoc(), bc, oprev, ComparisonDirection::EQ); + + auto bc2 = builder.create(op.getLoc(), oprev.getType(), inDiffe, attr); + + auto res = builder.create(op.getLoc(), cmp, bc2, zero); + gutils->addToDiffe(op.getInputs()[0], res, builder); + } + if (!gutils->isConstantValue(op.getInitValues()[0])) { + auto oprev = gutils->getNewFromOriginal(op.getInitValues()[0]); + + auto zeroI = inDiffe.getType().cast().createNullValue(builder, + op.getLoc()); + + auto cmp = builder.create(op.getLoc(), ores, oprev, ComparisonDirection::EQ); + + auto res = builder.create(op.getLoc(), cmp, inDiffe, zeroI); + gutils->addToDiffe(op.getInitValues()[0], res, builder); + } + return success(); } orig->emitError() << "Unsupported operation in reduction rev autodiff(1): " @@ -424,9 +471,9 @@ class AutoDiffConcatenateRev gutils->zeroDiffe(op->getResult(0), builder); auto dim = op.getDimension(); + size_t startDim = 0; for (auto &ope : op->getOpOperands()) { auto op = ope.get(); - if (gutils->isConstantValue(op)) continue; auto inTy = gutils->getShadowType(op.getType()); SmallVector start; SmallVector limit; @@ -434,18 +481,19 @@ class AutoDiffConcatenateRev SmallVector tys; auto RT = inTy.cast(); for (auto i=0; iisConstantValue(op)) continue; auto res = builder.create(op.getLoc(), RankedTensorType::get(tys, RT.getElementType()), inDiffe, start, limit, strides); auto res2 = builder.create(op.getLoc(), inTy, res); gutils->addToDiffe(op, res2, builder); diff --git a/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td b/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td index f71b9a61c..8881decc6 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td @@ -2,7 +2,7 @@ include "src/enzyme_ad/jax/Implementations/Common.td" class HLODerivative resultOps, dag forwardOps=(ForwardFromSummedReverse)> : MLIRDerivative<"stablehlo", opName_, patternToMatch, resultOps, forwardOps>; -class HLOInst : Inst; +class HLOInst : Inst; class HLOMemoryIdentityOp ptrargs_, list storedargs_ = [], dag patternToMatch=(Unimplemented), list reverse_ = []> : MemoryIdentityOp<"stablehlo", opName_, ptrargs_, storedargs_, patternToMatch, reverse_>; diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 9ad70bfa9..9c4ecc87b 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -19,11 +19,14 @@ #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/Passes.h" +#include "stablehlo/reference/Ops.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/CommonFolders.h" + #define DEBUG_TYPE "enzyme" using namespace mlir; @@ -51,6 +54,357 @@ struct SliceSimplification final : OpRewritePattern { } }; + +DenseElementsAttr fromTensor(stablehlo::Tensor inp) { + auto type = inp.getType(); + auto elemType = type.getElementType(); + + + if (elemType.isF32()) { + auto floatValues = ArrayRef((float*)inp.getData(), inp.getNumElements()); + return DenseFPElementsAttr::get(type, floatValues); + } + + if (elemType.isF64()) { + auto floatValues = ArrayRef((double*)inp.getData(), inp.getNumElements()); + return DenseFPElementsAttr::get(type, floatValues); + } + + if (elemType.isSignlessInteger(8)) { + auto floatValues = ArrayRef((int8_t*)inp.getData(), inp.getNumElements()); + return DenseIntElementsAttr::get(type, floatValues); + } + if (elemType.isSignlessInteger(16)) { + auto floatValues = ArrayRef((int16_t*)inp.getData(), inp.getNumElements()); + return DenseIntElementsAttr::get(type, floatValues); + } + if (elemType.isSignlessInteger(32)) { + auto floatValues = ArrayRef((int32_t*)inp.getData(), inp.getNumElements()); + return DenseIntElementsAttr::get(type, floatValues); + } + if (elemType.isSignlessInteger(64)) { + auto floatValues = ArrayRef((int64_t*)inp.getData(), inp.getNumElements()); + return DenseIntElementsAttr::get(type, floatValues); + } + if (elemType.isUnsignedInteger(8)) { + auto floatValues = ArrayRef((uint8_t*)inp.getData(), inp.getNumElements()); + return DenseIntElementsAttr::get(type, floatValues); + } + if (elemType.isUnsignedInteger(16)) { + auto floatValues = ArrayRef((uint16_t*)inp.getData(), inp.getNumElements()); + return DenseIntElementsAttr::get(type, floatValues); + } + if (elemType.isUnsignedInteger(32)) { + auto floatValues = ArrayRef((uint32_t*)inp.getData(), inp.getNumElements()); + return DenseIntElementsAttr::get(type, floatValues); + } + if (elemType.isUnsignedInteger(64)) { + auto floatValues = ArrayRef((uint64_t*)inp.getData(), inp.getNumElements()); + return DenseIntElementsAttr::get(type, floatValues); + } + + assert(0); +} + +/* +%22 = stablehlo.dot_general %21, %16, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<288x288xf32>, tensor<288xf32>) -> tensor<288xf32> +%27 = stablehlo.reshape %22 : (tensor<288xf32>) -> tensor<144x2xf32> +%28 = stablehlo.dot_general %6, %27, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<144x2x2xf32>, tensor<144x2xf32>) -> tensor<144x2xf32> + +should become + +%a21 = stablehlo.reshape %21 : (tensor<288xf32>) -> tensor<144x2xf32> + +%22 = stablehlo.dot_general %a21, %16, batching_dims = [1] x [], contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<144x2x288xf32>, tensor<288xf32>) -> tensor<2x144xf32> + +%28 = stablehlo.dot_general %6, %22, batching_dims = [0] x [1], contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<144x2x2xf32>, tensor<144x2xf32>) -> tensor<144x2xf32> + +TODO +*/ + +struct DotReshapeDot final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::DotGeneralOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) + return failure(); + + return failure(); + } +}; + + +/* + + %1192 = stablehlo.pad %1189, %cst_0, low = [0], high = [1], interior = [0] : (tensor<1xf32>, tensor) -> tensor<2xf32> + %1193 = arith.addf %1191, %1192 : tensor<2xf32> + +*/ +struct AddPad final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) + return failure(); + + for (int i=0; i<2; i++) { + if (auto lhs = op->getOperand(i).getDefiningOp()) { + auto rhs = op->getOperand(1-i); + + if (!matchPattern(lhs.getPaddingValue(), m_AnyZeroFloat())) { + continue; + } + + bool legal = true; + for (auto step : lhs.getInteriorPadding()) { + if (step != 0) { + legal = true; + break; + } + } + if (!legal) continue; + + ssize_t padidx = -1; + + SmallVector idxs; + for (auto &&[low, high, dim] : llvm::zip(lhs.getEdgePaddingLow(), lhs.getEdgePaddingHigh(), type.getShape())) { + padidx++; + if (low == 0 && high == dim) continue; + idxs.push_back(padidx-1); + } + + if (idxs.size() == 0) { + auto idx = idxs[0]; + + SmallVector strides(type.getShape().size(), 1); + SmallVector starts(type.getShape().size(), 0); + SmallVector limits(type.getShape().begin(), type.getShape().end()); + + starts[idx] = lhs.getEdgePaddingLow()[idx]; + limits[idx] = type.getShape()[idx] - lhs.getEdgePaddingLow()[idx]; + + auto midSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); + + starts[idx] = 0; + limits[idx] = lhs.getEdgePaddingLow()[idx]; + auto prevSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); + + starts[idx] = type.getShape()[idx] - lhs.getEdgePaddingLow()[idx]; + limits[idx] = 0; + auto postSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); + + auto mid = rewriter.create(op.getLoc(), midSlice, lhs.getOperand()); + + Value vals[3] = {prevSlice, mid, postSlice}; + rewriter.replaceOpWithNewOp(op, vals, idx); + return success(); + } + + } + } + + return failure(); + } +}; + +struct ConcatConstProp final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) + return failure(); + + + SmallVector constants; + constants.assign(op->getNumOperands(), DenseElementsAttr()); + bool legal = true; + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) { + matchPattern(op->getOperand(i), m_Constant(&constants[i])); + if (!constants[i]) legal = false; + } + + if (legal) { + + SmallVector inps; + for (auto &c : constants) + inps.push_back(mlir::stablehlo::evalConstantOp(c)); + auto out = mlir::stablehlo::evalConcatenateOp(inps, op.getDimension(), op.getType()); + rewriter.replaceOpWithNewOp(op, op.getType(), fromTensor(out)); + return success(); + } + return failure(); + } +}; + +struct BroadcastToReshape final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) + return failure(); + + DenseElementsAttr inp; + matchPattern(op->getOperand(0), m_Constant(&inp)); + if (inp) { + auto inp0 = mlir::stablehlo::evalConstantOp(inp); + auto out = mlir::stablehlo::evalBroadcastInDimOp(inp0, mlir::stablehlo::Axes(op.getBroadcastDimensions()), op.getType()); + rewriter.replaceOpWithNewOp(op, op.getType(), fromTensor(out)); + return success(); + /* + if (inp.isSplat()) { + rewriter.replaceOpWithNewOp(op, op.getType(), SplatElementsAttr::get(op.getType().getShape(), inp.getSplatValue())); + return success(); + } + */ + } + + // Ensure these are sorted + for (auto en : llvm::enumerate(op.getBroadcastDimensions())) { + if (en.index() == 0) continue; + if (op.getBroadcastDimensions()[en.index() - 1] >= en.value()) { + return failure(); + } + } + + // Check that no new data is added + for (auto en : llvm::enumerate(op.getType().getShape())) { + ssize_t idx=-1; + for (auto en2 : llvm::enumerate(op.getBroadcastDimensions())) { + if (en2.value() == en.index()) + idx = en2.index(); + } + if (idx != -1) { + if (en.value() != op.getOperand().getType().getShape()[idx]) { + return failure(); + } + continue; + } + if (en.value() != 1) return failure(); + } + + // replace with reshape + rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand()); + return success(); + } +}; + +#if 0 +struct ScatterToPad final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ScatterOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) + return failure(); + + Block &body = op.getUpdateComputation().front(); + if (body.size() != 1) + + Operation &innerOp = body.front(); + if (!isa(&innerOp)) { + return failure(); + } + if (innerOp->getNumOperands() != 1) { + return failure(); + } + auto retop = innerOp->getOperand(0).dyn_cast(); + if (!retop) return failure(); + if (retop.getOwner() != &body) return failure(); + if (retop.getArgNumber() != 1) return failure(); + + if (op.getInputs().size() != 1) return failure(); + + mlir::SplatElementsAttr prev; + if (!matchPattern(op.getInputs()[0], m_Constant(&prev))) { + return failure(); + } + + mlir::DenseElementsAttr idx; + if (!matchPattern(op.getScatterIndices()[0], m_Constant(&idx))) { + return failure(); + } + auto idx2 = mlir::stablehlo::evalConstantOp(idx); + + if (!op.getIndicesAreSorted()) return failure(); + if (!op.getUniqueIndices()) return failure(); + + auto dims = op.getScatterDimensionNumbers(); + if (dims.getInsertedWindowDims() != op.getScatterDimsToOperandDims()) + return failure(); + for (auto en : llvm::enumerate(dims.getInsertedWindowDims())) { + if (en.value() != en.index()) return failure(); + } + + auto update = op.getUpdates()[0]; + auto updateTy = update.getType().cast(); + if (op.getIndexVectorDim() != updateTy.getShape().size()) return failure(); + + SmallVector starts; + SmallVector edge_padding_high; + SmallVector interior_padding; + for (size_t lidx = 0; lidx < idx2.getShape()[idx2.getShape().size()-1]; lidx++) { + + uint64_t start = 0; + uint64_t step = 0 + for (size_t incidx = 0; incidx < idx2.getShape()[lidx]; incidx++) { + std::optional value; + bool legal = true; + std::function)> checkAllEqual = [&](SmallVector prefix) { + if (prefix.size() == lidx) + prefix.push_back(incidx); + + if (prefix.size() == idx2.getShape().size()-1) { + prefix.push_back(lidx); + auto cur = idx2.get(prefix); + if (value) { + legal &= value == cur; + } else { + value = cur; + } + return; + } + for (size_t j = 0; j < idx2.getShape()[prefix.size()]; j++) { + SmallVector prefix2(prefix); + prefix2.push_back(j); + checkAllEqual(prefix2); + } + }; + checkAllEqual({}); + assert(value); + + uint64_t cur = (*value).getIntegerValue().getZExtValue(); + if (incidx == 0) { + start = cur; + } else if (incidx == 1) { + step = cur - start; + } else { + // Only support step size of one + if (start + incidx * step != cur) { + return failure(); + } + } + + } + start.push_back(start); + edge_padding_high.push_back(idx2.getShape()[lidx] - start - ); + interior_padding.push_back(step - 1); + } + + auto padval = builder.create(op.getLoc(), RankedTensorType::get({}, prev.getType().getElementType()), prev.getSplatValue()); + auto pad = builder.replaceOpWithNewOp(op, update, padval) + return failure(); +}; +#endif + struct AddSimplify : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -67,6 +421,22 @@ struct AddSimplify : public OpRewritePattern { return success(); } + SmallVector constants; + constants.assign(op->getNumOperands(), Attribute()); + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constants[i])); + + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APFloat &a, const APFloat &b) -> std::optional { + APFloat res2(a); + res2.add(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); + return success(); + } + return failure(); } }; @@ -87,6 +457,231 @@ struct SubSimplify : public OpRewritePattern { return success(); } + SmallVector constants; + constants.assign(op->getNumOperands(), Attribute()); + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constants[i])); + + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APFloat &a, const APFloat &b) -> std::optional { + APFloat res2(a); + res2.subtract(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); + return success(); + } + + return failure(); + } +}; + + +struct NegateSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::NegOp op, + PatternRewriter &rewriter) const final { + + SmallVector constants; + constants.assign(op->getNumOperands(), Attribute()); + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constants[i])); + + if (auto res = mlir::constFoldUnaryOpConditional ( + constants, + [](const APFloat &a) -> std::optional { + return -a; + })) { + rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); + return success(); + } + + return failure(); + } +}; + +struct MulSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, + PatternRewriter &rewriter) const final { + + if (matchPattern(op.getLhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getLhs()); + return success(); + } + if (matchPattern(op.getLhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getRhs()); + return success(); + } + + SmallVector constants; + constants.assign(op->getNumOperands(), Attribute()); + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constants[i])); + + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APFloat &a, const APFloat &b) -> std::optional { + APFloat res2(a); + res2.multiply(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); + return success(); + } + + return failure(); + } +}; + +struct DivSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::DivOp op, + PatternRewriter &rewriter) const final { + + if (matchPattern(op.getLhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getLhs()); + return success(); + } + + SmallVector constants; + constants.assign(op->getNumOperands(), Attribute()); + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constants[i])); + + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APFloat &a, const APFloat &b) -> std::optional { + APFloat res2(a); + res2.divide(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); + return success(); + } + + return failure(); + } +}; + +struct PowSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::PowOp op, + PatternRewriter &rewriter) const final { + + SmallVector constants; + constants.assign(op->getNumOperands(), Attribute()); + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constants[i])); + + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APFloat &a, const APFloat &b) -> std::optional { + if (a.getSizeInBits(a.getSemantics()) == 64 && + b.getSizeInBits(b.getSemantics()) == 64) + return APFloat(pow(a.convertToDouble(), b.convertToDouble())); + + if (a.getSizeInBits(a.getSemantics()) == 32 && + b.getSizeInBits(b.getSemantics()) == 32) + return APFloat(powf(a.convertToFloat(), b.convertToFloat())); + return {}; + })) { + rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); + return success(); + } + + return failure(); + } +}; + +struct CosSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::CosineOp op, + PatternRewriter &rewriter) const final { + + SmallVector constants; + constants.assign(op->getNumOperands(), Attribute()); + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constants[i])); + + if (auto res = constFoldUnaryOpConditional( + constants, + [](const APFloat &a) -> std::optional { + if (a.getSizeInBits(a.getSemantics()) == 64) + return APFloat(cos(a.convertToDouble())); + + if (a.getSizeInBits(a.getSemantics()) == 32) + return APFloat(cosf(a.convertToFloat())); + return {}; + })) { + rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); + return success(); + } + + return failure(); + } +}; + +struct SinSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SineOp op, + PatternRewriter &rewriter) const final { + + SmallVector constants; + constants.assign(op->getNumOperands(), Attribute()); + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constants[i])); + + if (auto res = constFoldUnaryOpConditional( + constants, + [](const APFloat &a) -> std::optional { + if (a.getSizeInBits(a.getSemantics()) == 64) + return APFloat(sin(a.convertToDouble())); + + if (a.getSizeInBits(a.getSemantics()) == 32) + return APFloat(sinf(a.convertToFloat())); + return {}; + })) { + rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); + return success(); + } + + return failure(); + } +}; + +struct SqrtSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SqrtOp op, + PatternRewriter &rewriter) const final { + + SmallVector constants; + constants.assign(op->getNumOperands(), Attribute()); + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constants[i])); + + if (auto res = constFoldUnaryOpConditional( + constants, + [](const APFloat &a) -> std::optional { + if (a.getSizeInBits(a.getSemantics()) == 64) + return APFloat(sqrt(a.convertToDouble())); + + if (a.getSizeInBits(a.getSemantics()) == 32) + return APFloat(sqrtf(a.convertToFloat())); + return {}; + })) { + rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); + return success(); + } + return failure(); } }; @@ -96,7 +691,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); mlir::stablehlo::populateStablehloCanonicalizationPatterns(context, &patterns); diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index a1a9173c6..31aab9c66 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -1197,7 +1197,7 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs): newpasses = ( prev_passes + "print," + ad_pass - + ",canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse" + + ",canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse, print" + post_passes ) diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index cd7f5a9ef..3c241bb3b 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -329,5 +329,52 @@ def f(x): self.fn = f self.name = "activitymismatch" +class GenDot(EnzymeJaxTest): + def setUp(self): + dim = 12 + self.ins = [jnp.array(range(dim), dtype=jnp.float32)] + self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)] + self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))] + + def nomlir(x): + return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"] + + self.revfilter = nomlir + + def f(x): + k = jnp.ones((dim, dim)) @ x + k_tmp = jnp.reshape(k, (2, dim // 2)) + + toconv2 = jnp.ones((2, dim // 2, dim // 2)) + k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,)) + + kcl = jnp.zeros((1, dim)) + + h = jnp.reshape(k, (1, dim)) + kcl = jnp.append(kcl, h, axis=0) + return kcl + + self.fn = f + self.name = "GenDot" + + +class Concat(EnzymeJaxTest): + def setUp(self): + dim = 12 + self.ins = [jnp.array(range(dim), dtype=jnp.float32), 10*jnp.array(range(dim), dtype=jnp.float32)] + self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32), jnp.array([i * i *i / 3. for i in range(dim)], dtype=jnp.float32)] + self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32)] + + def nomlir(x): + return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"] + + self.revfilter = nomlir + + def f(x, y): + return jnp.concat([x, y], axis=None) + + self.fn = f + self.name = "Concat" + if __name__ == "__main__": absltest.main() diff --git a/test/llama.py b/test/llama.py index 0377034a1..fb0cfc99b 100644 --- a/test/llama.py +++ b/test/llama.py @@ -11,21 +11,21 @@ def rmsnorm(x, weight): ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5) - return x # weight * x * ss + return weight * x * ss def softmax(x): - # max_val = jnp.max(x) - # x = jnp.exp(x - max_val) - return x # / sum(x) + max_val = jnp.max(x) + x = jnp.exp(x - max_val) + return x / sum(x) def sigmoid(x): - return 1 # / (1 + jnp.exp(-x)) + return 1 / (1 + jnp.exp(-x)) def silu(x): - return x # * sigmoid(x) + return x * sigmoid(x) # Token is token value @@ -134,7 +134,7 @@ def forward(x, config, weights, key_cache, value_cache): fcr = jnp.cos(val) fci = jnp.sin(val) - rotM = jnp.array([[0.0, -1.0], [1., 0.]]) + rotM = jnp.array([[fcr, -fci], [fci, fcr]]) toconv.append(rotM) toconv2 = toconv[: kv_dim // 2] + [jnp.eye(2)] * (dim // 2 - kv_dim // 2) @@ -143,20 +143,99 @@ def forward(x, config, weights, key_cache, value_cache): keys2 = [] values2 = [] - k = wk[0, :, :] @ x - - k_tmp = jnp.reshape(k, (dim // 2, 2)) + for l in range(n_layers): + xb = rmsnorm(x, rms_att_weight[l, :]) + if asserts: + assert xb.shape == (dim,) + + q = wq[l, :, :] @ xb + if asserts: + assert q.shape == (dim,) + + k = wk[l, :, :] @ xb + if asserts: + assert q.shape == (kv_dim,) + + v = wv[l, :, :] @ xb + if asserts: + assert q.shape == (kv_dim,) + + q_tmp = jnp.reshape(q, (dim // 2, 2)) + k_tmp = jnp.reshape(k, (dim // 2, 2)) + + # dim == head_size * n_heads + + # Batched gemv + k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,)) + q = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv, q_tmp), (dim,)) + + key_cache_l = key_cache[l, :, :] + key_cache_l = jnp.append(key_cache_l, jnp.reshape(k, (1, dim)), axis=0) + value_cache_l = value_cache[l, :, :] + value_cache_l = jnp.append(value_cache_l, jnp.reshape(v, (1, dim)), axis=0) + keys2.append(key_cache_l) + values2.append(value_cache_l) + + xbs2 = [] + for h in range(n_heads): + q2 = q[head_size * h : head_size * (h + 1)] + if asserts: + assert q2.shape == (head_size,) + + # For kv_mul consecutive heads, they share the same kv cache + # reshape key_cache last dim from (kv_dim,) to (kv_mul, head_size) + # generalized einsum reducing the last dim, the rest are batch + att = [] + + key_index = h // kv_mul + + att = jnp.einsum( + "ij,j->i", + key_cache_l[:, key_index * head_size : (key_index + 1) * head_size], + q2, + ) + + att = att / jnp.sqrt(head_size) + + att = softmax(att) + + x_tmp = jnp.einsum( + "ij,i->j", + value_cache_l[:, key_index * head_size : (key_index + 1) * head_size], + att, + ) + + xbs2.append(x_tmp) + + # Todo right concat + xb = jnp.concatenate(xbs2, axis=None) + + xb2 = wo[l, :, :] @ xb + if asserts: + assert xb2.shape == (dim,) + + x += xb2 + + # Rmsnorm and feedforward swiglu + + xb = rmsnorm(x, rms_ffn_weight[l, :]) - # dim == head_size * n_heads + # Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) + # first calculate self.w1(x) and self.w3(x) - # Batched gemv - k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,)) + hb = w1[l, :, :] @ xb + hb2 = w3[l, :, :] @ xb - key_cache_l = key_cache[0, :, :] - h = jnp.reshape(k, (1, dim)) - key_cache_l = jnp.append(key_cache_l, h, axis=0) + hb = silu(hb) - x = key_cache_l + hb = hb * hb2 + + xb = w2[l, :, :] @ hb + + x += xb + + x = rmsnorm(x, rms_final_weight) + logits = wcls @ x return x @@ -164,12 +243,11 @@ def forward(x, config, weights, key_cache, value_cache): class Llama(absltest.TestCase): def test_llama_random(self): config = { - "dim": 2, + "dim": 288, "hidden_dim": 768, - "n_layers": 1, - # "n_heads": 6, - "n_heads": 1, - "n_kv_heads": 1, + "n_layers": 6, + "n_heads": 6, + "n_kv_heads": 6, "vocab_size": 32000, "seq_len": 256, } @@ -232,89 +310,91 @@ def sfn(x, weights, key_cache, value_cache): efunc = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=pipeline)(func) - eres = efunc(x, weights, key_cache, value_cache) - print("Enzyme primal", eres) - res = jfunc(x, weights, key_cache, value_cache) - print("Jax primal", res) - print(" max error", jnp.max(jnp.abs(eres - res))) - assert (jnp.abs(eres - res) < 1e-3).all() - number = 1000 - print( - "Enzyme primal", - timeit.Timer( - "efunc(x, weights, key_cache, value_cache)", - globals={ - "efunc": efunc, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) - print( - "JaX primal", - timeit.Timer( - "jfunc(x, weights, key_cache, value_cache)", - globals={ - "jfunc": jfunc, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) + if False: + eres = efunc(x, weights, key_cache, value_cache) + print("Enzyme primal", eres) + res = jfunc(x, weights, key_cache, value_cache) + print("Jax primal", res) + print(" max error", jnp.max(jnp.abs(eres - res))) + assert (jnp.abs(eres - res) < 1e-3).all() + + print( + "Enzyme primal", + timeit.Timer( + "efunc(x, weights, key_cache, value_cache)", + globals={ + "efunc": efunc, + "x": x, + "weights": weights, + "key_cache": key_cache, + "value_cache": value_cache, + }, + ).timeit(number), + ) + print( + "JaX primal", + timeit.Timer( + "jfunc(x, weights, key_cache, value_cache)", + globals={ + "jfunc": jfunc, + "x": x, + "weights": weights, + "key_cache": key_cache, + "value_cache": value_cache, + }, + ).timeit(number), + ) # jfunc = jax.jit(partial(forward, config)) # mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") - @jax.jit - def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): - return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) + if False: + @jax.jit + def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): + return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) - @jax.jit - def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): - return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) + @jax.jit + def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): + return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) - eres = efwd( - x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache - ) - print("Enzyme fwd", eres) - jres = jfwd( - x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache - ) - print("Jax fwd", jres) - print( - "Enzyme fwd", - timeit.Timer( - "efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)", - globals={ - "efwd": efwd, - "x": x, - "dx": dx, - "weights": weights, - "dweights": dweights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) - print( - "JaX fwd", - timeit.Timer( - "jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)", - globals={ - "jfwd": jfwd, - "x": x, - "dx": dx, - "weights": weights, - "dweights": dweights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) + eres = efwd( + x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache + ) + print("Enzyme fwd", eres) + jres = jfwd( + x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache + ) + print("Jax fwd", jres) + print( + "Enzyme fwd", + timeit.Timer( + "efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)", + globals={ + "efwd": efwd, + "x": x, + "dx": dx, + "weights": weights, + "dweights": dweights, + "key_cache": key_cache, + "value_cache": value_cache, + }, + ).timeit(number), + ) + print( + "JaX fwd", + timeit.Timer( + "jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)", + globals={ + "jfwd": jfwd, + "x": x, + "dx": dx, + "weights": weights, + "dweights": dweights, + "key_cache": key_cache, + "value_cache": value_cache, + }, + ).timeit(number), + ) @jax.jit def jrev(x, weights, kc, vc, dx, dkc, dvc): @@ -326,11 +406,17 @@ def erev(x, weights, kc, vc, dx, dkc, dvc): primals, f_vjp = jax.vjp(efunc, x, weights, kc, vc) return f_vjp(dx) # , dkc, dvc) - eres = erev(x, weights, key_cache, value_cache, res, dkc, dvc) + eres = erev(x, weights, key_cache, value_cache, dx, dkc, dvc) print("Enzyme rev", eres) - jres = jrev(x, weights, key_cache, value_cache, res, dkc, dvc) + jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) print("Jax rev", jres) + jrev2 = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=enzyme_jax.JaXPipeline("inline{default-pipeline=canonicalize max-iterations=4}," + + "canonicalize,cse,print,enzyme-hlo-opt,cse,print"))(jrev) + + jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc) + print("Jax2 rev", jres2) + print( "Enzyme rev", timeit.Timer( @@ -363,6 +449,22 @@ def erev(x, weights, kc, vc, dx, dkc, dvc): }, ).timeit(number), ) + print( + "JaX2 rev", + timeit.Timer( + "jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc)", + globals={ + "jrev2": jrev2, + "x": x, + "weights": weights, + "key_cache": key_cache, + "value_cache": value_cache, + "dx": dx, + "dkc": dkc, + "dvc": dvc, + }, + ).timeit(number), + ) if __name__ == "__main__": From ab7728ab191489f086ff5dc3fb1a8c0abe876df1 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 7 Mar 2024 22:54:10 -0800 Subject: [PATCH 08/17] fix --- src/enzyme_ad/jax/enzymexlamlir-opt.cpp | 118 ++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 src/enzyme_ad/jax/enzymexlamlir-opt.cpp diff --git a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp new file mode 100644 index 000000000..b1543565e --- /dev/null +++ b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp @@ -0,0 +1,118 @@ +//===- enzymemlir-opt.cpp - The enzymemlir-opt driver ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the 'enzymemlir-opt' tool, which is the enzyme analog +// of mlir-opt, used to drive compiler passes, e.g. for testing. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" + +#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Enzyme/MLIR/Passes/Passes.h" + +#include "Enzyme/MLIR/Dialect/Ops.h" + +#include "Implementations/XLADerivatives.h" +#include "Passes/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +using namespace mlir; + +class MemRefInsider + : public mlir::MemRefElementTypeInterface::FallbackModel {}; + +template +struct PtrElementModel + : public mlir::LLVM::PointerElementTypeInterface::ExternalModel< + PtrElementModel, T> {}; + +int main(int argc, char **argv) { + mlir::DialectRegistry registry; + + // Register MLIR stuff + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + + registry.insert(); + + mlir::registerenzymePasses(); + regsiterenzymeXLAPasses(); + mlir::enzyme::registerXLAAutoDiffInterfaces(registry); + + mlir::func::registerInlinerExtension(registry); + + // Register the standard passes we want. + mlir::registerCSEPass(); + mlir::registerConvertAffineToStandardPass(); + mlir::registerSCCPPass(); + mlir::registerInlinerPass(); + mlir::registerCanonicalizerPass(); + mlir::registerSymbolDCEPass(); + mlir::registerLoopInvariantCodeMotionPass(); + mlir::registerConvertSCFToOpenMPPass(); + mlir::affine::registerAffinePasses(); + mlir::registerReconcileUnrealizedCasts(); + + registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { + LLVM::LLVMFunctionType::attachInterface(*ctx); + LLVM::LLVMArrayType::attachInterface(*ctx); + LLVM::LLVMPointerType::attachInterface(*ctx); + LLVM::LLVMStructType::attachInterface(*ctx); + MemRefType::attachInterface>(*ctx); + LLVM::LLVMStructType::attachInterface< + PtrElementModel>(*ctx); + LLVM::LLVMPointerType::attachInterface< + PtrElementModel>(*ctx); + LLVM::LLVMArrayType::attachInterface>( + *ctx); + }); + + // Register the autodiff interface implementations for upstream dialects. + enzyme::registerCoreDialectAutodiffInterfaces(registry); + + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "Enzyme modular optimizer driver", registry)); +} \ No newline at end of file From 216ba31bfe986ac1d8c8ef9debdd8b02ddf33375 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 7 Mar 2024 22:55:17 -0800 Subject: [PATCH 09/17] fix --- patches/xla2.patch | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 patches/xla2.patch diff --git a/patches/xla2.patch b/patches/xla2.patch new file mode 100644 index 000000000..45a67472f --- /dev/null +++ b/patches/xla2.patch @@ -0,0 +1,10 @@ +--- a/xla/mlir/runtime/BUILD ++++ b/xla/mlir/runtime/BUILD +@@ -19,6 +19,7 @@ package_group( + # TODO(ezhulenev): All targets depending on mlir must be under xla/mlir folder + "//xla/service/cpu/...", + "//xla/service/gpu/...", ++ "public", + ], + ) + \ No newline at end of file From 6aba69db70928a5d852d96e5659b51991833eb04 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 8 Mar 2024 02:12:57 -0500 Subject: [PATCH 10/17] fixup --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 29 ++++++++++++++--------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 9c4ecc87b..33c6fa6a2 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -173,33 +173,40 @@ struct AddPad final : OpRewritePattern { SmallVector idxs; for (auto &&[low, high, dim] : llvm::zip(lhs.getEdgePaddingLow(), lhs.getEdgePaddingHigh(), type.getShape())) { padidx++; - if (low == 0 && high == dim) continue; - idxs.push_back(padidx-1); + if (low == 0 && high == 0) continue; + idxs.push_back(padidx); } - if (idxs.size() == 0) { + if (idxs.size() == 1) { auto idx = idxs[0]; SmallVector strides(type.getShape().size(), 1); SmallVector starts(type.getShape().size(), 0); SmallVector limits(type.getShape().begin(), type.getShape().end()); - starts[idx] = lhs.getEdgePaddingLow()[idx]; - limits[idx] = type.getShape()[idx] - lhs.getEdgePaddingLow()[idx]; - - auto midSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); + SmallVector vals; + if (lhs.getEdgePaddingLow()[idx] != 0) { starts[idx] = 0; limits[idx] = lhs.getEdgePaddingLow()[idx]; auto prevSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); + vals.push_back(prevSlice); + } - starts[idx] = type.getShape()[idx] - lhs.getEdgePaddingLow()[idx]; - limits[idx] = 0; - auto postSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); + starts[idx] = lhs.getEdgePaddingLow()[idx]; + limits[idx] = type.getShape()[idx] - lhs.getEdgePaddingHigh()[idx]; + auto midSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); auto mid = rewriter.create(op.getLoc(), midSlice, lhs.getOperand()); + vals.push_back(mid); + + if (lhs.getEdgePaddingHigh()[idx] != 0) { + starts[idx] = type.getShape()[idx] - lhs.getEdgePaddingHigh()[idx]; + limits[idx] = 0; + auto postSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); + vals.push_back(postSlice); + } - Value vals[3] = {prevSlice, mid, postSlice}; rewriter.replaceOpWithNewOp(op, vals, idx); return success(); } From 913b2f68375fc75e83c67da4439fe816d14a6418 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 8 Mar 2024 02:21:46 -0500 Subject: [PATCH 11/17] fix bug --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 33c6fa6a2..8fd689fc2 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -202,7 +202,7 @@ struct AddPad final : OpRewritePattern { if (lhs.getEdgePaddingHigh()[idx] != 0) { starts[idx] = type.getShape()[idx] - lhs.getEdgePaddingHigh()[idx]; - limits[idx] = 0; + limits[idx] = type.getShape()[idx]; auto postSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); vals.push_back(postSlice); } From f0ce46e86bcf64e165d7ab7bed2c59ddea6a83cf Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 7 Mar 2024 23:41:12 -0800 Subject: [PATCH 12/17] fixup --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 46 ++++++++++++++++++++++- src/enzyme_ad/jax/primitives.py | 2 +- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 8fd689fc2..5c0710074 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -54,6 +54,50 @@ struct SliceSimplification final : OpRewritePattern { } }; +struct SliceConcat final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SliceOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) + return failure(); + + auto concat = op.getOperand().getDefiningOp(); + if (!concat) return failure(); + + + auto dim = concat.getDimension(); + + if (op.getStrides()[dim] != 1) return failure(); + + SmallVector postConcat; + size_t curdim = 0; + for (auto v : concat.getInputs()) { + auto ty = v.getType().cast(); + auto nextdim = ty.getShape()[dim]; + if (op.getStartIndices()[dim] < curdim) { + curdim += nextdim; + continue; + } + if (op.getLimitIndices()[dim] >= curdim) { + curdim += nextdim; + continue; + } + SmallVector nstart(op.getStartIndices().begin(), op.getStartIndices().end()); + SmallVector nend(op.getStartIndices().begin(), op.getStartIndices().end()); + nstart[dim] -= curdim; + if (nstart[dim] < 0) nstart[dim] = 0; + nend[dim] -= curdim; + if (nend[dim] > nextdim) nend[dim] = nextdim; + auto subslice = rewriter.create(op.getLoc(), v, nstart, nend, op.getStrides()); + postConcat.push_back(subslice); + } + rewriter.replaceOpWithNewOp(op, postConcat, dim); + return success(); + } +}; + DenseElementsAttr fromTensor(stablehlo::Tensor inp) { auto type = inp.getType(); @@ -698,7 +742,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); mlir::stablehlo::populateStablehloCanonicalizationPatterns(context, &patterns); diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 31aab9c66..f0029ba93 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -1197,7 +1197,7 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs): newpasses = ( prev_passes + "print," + ad_pass - + ",canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse, print" + + ",arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse, print" + post_passes ) From 31081b5455670409972c7d38c677437c90a437c0 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 8 Mar 2024 03:20:48 -0500 Subject: [PATCH 13/17] more fixup --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 30 ++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 5c0710074..4078e88dc 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -76,22 +76,23 @@ struct SliceConcat final : OpRewritePattern { for (auto v : concat.getInputs()) { auto ty = v.getType().cast(); auto nextdim = ty.getShape()[dim]; - if (op.getStartIndices()[dim] < curdim) { + if (op.getStartIndices()[dim] >= curdim + nextdim) { curdim += nextdim; continue; } - if (op.getLimitIndices()[dim] >= curdim) { + if (op.getLimitIndices()[dim] <= curdim) { curdim += nextdim; continue; } SmallVector nstart(op.getStartIndices().begin(), op.getStartIndices().end()); - SmallVector nend(op.getStartIndices().begin(), op.getStartIndices().end()); + SmallVector nend(op.getLimitIndices().begin(), op.getLimitIndices().end()); nstart[dim] -= curdim; if (nstart[dim] < 0) nstart[dim] = 0; nend[dim] -= curdim; if (nend[dim] > nextdim) nend[dim] = nextdim; auto subslice = rewriter.create(op.getLoc(), v, nstart, nend, op.getStrides()); postConcat.push_back(subslice); + curdim += nextdim; } rewriter.replaceOpWithNewOp(op, postConcat, dim); return success(); @@ -271,6 +272,29 @@ struct ConcatConstProp final : OpRewritePattern if (!type) return failure(); + if (op->getNumOperands() == 1) { + rewriter.replaceOp(op, op->getOperand(0)); + return success(); + } + + { + SmallVector subconcat; + bool changed = false; + for (auto v : op->getOperands()) { + if (auto c2 = v.getDefiningOp()) + if (c2.getDimension() == op.getDimension()) { + for (auto v2 : c2->getOperands()) + subconcat.push_back(v2); + changed = true; + continue; + } + subconcat.push_back(v); + } + if (changed) { + rewriter.replaceOpWithNewOp(op, subconcat, op.getDimension()); + return success(); + } + } SmallVector constants; constants.assign(op->getNumOperands(), DenseElementsAttr()); From 5eca730e746f10efcf73afbb921d0d70860d0ab4 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 8 Mar 2024 08:02:05 -0800 Subject: [PATCH 14/17] sliceslice --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 30 ++++++++++++++++++++++- test/llama.py | 2 +- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 4078e88dc..5a6942d31 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -54,6 +54,34 @@ struct SliceSimplification final : OpRewritePattern { } }; +struct SliceSlice final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SliceOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) + return failure(); + + auto prev = op.getOperand().getDefiningOp(); + if (!prev) return failure(); + + SmallVector start; + SmallVector end; + SmallVector step; + + for (auto && [pstart, pend, pstep, nstart, nend, nstep] : llvm::zip(prev.getStartIndices(), prev.getLimitIndices(), prev.getStrides(), + op.getStartIndices(), op.getLimitIndices(), op.getStrides() + )) { + start.push_back(pstart + pstep * nstart); + step.push_back(pstep * nstep); + end.push_back(pstart + pstep * nstep * (nend - nstart)); + } + rewriter.replaceOpWithNewOp(op, prev.getOperand(), start, end, step); + return failure(); + } +}; + struct SliceConcat final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -766,7 +794,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); mlir::stablehlo::populateStablehloCanonicalizationPatterns(context, &patterns); diff --git a/test/llama.py b/test/llama.py index fb0cfc99b..9effa6217 100644 --- a/test/llama.py +++ b/test/llama.py @@ -412,7 +412,7 @@ def erev(x, weights, kc, vc, dx, dkc, dvc): print("Jax rev", jres) jrev2 = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=enzyme_jax.JaXPipeline("inline{default-pipeline=canonicalize max-iterations=4}," - + "canonicalize,cse,print,enzyme-hlo-opt,cse,print"))(jrev) + + "canonicalize,cse,enzyme-hlo-opt,cse"))(jrev) jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc) print("Jax2 rev", jres2) From c9bfb8fc4847feb60bb2bb6dfc2d63ecc541553b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 8 Mar 2024 12:22:07 -0500 Subject: [PATCH 15/17] fix pad opt --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 83 ++++++++++++++++++++--- src/enzyme_ad/jax/compile_with_xla.cc | 2 + src/enzyme_ad/jax/enzymexlamlir-opt.cpp | 3 +- 3 files changed, 78 insertions(+), 10 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 5a6942d31..3b63a98ba 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -75,10 +75,77 @@ struct SliceSlice final : OpRewritePattern { )) { start.push_back(pstart + pstep * nstart); step.push_back(pstep * nstep); - end.push_back(pstart + pstep * nstep * (nend - nstart)); + end.push_back(pstart + pstep * nstart + pstep * nstep * (nend - nstart)); } rewriter.replaceOpWithNewOp(op, prev.getOperand(), start, end, step); - return failure(); + return success(); + } +}; + +// slice(pad x) -> pad(slice x) +struct SlicePad final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SliceOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) + return failure(); + + auto pad = op.getOperand().getDefiningOp(); + if (!pad) return failure(); + + SmallVector start; + SmallVector end; + SmallVector step; + + SmallVector lpads; + SmallVector hpads; + SmallVector interiors; + + bool needspad = false; + for (auto && [nstart, nend, nstep, lpad, hpad, interior, inshape] : llvm::zip(op.getStartIndices(), op.getLimitIndices(), op.getStrides(), pad.getEdgePaddingLow(), pad.getEdgePaddingHigh(), pad.getInteriorPadding(), pad.getOperand().getType().getShape() + )) { + if (nstep != 1) return failure(); + if (interior != 0) return failure(); + + // start of slice starts after end of value being padded + if (nstart - lpad >= inshape) { + rewriter.replaceOpWithNewOp(op, op.getType(), pad.getPaddingValue(), rewriter.getDenseI64ArrayAttr({})); + return success(); + } + // slice ends before the start of value being padded + if (nend - lpad < inshape) { + rewriter.replaceOpWithNewOp(op, op.getType(), pad.getPaddingValue(), rewriter.getDenseI64ArrayAttr({})); + return success(); + } + if (nstart - lpad < 0) { + start.push_back(0); + lpads.push_back(lpad - nstart); + needspad = true; + } else { + start.push_back(nstart - lpad); + lpads.push_back(0); + } + if (nend - lpad > inshape) { + end.push_back(inshape); + hpads.push_back(nend - lpad - inshape); + needspad = true; + } else { + end.push_back(nend - lpad); + hpads.push_back(0); + } + + step.push_back(1); + interiors.push_back(0); + } + if (needspad) { + auto nslice = rewriter.create(op.getLoc(), pad.getOperand(), start, end, step); + rewriter.replaceOpWithNewOp(op, nslice, pad.getPaddingValue(), lpads, hpads, interiors); + } { + rewriter.replaceOpWithNewOp(op, pad.getOperand(), start, end, step); + } + return success(); } }; @@ -357,16 +424,14 @@ struct BroadcastToReshape final : OpRewritePatterngetOperand(0), m_Constant(&inp)); if (inp) { + if (inp.isSplat()) { + rewriter.replaceOpWithNewOp(op, op.getType(), mlir::SplatElementsAttr::get(op.getType(), inp.getSplatValue())); + return success(); + } auto inp0 = mlir::stablehlo::evalConstantOp(inp); auto out = mlir::stablehlo::evalBroadcastInDimOp(inp0, mlir::stablehlo::Axes(op.getBroadcastDimensions()), op.getType()); rewriter.replaceOpWithNewOp(op, op.getType(), fromTensor(out)); return success(); - /* - if (inp.isSplat()) { - rewriter.replaceOpWithNewOp(op, op.getType(), SplatElementsAttr::get(op.getType().getShape(), inp.getSplatValue())); - return success(); - } - */ } // Ensure these are sorted @@ -794,7 +859,7 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); mlir::stablehlo::populateStablehloCanonicalizationPatterns(context, &patterns); diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index d4e7cad9b..f3e703209 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -102,6 +102,7 @@ run_pass_pipeline(const std::vector &oldsym_vec, prepareRegistry(registry); MLIRContext context(registry); context.loadDialect(); + context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); @@ -169,6 +170,7 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, prepareRegistry(registry); mlir::MLIRContext context(registry); context.loadDialect(); + context.loadDialect(); context.loadDialect(); context.loadDialect(); context.loadDialect(); diff --git a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp index b1543565e..da01b43c5 100644 --- a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp +++ b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp @@ -63,6 +63,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -115,4 +116,4 @@ int main(int argc, char **argv) { return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Enzyme modular optimizer driver", registry)); -} \ No newline at end of file +} From 992034fb30c0ccf752392eec2bfa8ea0664b05ab Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 8 Mar 2024 11:03:00 -0800 Subject: [PATCH 16/17] cleanup --- WORKSPACE | 17 +- src/enzyme_ad/jax/BUILD | 2 + .../StableHLOAutoDiffOpInterfaceImpl.cpp | 168 +++--- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 499 ++++++++++-------- src/enzyme_ad/jax/compile_with_xla.cc | 1 + src/enzyme_ad/jax/enzymexlamlir-opt.cpp | 2 +- src/enzyme_ad/jax/primitives.py | 4 +- test/BUILD | 10 + test/llama.py | 6 +- 9 files changed, 407 insertions(+), 302 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index b73f72f42..e4bad745e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -60,21 +60,16 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen pip_install_dependencies() -ENZYME_COMMIT = "97066352a40b3c66f9a1f41ec1802af255216c0c" -ENZYME_SHA256 = "" +ENZYME_COMMIT = "0a129ae7e45114a08f281e50632b9f967fae8396" +ENZYME_SHA256 = "715982efd0a0ef8038e8ad35047e9c1941eb3f9cb038883342969b0bcc8915ad" -local_repository( +http_archive( name = "enzyme", - path = "../Enzyme/enzyme" + sha256 = ENZYME_SHA256, + strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", + urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], ) -# http_archive( -# name = "enzyme", -# sha256 = ENZYME_SHA256, -# strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", -# urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], -# ) - JAX_COMMIT = "9a098e922aff62a3b49bd673b9518d97ee599248" JAX_SHA256 = "" diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 43c44e2d1..c3108b0a0 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -145,6 +145,7 @@ cc_library( "@stablehlo//:reference_ops", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:ControlFlowInterfaces", @@ -225,6 +226,7 @@ pybind_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 5fb428937..0ad4f3751 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -192,8 +192,8 @@ class AutoDiffBroadcastInDimRev AutoDiffBroadcastInDimRev, BroadcastInDimOp> { public: LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { + MGradientUtilsReverse *gutils, + SmallVector caches) const { auto op = cast(orig); auto inTy = op.getOperand().getType(); auto outTy = op.getType(); @@ -204,17 +204,25 @@ class AutoDiffBroadcastInDimRev op.getBroadcastDimensions().end()); SmallVector newDims; + SmallVector reduceShape; for (auto en : llvm::enumerate(outTy.getShape())) { - if (llvm::is_contained(bcastDims, en.index())) continue; + if (llvm::is_contained(bcastDims, en.index())) { + if (en.value() != 1) { + newDims.push_back(en.index()); + } + continue; + } + reduceShape.push_back(en.value()); newDims.push_back(en.index()); } - Value zero = gutils->getShadowType(inTy) - .cast() - .createNullValue(builder, op.getLoc()); + auto reduceTy = RankedTensorType::get(reduceShape, inTy.getElementType()); + + Value zero = gutils->getShadowType(reduceTy) + .cast() + .createNullValue(builder, op.getLoc()); - auto red = builder.create(op.getLoc(), - TypeRange(zero.getType()), + auto red = builder.create(op.getLoc(), TypeRange(zero.getType()), inDiffe, zero, newDims); red.getBody().push_back(new Block()); Block &body = red.getBody().front(); @@ -228,9 +236,9 @@ class AutoDiffBroadcastInDimRev bodyBuilder.create(op.getLoc(), ValueRange(add)); Value res = red->getResult(0); - Type resTy = gutils->getShadowType(op.getOperand().getType()); + Type resTy = gutils->getShadowType(op.getOperand().getType()); if (res.getType() != resTy) - res = builder.create(op.getLoc(), resTy, res); + res = builder.create(op.getLoc(), resTy, res); gutils->addToDiffe(op.getOperand(), res, builder); return success(); @@ -250,8 +258,8 @@ class AutoDiffSliceRev SliceOp> { public: LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { + MGradientUtilsReverse *gutils, + SmallVector caches) const { auto op = cast(orig); auto inTy = op.getOperand().getType(); auto outTy = op.getType(); @@ -263,21 +271,25 @@ class AutoDiffSliceRev SmallVector starts; SmallVector edge_padding_high; SmallVector interior_padding; - for (auto &&[start, limit, stride, dim] : llvm::zip( - op.getStartIndices(), op.getLimitIndices(), op.getStrides(), inTy.getShape())) { + for (auto &&[start, limit, stride, dim] : + llvm::zip(op.getStartIndices(), op.getLimitIndices(), op.getStrides(), + inTy.getShape())) { starts.push_back(start); edge_padding_high.push_back(dim - limit); interior_padding.push_back(stride - 1); } - - auto zeroPad = RankedTensorType::get({}, inTy.getElementType()).cast().createNullValue(builder, - op.getLoc()); - auto red = builder.create(op.getLoc(), inDiffe, zeroPad, builder.getDenseI64ArrayAttr(starts), builder.getDenseI64ArrayAttr(edge_padding_high), builder.getDenseI64ArrayAttr(interior_padding)); + auto zeroPad = RankedTensorType::get({}, inTy.getElementType()) + .cast() + .createNullValue(builder, op.getLoc()); + auto red = builder.create( + op.getLoc(), inDiffe, zeroPad, builder.getDenseI64ArrayAttr(starts), + builder.getDenseI64ArrayAttr(edge_padding_high), + builder.getDenseI64ArrayAttr(interior_padding)); gutils->addToDiffe(op.getOperand(), red->getResult(0), builder); return success(); - #if 0 +#if 0 Value idxs; { @@ -351,7 +363,7 @@ class AutoDiffSliceRev // gutils->setDiffe(op.getOperand(), red->getResult(0), builder); return success(); - #endif +#endif } SmallVector cacheValues(Operation *orig, @@ -368,26 +380,27 @@ class AutoDiffReduceRev ReduceOp> { public: LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { + MGradientUtilsReverse *gutils, + SmallVector caches) const { auto op = cast(orig); if (!isEligibleForCompactPrint(op)) { - orig->emitError() << "Unsupported operation in reduction rev autodiff(1): " - << *orig << "\n"; + orig->emitError() + << "Unsupported operation in reduction rev autodiff(1): " << *orig + << "\n"; return failure(); } Operation &innerOp = op.getBody().front().front(); - + auto inTy = op->getOperand(0).getType().cast(); auto zero = inTy.cast().createNullValue(builder, op.getLoc()); auto inDiffe = gutils->diffe(op->getResult(0), builder); gutils->zeroDiffe(op->getResult(0), builder); - - SmallVector toBroadcast; - { - size_t idx=0; + + SmallVector toBroadcast; + { + size_t idx = 0; for (auto en : llvm::enumerate(inTy.getShape())) { if (llvm::is_contained(op.getDimensions(), en.index())) { // reduced op @@ -396,36 +409,40 @@ class AutoDiffReduceRev toBroadcast.push_back(idx); idx++; } - } + } if (isa(innerOp)) { - if (!gutils->isConstantValue(op.getInputs()[0])) { + if (!gutils->isConstantValue(op.getInputs()[0])) { Value bcast; - - bcast = builder.create(op.getLoc(), gutils->getShadowType(inTy), inDiffe, builder.getDenseI64ArrayAttr(toBroadcast)); + bcast = builder.create( + op.getLoc(), gutils->getShadowType(inTy), inDiffe, + builder.getDenseI64ArrayAttr(toBroadcast)); gutils->addToDiffe(op.getInputs()[0], bcast, builder); - } - if (!gutils->isConstantValue(op.getInitValues()[0])) { + } + if (!gutils->isConstantValue(op.getInitValues()[0])) { gutils->addToDiffe(op.getInitValues()[0], inDiffe, builder); - } - return success(); + } + return success(); } if (isa(innerOp) || isa(innerOp)) { - // TODO: technically we should invert the order here to pick the last value (or divide by count) if multiple are the same as the - // result + // TODO: technically we should invert the order here to pick the last + // value (or divide by count) if multiple are the same as the result auto ores = gutils->getNewFromOriginal(op->getResult(0)); if (!gutils->isConstantValue(op.getInputs()[0])) { auto oprev = gutils->getNewFromOriginal(op.getInputs()[0]); auto attr = builder.getDenseI64ArrayAttr(toBroadcast); - auto bc = builder.create(op.getLoc(), oprev.getType(), ores, attr); + auto bc = builder.create(op.getLoc(), oprev.getType(), + ores, attr); - auto cmp = builder.create(op.getLoc(), bc, oprev, ComparisonDirection::EQ); + auto cmp = builder.create(op.getLoc(), bc, oprev, + ComparisonDirection::EQ); - auto bc2 = builder.create(op.getLoc(), oprev.getType(), inDiffe, attr); + auto bc2 = builder.create( + op.getLoc(), oprev.getType(), inDiffe, attr); auto res = builder.create(op.getLoc(), cmp, bc2, zero); gutils->addToDiffe(op.getInputs()[0], res, builder); @@ -433,19 +450,21 @@ class AutoDiffReduceRev if (!gutils->isConstantValue(op.getInitValues()[0])) { auto oprev = gutils->getNewFromOriginal(op.getInitValues()[0]); - auto zeroI = inDiffe.getType().cast().createNullValue(builder, - op.getLoc()); + auto zeroI = + inDiffe.getType().cast().createNullValue( + builder, op.getLoc()); - auto cmp = builder.create(op.getLoc(), ores, oprev, ComparisonDirection::EQ); + auto cmp = builder.create(op.getLoc(), ores, oprev, + ComparisonDirection::EQ); auto res = builder.create(op.getLoc(), cmp, inDiffe, zeroI); gutils->addToDiffe(op.getInitValues()[0], res, builder); } return success(); } - + orig->emitError() << "Unsupported operation in reduction rev autodiff(1): " - << *orig << "\n"; + << *orig << "\n"; return failure(); } @@ -463,8 +482,8 @@ class AutoDiffConcatenateRev ConcatenateOp> { public: LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { + MGradientUtilsReverse *gutils, + SmallVector caches) const { auto op = cast(orig); auto inDiffe = gutils->diffe(op->getResult(0), builder); @@ -472,31 +491,34 @@ class AutoDiffConcatenateRev auto dim = op.getDimension(); size_t startDim = 0; - for (auto &ope : op->getOpOperands()) { - auto op = ope.get(); - auto inTy = gutils->getShadowType(op.getType()); - SmallVector start; - SmallVector limit; - SmallVector strides; - SmallVector tys; - auto RT = inTy.cast(); - for (auto i=0; igetOpOperands()) { + auto op = ope.get(); + auto inTy = gutils->getShadowType(op.getType()); + SmallVector start; + SmallVector limit; + SmallVector strides; + SmallVector tys; + auto RT = inTy.cast(); + for (auto i = 0; i < RT.getShape().size(); i++) { + tys.push_back(RT.getShape()[i]); + if (i == dim) { + start.push_back(startDim); + limit.push_back(startDim + RT.getShape()[i]); + startDim += RT.getShape()[i]; + strides.push_back(1); + continue; } - if (gutils->isConstantValue(op)) continue; - auto res = builder.create(op.getLoc(), RankedTensorType::get(tys, RT.getElementType()), inDiffe, start, limit, strides); - auto res2 = builder.create(op.getLoc(), inTy, res); - gutils->addToDiffe(op, res2, builder); + start.push_back(0); + limit.push_back(RT.getShape()[i]); + strides.push_back(1); + } + if (gutils->isConstantValue(op)) + continue; + auto res = builder.create( + op.getLoc(), RankedTensorType::get(tys, RT.getElementType()), inDiffe, + start, limit, strides); + auto res2 = builder.create(op.getLoc(), inTy, res); + gutils->addToDiffe(op, res2, builder); } return success(); } diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 3b63a98ba..661bd7fb7 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -18,8 +18,8 @@ #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/transforms/Passes.h" #include "stablehlo/reference/Ops.h" +#include "stablehlo/transforms/Passes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -64,20 +64,22 @@ struct SliceSlice final : OpRewritePattern { return failure(); auto prev = op.getOperand().getDefiningOp(); - if (!prev) return failure(); + if (!prev) + return failure(); SmallVector start; SmallVector end; SmallVector step; - for (auto && [pstart, pend, pstep, nstart, nend, nstep] : llvm::zip(prev.getStartIndices(), prev.getLimitIndices(), prev.getStrides(), - op.getStartIndices(), op.getLimitIndices(), op.getStrides() - )) { + for (auto &&[pstart, pend, pstep, nstart, nend, nstep] : llvm::zip( + prev.getStartIndices(), prev.getLimitIndices(), prev.getStrides(), + op.getStartIndices(), op.getLimitIndices(), op.getStrides())) { start.push_back(pstart + pstep * nstart); step.push_back(pstep * nstep); end.push_back(pstart + pstep * nstart + pstep * nstep * (nend - nstart)); } - rewriter.replaceOpWithNewOp(op, prev.getOperand(), start, end, step); + rewriter.replaceOpWithNewOp(op, prev.getOperand(), + start, end, step); return success(); } }; @@ -93,57 +95,71 @@ struct SlicePad final : OpRewritePattern { return failure(); auto pad = op.getOperand().getDefiningOp(); - if (!pad) return failure(); + if (!pad) + return failure(); SmallVector start; SmallVector end; SmallVector step; - + SmallVector lpads; SmallVector hpads; SmallVector interiors; bool needspad = false; - for (auto && [nstart, nend, nstep, lpad, hpad, interior, inshape] : llvm::zip(op.getStartIndices(), op.getLimitIndices(), op.getStrides(), pad.getEdgePaddingLow(), pad.getEdgePaddingHigh(), pad.getInteriorPadding(), pad.getOperand().getType().getShape() - )) { - if (nstep != 1) return failure(); - if (interior != 0) return failure(); - - // start of slice starts after end of value being padded - if (nstart - lpad >= inshape) { - rewriter.replaceOpWithNewOp(op, op.getType(), pad.getPaddingValue(), rewriter.getDenseI64ArrayAttr({})); - return success(); - } - // slice ends before the start of value being padded - if (nend - lpad < inshape) { - rewriter.replaceOpWithNewOp(op, op.getType(), pad.getPaddingValue(), rewriter.getDenseI64ArrayAttr({})); - return success(); - } - if (nstart - lpad < 0) { - start.push_back(0); - lpads.push_back(lpad - nstart); - needspad = true; - } else { - start.push_back(nstart - lpad); - lpads.push_back(0); - } - if (nend - lpad > inshape) { - end.push_back(inshape); - hpads.push_back(nend - lpad - inshape); - needspad = true; - } else { - end.push_back(nend - lpad); - hpads.push_back(0); - } + for (auto &&[nstart, nend, nstep, lpad, hpad, interior, inshape] : + llvm::zip(op.getStartIndices(), op.getLimitIndices(), op.getStrides(), + pad.getEdgePaddingLow(), pad.getEdgePaddingHigh(), + pad.getInteriorPadding(), + pad.getOperand().getType().getShape())) { + if (nstep != 1) + return failure(); + if (interior != 0) + return failure(); - step.push_back(1); - interiors.push_back(0); + // start of slice starts after end of value being padded + if (nstart - lpad >= inshape) { + rewriter.replaceOpWithNewOp( + op, op.getType(), pad.getPaddingValue(), + rewriter.getDenseI64ArrayAttr({})); + return success(); + } + // slice ends before the start of value being padded + if (nend - lpad < inshape) { + rewriter.replaceOpWithNewOp( + op, op.getType(), pad.getPaddingValue(), + rewriter.getDenseI64ArrayAttr({})); + return success(); + } + if (nstart - lpad < 0) { + start.push_back(0); + lpads.push_back(lpad - nstart); + needspad = true; + } else { + start.push_back(nstart - lpad); + lpads.push_back(0); + } + if (nend - lpad > inshape) { + end.push_back(inshape); + hpads.push_back(nend - lpad - inshape); + needspad = true; + } else { + end.push_back(nend - lpad); + hpads.push_back(0); + } + + step.push_back(1); + interiors.push_back(0); } if (needspad) { - auto nslice = rewriter.create(op.getLoc(), pad.getOperand(), start, end, step); - rewriter.replaceOpWithNewOp(op, nslice, pad.getPaddingValue(), lpads, hpads, interiors); - } { - rewriter.replaceOpWithNewOp(op, pad.getOperand(), start, end, step); + auto nslice = rewriter.create( + op.getLoc(), pad.getOperand(), start, end, step); + rewriter.replaceOpWithNewOp( + op, nslice, pad.getPaddingValue(), lpads, hpads, interiors); + } + { + rewriter.replaceOpWithNewOp(op, pad.getOperand(), + start, end, step); } return success(); } @@ -159,12 +175,13 @@ struct SliceConcat final : OpRewritePattern { return failure(); auto concat = op.getOperand().getDefiningOp(); - if (!concat) return failure(); - + if (!concat) + return failure(); auto dim = concat.getDimension(); - if (op.getStrides()[dim] != 1) return failure(); + if (op.getStrides()[dim] != 1) + return failure(); SmallVector postConcat; size_t curdim = 0; @@ -179,13 +196,18 @@ struct SliceConcat final : OpRewritePattern { curdim += nextdim; continue; } - SmallVector nstart(op.getStartIndices().begin(), op.getStartIndices().end()); - SmallVector nend(op.getLimitIndices().begin(), op.getLimitIndices().end()); + SmallVector nstart(op.getStartIndices().begin(), + op.getStartIndices().end()); + SmallVector nend(op.getLimitIndices().begin(), + op.getLimitIndices().end()); nstart[dim] -= curdim; - if (nstart[dim] < 0) nstart[dim] = 0; + if (nstart[dim] < 0) + nstart[dim] = 0; nend[dim] -= curdim; - if (nend[dim] > nextdim) nend[dim] = nextdim; - auto subslice = rewriter.create(op.getLoc(), v, nstart, nend, op.getStrides()); + if (nend[dim] > nextdim) + nend[dim] = nextdim; + auto subslice = rewriter.create( + op.getLoc(), v, nstart, nend, op.getStrides()); postConcat.push_back(subslice); curdim += nextdim; } @@ -194,52 +216,53 @@ struct SliceConcat final : OpRewritePattern { } }; - DenseElementsAttr fromTensor(stablehlo::Tensor inp) { auto type = inp.getType(); auto elemType = type.getElementType(); - if (elemType.isF32()) { - auto floatValues = ArrayRef((float*)inp.getData(), inp.getNumElements()); + auto floatValues = ArrayRef((float *)inp.getData(), inp.getNumElements()); return DenseFPElementsAttr::get(type, floatValues); } if (elemType.isF64()) { - auto floatValues = ArrayRef((double*)inp.getData(), inp.getNumElements()); + auto floatValues = ArrayRef((double *)inp.getData(), inp.getNumElements()); return DenseFPElementsAttr::get(type, floatValues); } if (elemType.isSignlessInteger(8)) { - auto floatValues = ArrayRef((int8_t*)inp.getData(), inp.getNumElements()); + auto floatValues = ArrayRef((int8_t *)inp.getData(), inp.getNumElements()); return DenseIntElementsAttr::get(type, floatValues); } if (elemType.isSignlessInteger(16)) { - auto floatValues = ArrayRef((int16_t*)inp.getData(), inp.getNumElements()); + auto floatValues = ArrayRef((int16_t *)inp.getData(), inp.getNumElements()); return DenseIntElementsAttr::get(type, floatValues); } if (elemType.isSignlessInteger(32)) { - auto floatValues = ArrayRef((int32_t*)inp.getData(), inp.getNumElements()); + auto floatValues = ArrayRef((int32_t *)inp.getData(), inp.getNumElements()); return DenseIntElementsAttr::get(type, floatValues); } if (elemType.isSignlessInteger(64)) { - auto floatValues = ArrayRef((int64_t*)inp.getData(), inp.getNumElements()); + auto floatValues = ArrayRef((int64_t *)inp.getData(), inp.getNumElements()); return DenseIntElementsAttr::get(type, floatValues); } if (elemType.isUnsignedInteger(8)) { - auto floatValues = ArrayRef((uint8_t*)inp.getData(), inp.getNumElements()); + auto floatValues = ArrayRef((uint8_t *)inp.getData(), inp.getNumElements()); return DenseIntElementsAttr::get(type, floatValues); } if (elemType.isUnsignedInteger(16)) { - auto floatValues = ArrayRef((uint16_t*)inp.getData(), inp.getNumElements()); + auto floatValues = + ArrayRef((uint16_t *)inp.getData(), inp.getNumElements()); return DenseIntElementsAttr::get(type, floatValues); } if (elemType.isUnsignedInteger(32)) { - auto floatValues = ArrayRef((uint32_t*)inp.getData(), inp.getNumElements()); + auto floatValues = + ArrayRef((uint32_t *)inp.getData(), inp.getNumElements()); return DenseIntElementsAttr::get(type, floatValues); } if (elemType.isUnsignedInteger(64)) { - auto floatValues = ArrayRef((uint64_t*)inp.getData(), inp.getNumElements()); + auto floatValues = + ArrayRef((uint64_t *)inp.getData(), inp.getNumElements()); return DenseIntElementsAttr::get(type, floatValues); } @@ -247,17 +270,24 @@ DenseElementsAttr fromTensor(stablehlo::Tensor inp) { } /* -%22 = stablehlo.dot_general %21, %16, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<288x288xf32>, tensor<288xf32>) -> tensor<288xf32> +%22 = stablehlo.dot_general %21, %16, contracting_dims = [1] x [0], precision = +[DEFAULT, DEFAULT] : (tensor<288x288xf32>, tensor<288xf32>) -> tensor<288xf32> %27 = stablehlo.reshape %22 : (tensor<288xf32>) -> tensor<144x2xf32> -%28 = stablehlo.dot_general %6, %27, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<144x2x2xf32>, tensor<144x2xf32>) -> tensor<144x2xf32> +%28 = stablehlo.dot_general %6, %27, batching_dims = [0] x [0], contracting_dims += [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<144x2x2xf32>, +tensor<144x2xf32>) -> tensor<144x2xf32> should become %a21 = stablehlo.reshape %21 : (tensor<288xf32>) -> tensor<144x2xf32> -%22 = stablehlo.dot_general %a21, %16, batching_dims = [1] x [], contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<144x2x288xf32>, tensor<288xf32>) -> tensor<2x144xf32> +%22 = stablehlo.dot_general %a21, %16, batching_dims = [1] x [], +contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : +(tensor<144x2x288xf32>, tensor<288xf32>) -> tensor<2x144xf32> -%28 = stablehlo.dot_general %6, %22, batching_dims = [0] x [1], contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<144x2x2xf32>, tensor<144x2xf32>) -> tensor<144x2xf32> +%28 = stablehlo.dot_general %6, %22, batching_dims = [0] x [1], contracting_dims += [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<144x2x2xf32>, +tensor<144x2xf32>) -> tensor<144x2xf32> TODO */ @@ -275,11 +305,11 @@ struct DotReshapeDot final : OpRewritePattern { } }; - /* - %1192 = stablehlo.pad %1189, %cst_0, low = [0], high = [1], interior = [0] : (tensor<1xf32>, tensor) -> tensor<2xf32> - %1193 = arith.addf %1191, %1192 : tensor<2xf32> + %1192 = stablehlo.pad %1189, %cst_0, low = [0], high = [1], interior = [0] : + (tensor<1xf32>, tensor) -> tensor<2xf32> %1193 = arith.addf %1191, %1192 + : tensor<2xf32> */ struct AddPad final : OpRewritePattern { @@ -291,9 +321,9 @@ struct AddPad final : OpRewritePattern { if (!type) return failure(); - for (int i=0; i<2; i++) { + for (int i = 0; i < 2; i++) { if (auto lhs = op->getOperand(i).getDefiningOp()) { - auto rhs = op->getOperand(1-i); + auto rhs = op->getOperand(1 - i); if (!matchPattern(lhs.getPaddingValue(), m_AnyZeroFloat())) { continue; @@ -306,14 +336,18 @@ struct AddPad final : OpRewritePattern { break; } } - if (!legal) continue; + if (!legal) + continue; ssize_t padidx = -1; SmallVector idxs; - for (auto &&[low, high, dim] : llvm::zip(lhs.getEdgePaddingLow(), lhs.getEdgePaddingHigh(), type.getShape())) { + for (auto &&[low, high, dim] : + llvm::zip(lhs.getEdgePaddingLow(), lhs.getEdgePaddingHigh(), + type.getShape())) { padidx++; - if (low == 0 && high == 0) continue; + if (low == 0 && high == 0) + continue; idxs.push_back(padidx); } @@ -322,35 +356,39 @@ struct AddPad final : OpRewritePattern { SmallVector strides(type.getShape().size(), 1); SmallVector starts(type.getShape().size(), 0); - SmallVector limits(type.getShape().begin(), type.getShape().end()); + SmallVector limits(type.getShape().begin(), + type.getShape().end()); SmallVector vals; if (lhs.getEdgePaddingLow()[idx] != 0) { - starts[idx] = 0; - limits[idx] = lhs.getEdgePaddingLow()[idx]; - auto prevSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); - vals.push_back(prevSlice); + starts[idx] = 0; + limits[idx] = lhs.getEdgePaddingLow()[idx]; + auto prevSlice = rewriter.create( + op.getLoc(), rhs, starts, limits, strides); + vals.push_back(prevSlice); } starts[idx] = lhs.getEdgePaddingLow()[idx]; limits[idx] = type.getShape()[idx] - lhs.getEdgePaddingHigh()[idx]; - auto midSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); - auto mid = rewriter.create(op.getLoc(), midSlice, lhs.getOperand()); + auto midSlice = rewriter.create( + op.getLoc(), rhs, starts, limits, strides); + auto mid = rewriter.create(op.getLoc(), midSlice, + lhs.getOperand()); vals.push_back(mid); if (lhs.getEdgePaddingHigh()[idx] != 0) { - starts[idx] = type.getShape()[idx] - lhs.getEdgePaddingHigh()[idx]; - limits[idx] = type.getShape()[idx]; - auto postSlice = rewriter.create(op.getLoc(), rhs, starts, limits, strides); - vals.push_back(postSlice); + starts[idx] = type.getShape()[idx] - lhs.getEdgePaddingHigh()[idx]; + limits[idx] = type.getShape()[idx]; + auto postSlice = rewriter.create( + op.getLoc(), rhs, starts, limits, strides); + vals.push_back(postSlice); } rewriter.replaceOpWithNewOp(op, vals, idx); return success(); } - } } @@ -358,7 +396,8 @@ struct AddPad final : OpRewritePattern { } }; -struct ConcatConstProp final : OpRewritePattern { +struct ConcatConstProp final + : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, @@ -373,22 +412,23 @@ struct ConcatConstProp final : OpRewritePattern } { - SmallVector subconcat; - bool changed = false; - for (auto v : op->getOperands()) { - if (auto c2 = v.getDefiningOp()) - if (c2.getDimension() == op.getDimension()) { - for (auto v2 : c2->getOperands()) - subconcat.push_back(v2); - changed = true; - continue; - } - subconcat.push_back(v); - } - if (changed) { - rewriter.replaceOpWithNewOp(op, subconcat, op.getDimension()); - return success(); - } + SmallVector subconcat; + bool changed = false; + for (auto v : op->getOperands()) { + if (auto c2 = v.getDefiningOp()) + if (c2.getDimension() == op.getDimension()) { + for (auto v2 : c2->getOperands()) + subconcat.push_back(v2); + changed = true; + continue; + } + subconcat.push_back(v); + } + if (changed) { + rewriter.replaceOpWithNewOp( + op, subconcat, op.getDimension()); + return success(); + } } SmallVector constants; @@ -396,7 +436,8 @@ struct ConcatConstProp final : OpRewritePattern bool legal = true; for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) { matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (!constants[i]) legal = false; + if (!constants[i]) + legal = false; } if (legal) { @@ -404,15 +445,18 @@ struct ConcatConstProp final : OpRewritePattern SmallVector inps; for (auto &c : constants) inps.push_back(mlir::stablehlo::evalConstantOp(c)); - auto out = mlir::stablehlo::evalConcatenateOp(inps, op.getDimension(), op.getType()); - rewriter.replaceOpWithNewOp(op, op.getType(), fromTensor(out)); + auto out = mlir::stablehlo::evalConcatenateOp(inps, op.getDimension(), + op.getType()); + rewriter.replaceOpWithNewOp(op, op.getType(), + fromTensor(out)); return success(); } return failure(); } }; -struct BroadcastToReshape final : OpRewritePattern { +struct BroadcastToReshape final + : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, @@ -425,18 +469,25 @@ struct BroadcastToReshape final : OpRewritePatterngetOperand(0), m_Constant(&inp)); if (inp) { if (inp.isSplat()) { - rewriter.replaceOpWithNewOp(op, op.getType(), mlir::SplatElementsAttr::get(op.getType(), inp.getSplatValue())); + rewriter.replaceOpWithNewOp( + op, op.getType(), + mlir::SplatElementsAttr::get(op.getType(), + inp.getSplatValue())); return success(); } auto inp0 = mlir::stablehlo::evalConstantOp(inp); - auto out = mlir::stablehlo::evalBroadcastInDimOp(inp0, mlir::stablehlo::Axes(op.getBroadcastDimensions()), op.getType()); - rewriter.replaceOpWithNewOp(op, op.getType(), fromTensor(out)); + auto out = mlir::stablehlo::evalBroadcastInDimOp( + inp0, mlir::stablehlo::Axes(op.getBroadcastDimensions()), + op.getType()); + rewriter.replaceOpWithNewOp(op, op.getType(), + fromTensor(out)); return success(); } // Ensure these are sorted for (auto en : llvm::enumerate(op.getBroadcastDimensions())) { - if (en.index() == 0) continue; + if (en.index() == 0) + continue; if (op.getBroadcastDimensions()[en.index() - 1] >= en.value()) { return failure(); } @@ -444,7 +495,7 @@ struct BroadcastToReshape final : OpRewritePattern(op, op.getType(), op.getOperand()); + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getOperand()); return success(); } }; @@ -594,15 +647,18 @@ struct AddSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = constFoldBinaryOpConditional( - constants, - [](const APFloat &a, const APFloat &b) -> std::optional { - APFloat res2(a); - res2.add(b, llvm::RoundingMode::NearestTiesToEven); - return res2; - })) { - rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); - return success(); + if (auto res = + constFoldBinaryOpConditional( + constants, + [](const APFloat &a, + const APFloat &b) -> std::optional { + APFloat res2(a); + res2.add(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); } return failure(); @@ -630,22 +686,24 @@ struct SubSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = constFoldBinaryOpConditional( - constants, - [](const APFloat &a, const APFloat &b) -> std::optional { - APFloat res2(a); - res2.subtract(b, llvm::RoundingMode::NearestTiesToEven); - return res2; - })) { - rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); - return success(); + if (auto res = + constFoldBinaryOpConditional( + constants, + [](const APFloat &a, + const APFloat &b) -> std::optional { + APFloat res2(a); + res2.subtract(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); } return failure(); } }; - struct NegateSimplify : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -657,13 +715,15 @@ struct NegateSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = mlir::constFoldUnaryOpConditional ( - constants, - [](const APFloat &a) -> std::optional { - return -a; - })) { - rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); - return success(); + if (auto res = + mlir::constFoldUnaryOpConditional( + constants, [](const APFloat &a) -> std::optional { + return -a; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); } return failure(); @@ -690,15 +750,18 @@ struct MulSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = constFoldBinaryOpConditional( - constants, - [](const APFloat &a, const APFloat &b) -> std::optional { - APFloat res2(a); - res2.multiply(b, llvm::RoundingMode::NearestTiesToEven); - return res2; - })) { - rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); - return success(); + if (auto res = + constFoldBinaryOpConditional( + constants, + [](const APFloat &a, + const APFloat &b) -> std::optional { + APFloat res2(a); + res2.multiply(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); } return failure(); @@ -721,15 +784,18 @@ struct DivSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = constFoldBinaryOpConditional( - constants, - [](const APFloat &a, const APFloat &b) -> std::optional { - APFloat res2(a); - res2.divide(b, llvm::RoundingMode::NearestTiesToEven); - return res2; - })) { - rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); - return success(); + if (auto res = + constFoldBinaryOpConditional( + constants, + [](const APFloat &a, + const APFloat &b) -> std::optional { + APFloat res2(a); + res2.divide(b, llvm::RoundingMode::NearestTiesToEven); + return res2; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); } return failure(); @@ -747,20 +813,22 @@ struct PowSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = constFoldBinaryOpConditional( - constants, - [](const APFloat &a, const APFloat &b) -> std::optional { - if (a.getSizeInBits(a.getSemantics()) == 64 && - b.getSizeInBits(b.getSemantics()) == 64) - return APFloat(pow(a.convertToDouble(), b.convertToDouble())); - - if (a.getSizeInBits(a.getSemantics()) == 32 && - b.getSizeInBits(b.getSemantics()) == 32) - return APFloat(powf(a.convertToFloat(), b.convertToFloat())); - return {}; - })) { - rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); - return success(); + if (auto res = constFoldBinaryOpConditional( + constants, + [](const APFloat &a, const APFloat &b) -> std::optional { + if (a.getSizeInBits(a.getSemantics()) == 64 && + b.getSizeInBits(b.getSemantics()) == 64) + return APFloat(pow(a.convertToDouble(), b.convertToDouble())); + + if (a.getSizeInBits(a.getSemantics()) == 32 && + b.getSizeInBits(b.getSemantics()) == 32) + return APFloat(powf(a.convertToFloat(), b.convertToFloat())); + return {}; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); } return failure(); @@ -778,18 +846,19 @@ struct CosSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = constFoldUnaryOpConditional( - constants, - [](const APFloat &a) -> std::optional { - if (a.getSizeInBits(a.getSemantics()) == 64) - return APFloat(cos(a.convertToDouble())); - - if (a.getSizeInBits(a.getSemantics()) == 32) - return APFloat(cosf(a.convertToFloat())); - return {}; - })) { - rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); - return success(); + if (auto res = + constFoldUnaryOpConditional( + constants, [](const APFloat &a) -> std::optional { + if (a.getSizeInBits(a.getSemantics()) == 64) + return APFloat(cos(a.convertToDouble())); + + if (a.getSizeInBits(a.getSemantics()) == 32) + return APFloat(cosf(a.convertToFloat())); + return {}; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); } return failure(); @@ -807,18 +876,19 @@ struct SinSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = constFoldUnaryOpConditional( - constants, - [](const APFloat &a) -> std::optional { - if (a.getSizeInBits(a.getSemantics()) == 64) - return APFloat(sin(a.convertToDouble())); - - if (a.getSizeInBits(a.getSemantics()) == 32) - return APFloat(sinf(a.convertToFloat())); - return {}; - })) { - rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); - return success(); + if (auto res = + constFoldUnaryOpConditional( + constants, [](const APFloat &a) -> std::optional { + if (a.getSizeInBits(a.getSemantics()) == 64) + return APFloat(sin(a.convertToDouble())); + + if (a.getSizeInBits(a.getSemantics()) == 32) + return APFloat(sinf(a.convertToFloat())); + return {}; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); } return failure(); @@ -836,18 +906,19 @@ struct SqrtSimplify : public OpRewritePattern { for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&constants[i])); - if (auto res = constFoldUnaryOpConditional( - constants, - [](const APFloat &a) -> std::optional { - if (a.getSizeInBits(a.getSemantics()) == 64) - return APFloat(sqrt(a.convertToDouble())); - - if (a.getSizeInBits(a.getSemantics()) == 32) - return APFloat(sqrtf(a.convertToFloat())); - return {}; - })) { - rewriter.replaceOpWithNewOp(op, op.getType(), res.cast()); - return success(); + if (auto res = + constFoldUnaryOpConditional( + constants, [](const APFloat &a) -> std::optional { + if (a.getSizeInBits(a.getSemantics()) == 64) + return APFloat(sqrt(a.convertToDouble())); + + if (a.getSizeInBits(a.getSemantics()) == 32) + return APFloat(sqrtf(a.convertToFloat())); + return {}; + })) { + rewriter.replaceOpWithNewOp( + op, op.getType(), res.cast()); + return success(); } return failure(); @@ -859,7 +930,11 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { auto context = getOperation()->getContext(); RewritePatternSet patterns(context); - patterns.add(context); + patterns.add(context); mlir::stablehlo::populateStablehloCanonicalizationPatterns(context, &patterns); diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index f3e703209..b3a8c4cfb 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -10,6 +10,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" #include "stablehlo/dialect/StablehloOps.h" diff --git a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp index da01b43c5..7d748d5b1 100644 --- a/src/enzyme_ad/jax/enzymexlamlir-opt.cpp +++ b/src/enzyme_ad/jax/enzymexlamlir-opt.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -24,7 +25,6 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/InitAllPasses.h" diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index f0029ba93..c8af7b98a 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -1196,8 +1196,8 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs): ad_pass = ad_pass.replace("ForwardMode", "ReverseModeCombined") newpasses = ( prev_passes - + "print," + ad_pass - + ",arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse, print" + + ad_pass + + ",arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math, enzyme-hlo-opt, canonicalize, cse" + post_passes ) diff --git a/test/BUILD b/test/BUILD index 652601d50..464f40a2c 100644 --- a/test/BUILD +++ b/test/BUILD @@ -63,3 +63,13 @@ py_test( ], ) + +py_test( + name = "llama", + srcs = [ + "llama.py", + ], + deps = [ + "//src/enzyme_ad/jax:enzyme_jax_internal", + ], +) diff --git a/test/llama.py b/test/llama.py index 9effa6217..cc588b0a2 100644 --- a/test/llama.py +++ b/test/llama.py @@ -310,8 +310,8 @@ def sfn(x, weights, key_cache, value_cache): efunc = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=pipeline)(func) - number = 1000 - if False: + number = 100 + if True: eres = efunc(x, weights, key_cache, value_cache) print("Enzyme primal", eres) res = jfunc(x, weights, key_cache, value_cache) @@ -348,7 +348,7 @@ def sfn(x, weights, key_cache, value_cache): # jfunc = jax.jit(partial(forward, config)) # mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") - if False: + if True: @jax.jit def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) From 9ed373ae7e2a367b46e87703af0c864df988730e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 8 Mar 2024 11:47:01 -0800 Subject: [PATCH 17/17] fix --- .../StableHLOAutoDiffOpInterfaceImpl.cpp | 27 ++++++---- test/bench_vs_xla.py | 52 +++++++++++++++---- test/llama.py | 10 +++- 3 files changed, 67 insertions(+), 22 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 0ad4f3751..79ea13790 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -203,27 +203,36 @@ class AutoDiffBroadcastInDimRev SmallVector bcastDims(op.getBroadcastDimensions().begin(), op.getBroadcastDimensions().end()); - SmallVector newDims; - SmallVector reduceShape; + SmallVector reducedDims; + SmallVector iterShape; for (auto en : llvm::enumerate(outTy.getShape())) { - if (llvm::is_contained(bcastDims, en.index())) { - if (en.value() != 1) { - newDims.push_back(en.index()); + ssize_t bcastIdx = -1; + for (auto en2 : llvm::enumerate(bcastDims)) { + if (en2.value() == en.index()) { + bcastIdx = en2.index(); + break; + } + } + if (bcastIdx != -1) { + if (en.value() != inTy.getShape()[bcastIdx]) { + reducedDims.push_back(en.index()); + assert(inTy.getShape()[bcastIdx] == 1); + } else { + iterShape.push_back(inTy.getShape()[bcastIdx]); } continue; } - reduceShape.push_back(en.value()); - newDims.push_back(en.index()); + reducedDims.push_back(en.index()); } - auto reduceTy = RankedTensorType::get(reduceShape, inTy.getElementType()); + auto reduceTy = RankedTensorType::get(iterShape, inTy.getElementType()); Value zero = gutils->getShadowType(reduceTy) .cast() .createNullValue(builder, op.getLoc()); auto red = builder.create(op.getLoc(), TypeRange(zero.getType()), - inDiffe, zero, newDims); + inDiffe, zero, reducedDims); red.getBody().push_back(new Block()); Block &body = red.getBody().front(); OpBuilder bodyBuilder(orig->getContext()); diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index 3c241bb3b..5057f5273 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -291,7 +291,9 @@ class Slicing(EnzymeJaxTest): def setUp(self): dim = 3 self.ins = [jnp.array(range(dim), dtype=jnp.float32).reshape(1, dim, 1)] - self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)] + self.dins = [ + jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1) + ] self.douts = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)] def nomlir(x): @@ -311,16 +313,24 @@ def setUp(self): dim = 12 self.ins = [jnp.array(range(dim), dtype=jnp.float32)] self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)] - self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))] + self.douts = [ + jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape( + (2, dim) + ) + ] def nomlir(x): - return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"] + return [ + (name, a) + for (name, a) in x + if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" + ] self.revfilter = nomlir def f(x): toconv2 = jnp.ones((dim, dim)) - k = jnp.einsum('jk,k->j', toconv2, x) + k = jnp.einsum("jk,k->j", toconv2, x) kcl = jnp.zeros((1, dim)) h = jnp.reshape(k, (1, dim)) kcl = jnp.append(kcl, h, axis=0) @@ -329,15 +339,24 @@ def f(x): self.fn = f self.name = "activitymismatch" + class GenDot(EnzymeJaxTest): def setUp(self): dim = 12 self.ins = [jnp.array(range(dim), dtype=jnp.float32)] self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)] - self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))] + self.douts = [ + jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape( + (2, dim) + ) + ] def nomlir(x): - return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"] + return [ + (name, a) + for (name, a) in x + if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" + ] self.revfilter = nomlir @@ -349,7 +368,7 @@ def f(x): k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,)) kcl = jnp.zeros((1, dim)) - + h = jnp.reshape(k, (1, dim)) kcl = jnp.append(kcl, h, axis=0) return kcl @@ -361,12 +380,22 @@ def f(x): class Concat(EnzymeJaxTest): def setUp(self): dim = 12 - self.ins = [jnp.array(range(dim), dtype=jnp.float32), 10*jnp.array(range(dim), dtype=jnp.float32)] - self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32), jnp.array([i * i *i / 3. for i in range(dim)], dtype=jnp.float32)] - self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32)] + self.ins = [ + jnp.array(range(dim), dtype=jnp.float32), + 10 * jnp.array(range(dim), dtype=jnp.float32), + ] + self.dins = [ + jnp.array([i * i for i in range(dim)], dtype=jnp.float32), + jnp.array([i * i * i / 3.0 for i in range(dim)], dtype=jnp.float32), + ] + self.douts = [jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32)] def nomlir(x): - return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"] + return [ + (name, a) + for (name, a) in x + if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" + ] self.revfilter = nomlir @@ -376,5 +405,6 @@ def f(x, y): self.fn = f self.name = "Concat" + if __name__ == "__main__": absltest.main() diff --git a/test/llama.py b/test/llama.py index cc588b0a2..f7b8dd028 100644 --- a/test/llama.py +++ b/test/llama.py @@ -349,6 +349,7 @@ def sfn(x, weights, key_cache, value_cache): # mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") if True: + @jax.jit def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) @@ -411,8 +412,13 @@ def erev(x, weights, kc, vc, dx, dkc, dvc): jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) print("Jax rev", jres) - jrev2 = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=enzyme_jax.JaXPipeline("inline{default-pipeline=canonicalize max-iterations=4}," - + "canonicalize,cse,enzyme-hlo-opt,cse"))(jrev) + jrev2 = enzyme_jax.enzyme_jax_ir( + argv=argv, + pipeline_options=enzyme_jax.JaXPipeline( + "inline{default-pipeline=canonicalize max-iterations=4}," + + "canonicalize,cse,enzyme-hlo-opt,cse" + ), + )(jrev) jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc) print("Jax2 rev", jres2)