diff --git a/src/enzyme_ad/jax/Implementations/Common.td b/src/enzyme_ad/jax/Implementations/Common.td index 2d4667415..06551416c 100644 --- a/src/enzyme_ad/jax/Implementations/Common.td +++ b/src/enzyme_ad/jax/Implementations/Common.td @@ -33,11 +33,18 @@ class RegionTerminatorOp { string opName = opName_; } -class MLIRDerivative resultOps> { +class ForwardFromSummedReverseInternal { + int unused = unused_; +} +def ForwardFromSummedReverse : ForwardFromSummedReverseInternal<0>; + + +class MLIRDerivative resultOps, dag forwardOps=(ForwardFromSummedReverse)> { string dialect = dialect_; string opName = opName_; dag PatternToMatch = patternToMatch; list ArgDerivatives = resultOps; + dag ArgDuals = forwardOps; } class Operation { @@ -51,6 +58,13 @@ class DiffeRetIndex indices_> { } def DiffeRet : DiffeRetIndex<[-1]>; +def Shadow : Operation { +} + +class GlobalExpr : Operation{ + string value = val; +} + class Inst : Operation { string name = mnemonic; string dialect = dialect_; @@ -63,5 +77,13 @@ class ConstantFP : Operation { + +} + def Op { } + +def ResultTypes : GlobalExprgetResultTypes()">; + + diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index 6aa4ecfb6..875af33bc 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -12,6 +12,8 @@ def Sin : HLOInst<"SineOp">; def Sqrt : HLOInst<"SqrtOp">; def Exp : HLOInst<"ExpOp">; +def Dot : HLOInst<"DotGeneralOp">; + def CheckedMul : HLOInst<"MulOp">; def CheckedDiv : HLOInst<"DivOp">; @@ -83,6 +85,15 @@ def : HLOReadOnlyIdentityOp<"SliceOp">; def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">; def : HLOReadOnlyIdentityOp<"ConcatenateOp">; // convert -// cos -// sin -// sqrt + + +def ResultDotDim : GlobalExpr; +def ResultDotPrec : GlobalExpr; + +def : HLODerivative<"DotGeneralOp", (Op $lhs, $rhs), + [ + (Dot (ResultTypes), (DiffeRet), $rhs, (ResultDotDim), (ResultDotPrec)), + (Dot (ResultTypes), $lhs, (DiffeRet), (ResultDotDim), (ResultDotPrec)) + ], + (Add (SelectIfActive $lhs, (Dot (ResultTypes), (Shadow $lhs), $rhs, (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">)), (SelectIfActive $rhs, (Dot (ResultTypes), $lhs, (Shadow $rhs), (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">))) + >; diff --git a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td index cd319de8d..03ea7e0c6 100644 --- a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td @@ -1,6 +1,6 @@ include "Common.td" -class HLODerivative resultOps> : MLIRDerivative<"mhlo", opName_, patternToMatch, resultOps>; +class HLODerivative resultOps, dag forwardOps=(ForwardFromSummedReverse)> : MLIRDerivative<"mhlo", opName_, patternToMatch, resultOps, forwardOps>; class HLOInst : Inst; diff --git a/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td b/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td index da6291036..100e3179f 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td @@ -1,6 +1,6 @@ include "Common.td" -class HLODerivative resultOps> : MLIRDerivative<"stablehlo", opName_, patternToMatch, resultOps>; +class HLODerivative resultOps, dag forwardOps=(ForwardFromSummedReverse)> : MLIRDerivative<"stablehlo", opName_, patternToMatch, resultOps, forwardOps>; class HLOInst : Inst;