Skip to content

Commit

Permalink
disabled div-sqrt for now
Browse files Browse the repository at this point in the history
  • Loading branch information
vimarsh6739 committed Jan 4, 2025
1 parent 58443b8 commit 12138a2
Showing 1 changed file with 16 additions and 190 deletions.
206 changes: 16 additions & 190 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/TypeInference.h"
#include "stablehlo/reference/Ops.h"
// #include "stablehlo/transforms/ChloDecompositionUtils.h"
#include "stablehlo/transforms/ChloDecompositionUtils.h"
#include "stablehlo/transforms/PassUtils.h"
#include "stablehlo/transforms/Passes.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
Expand Down Expand Up @@ -2157,193 +2157,19 @@ struct ConcatToBroadcast final
}
};

Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val,
bool negative) {
auto ty = cast<FloatType>(getElementTypeOrSelf(val.getType()));
return mlir::stablehlo::getConstantLike(
b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val);
}

// Coefficients for the Lanczos approximation of the gamma function. The
// coefficients are uniquely determined by the choice of g and n (kLanczosGamma
// and kLanczosCoefficients.size() + 1). The coefficients below correspond to
// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
// [7, 9] seemed to be the least sensitive to the quality of the log function.
// In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
// for a particularly inaccurate log function.
constexpr double kLanczosGamma = 7; // aka g
constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
constexpr std::array<double, 8> kLanczosCoefficients = {
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
771.3234287776530788486528258894, -176.61502916214059906584551354,
12.507343278686904814458936853, -0.13857109526572011689554707,
9.984369578019570859563e-6, 1.50563273514931155834e-7};

// Compute the Lgamma function using Lanczos' approximation from "A Precision
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
// series B. Vol. 1:
// lgamma(z + 1) = (log(2) + log(pi)) / 2
// + (z + 1/2) * log(t(z))
// - t(z) + log(a(z))
// with t(z) = z + kLanczosGamma + 1/2
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
Value materializeLgamma(PatternRewriter &rewriter, Location loc,
ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
// Let z be
// z = -x if x < 1/2
// z = x - 1 otheriwse
Value x = args.front();
Value half = mlir::stablehlo::getConstantLike(rewriter, loc, 0.5, x);
Value needToReflect = rewriter.create<mlir::stablehlo::CompareOp>(
loc, x, half, mlir::stablehlo::ComparisonDirection::LT);
Value negX = rewriter.create<mlir::stablehlo::NegOp>(loc, x);
Value one = mlir::stablehlo::getConstantLike(rewriter, loc, 1, x);
Value xSubOne = rewriter.create<mlir::stablehlo::SubtractOp>(loc, x, one);
Value z = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect, negX,
xSubOne);

// Materialize
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
Value a =
mlir::stablehlo::getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
Value coeff = mlir::stablehlo::getConstantLike(rewriter, loc,
kLanczosCoefficients[i], x);
Value oneBasedIndex =
mlir::stablehlo::getConstantLike(rewriter, loc, i + 1, x);
Value quotient = rewriter.create<mlir::stablehlo::DivOp>(
loc, coeff,
rewriter.create<mlir::stablehlo::AddOp>(loc, z, oneBasedIndex));
a = rewriter.create<mlir::stablehlo::AddOp>(loc, a, quotient);
}

// To improve accuracy on platforms with less-precise log implementations,
// compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
// device.
// Materialize as
// log(t) = log(kLanczosGamma + 1/2 + z)
// = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
Value lanczosPlusHalf =
mlir::stablehlo::getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
Value t = rewriter.create<mlir::stablehlo::AddOp>(loc, lanczosPlusHalf, z);
Value logTerm = mlir::stablehlo::getConstantLike(
rewriter, loc, std::log(kLanczosGamma + 0.5), x);
Value log1pTerm = rewriter.create<mlir::stablehlo::Log1pOp>(
loc, rewriter.create<mlir::stablehlo::DivOp>(loc, z, lanczosPlusHalf));
Value logT = rewriter.create<mlir::stablehlo::AddOp>(loc, logTerm, log1pTerm);

// Note that t(z) may be large and we need to be careful not to overflow to
// infinity in the relevant term
// r = (z + 1/2) * log(t(z)) - t(z).
// Therefore, we compute this as
// r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
Value tDivLogT = rewriter.create<mlir::stablehlo::DivOp>(loc, t, logT);
Value sum = rewriter.create<mlir::stablehlo::SubtractOp>(
loc, rewriter.create<mlir::stablehlo::AddOp>(loc, z, half), tDivLogT);
Value r = rewriter.create<mlir::stablehlo::MulOp>(loc, sum, logT);

// Compute the final result (modulo reflection) as
// lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
Value logA = rewriter.create<mlir::stablehlo::LogOp>(loc, a);
Value lgamma = rewriter.create<mlir::stablehlo::AddOp>(
loc,
rewriter.create<mlir::stablehlo::AddOp>(
loc,
mlir::stablehlo::getConstantLike(
rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
r),
logA);

// Compute the reflected value for x < 0.5 as
// lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
//
// The abs is needed because lgamma is the log of the absolute value of the
// gamma function.
//
// We have to be careful when computing the final term above. gamma(x) goes
// to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
// term. The slope is large, so precision is particularly important.
//
// Because abs(sin(pi * x)) has period of 1 we can equivalently use
// abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
// more numerically accurate: It doesn't overflow to inf like pi * x would and
// if x is an integer it evaluates to exactly 0 which is important because we
// then take the log of this value, and log(0) is inf.
//
// We don't have a frac(x) primitive in HLO and computing it is tricky, but
// because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
// purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
//
// Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
// to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
// [0, 1] is symmetric across the line Y=0.5.
//

// Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
// pi * abs_frac for values of abs_frac close to 1.
Value abs = rewriter.create<mlir::stablehlo::AbsOp>(loc, x);
Value absFrac = rewriter.create<mlir::stablehlo::SubtractOp>(
loc, abs, rewriter.create<mlir::stablehlo::FloorOp>(loc, abs));
Value reduceAbsFrac = rewriter.create<mlir::stablehlo::CompareOp>(
loc, half, absFrac, mlir::stablehlo::ComparisonDirection::LT);
absFrac = rewriter.create<mlir::stablehlo::SelectOp>(
loc, reduceAbsFrac,
rewriter.create<mlir::stablehlo::SubtractOp>(loc, one, absFrac), absFrac);

// Materialize reflection.
Value reflectionDenom = rewriter.create<mlir::stablehlo::LogOp>(
loc,
rewriter.create<mlir::stablehlo::SineOp>(
loc,
rewriter.create<mlir::stablehlo::MulOp>(
loc, mlir::stablehlo::getConstantLike(rewriter, loc, M_PI, x),
absFrac)));
Value lgammaReflection = rewriter.create<mlir::stablehlo::SubtractOp>(
loc,
rewriter.create<mlir::stablehlo::SubtractOp>(
loc,
mlir::stablehlo::getConstantLike(rewriter, loc, std::log(M_PI), x),
reflectionDenom),
lgamma);

// Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
// then it "wins" and the result is +/-inf.
Value finiteReflectionDenom =
rewriter.create<mlir::stablehlo::IsFiniteOp>(loc, reflectionDenom);
Value negReflectionDenom =
rewriter.create<mlir::stablehlo::NegOp>(loc, reflectionDenom);
lgammaReflection = rewriter.create<mlir::stablehlo::SelectOp>(
loc, finiteReflectionDenom, lgammaReflection, negReflectionDenom);

// Select whether or not to rely on the reflection.
lgamma = rewriter.create<mlir::stablehlo::SelectOp>(loc, needToReflect,
lgammaReflection, lgamma);

// Materialize +/-inf behavior as
// lgamma(+/-inf) = +inf.
Value xIsInf = rewriter.create<chlo::IsInfOp>(loc, x);
return rewriter.create<mlir::stablehlo::SelectOp>(
loc, xIsInf,
getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), lgamma);
}

struct GammaConstProp final : OpRewritePattern<mlir::chlo::LgammaOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::chlo::LgammaOp op,
PatternRewriter &rewriter) const override {
// return if not constant
DenseElementsAttr inputAttr;
if(!matchPattern(op.getOperand(),m_Constant(&inputAttr)))
if (!matchPattern(op.getOperand(), m_Constant(&inputAttr)))
return failure();

Value result = materializeLgamma(rewriter, op.getLoc(), op->getOperands());

Value result = mlir::stablehlo::materializeLgamma(rewriter, op.getLoc(),
op->getOperands());
rewriter.replaceOp(op, result);

return success();
}
};
Expand Down Expand Up @@ -7262,17 +7088,17 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
patterns.add<NoNan, NoNanSelfSubSimplify, NoNanAddSubSimplify>(context);
}

patterns
.add<CompareOpCanon, BroadcastInDimOpCanon, ConvertOpCanon,
DynamicBroadcastInDimOpNotActuallyDynamic,
ChainedDynamicBroadcastInDimCanonicalization,
DynamicBroadcastInDimAllDimsNonExpanding, NoopReduceOpCanon,
EmptyReduceOpCanon, DynamicReshapeOpCanon, GetTupleElementOpCanon,
RealOpCanon, ImagOpCanon, ConjComplexNegate,
GetDimensionSizeOpCanon, GatherOpCanon, ReshapeOpCanon,
MergeConsecutiveReshapes, TransposeIsReshape, IfInline, IfToSelect,
ZeroExtentTensorCanon, ReorderElementwiseAndShapeOp,
DynamicGatherOpIsNotDynamic, DivideSqrtToMultiplyRsqrt>(context);
patterns.add<CompareOpCanon, BroadcastInDimOpCanon, ConvertOpCanon,
DynamicBroadcastInDimOpNotActuallyDynamic,
ChainedDynamicBroadcastInDimCanonicalization,
DynamicBroadcastInDimAllDimsNonExpanding, NoopReduceOpCanon,
EmptyReduceOpCanon, DynamicReshapeOpCanon,
GetTupleElementOpCanon, RealOpCanon, ImagOpCanon,
ConjComplexNegate, GetDimensionSizeOpCanon, GatherOpCanon,
ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape,
IfInline, IfToSelect, ZeroExtentTensorCanon,
ReorderElementwiseAndShapeOp, DynamicGatherOpIsNotDynamic>(
context);
patterns.add<SelectOpCanon>(max_constant_expansion, context,
PatternBenefit(65000));
patterns.add<ConcatenateOpCanon>(max_constant_expansion, context,
Expand Down

0 comments on commit 12138a2

Please sign in to comment.