From bf115413decd211fa8f3c7683831505de52bc362 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 18 Jan 2025 23:02:01 -0600 Subject: [PATCH] Adapt to upstream (#2228) --- .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 16 ++++++++-------- enzyme/Enzyme/MustExitScalarEvolution.cpp | 9 +++------ 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index c38b990ceb6..bde01cb1325 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -41,12 +41,12 @@ static mlir::Type batchType(mlir::Type type, int64_t width) { return RankedTensorType::get({width}, type); } -class FloatTypeInterface - : public AutoDiffTypeInterface::ExternalModel { +template +class FloatTypeInterface : public AutoDiffTypeInterface::ExternalModel< + FloatTypeInterface, ConcreteType> { public: Value createNullValue(Type self, OpBuilder &builder, Location loc) const { - auto fltType = self.cast(); + auto fltType = self.cast(); return builder.create( loc, APFloat(fltType.getFloatSemantics(), 0), fltType); } @@ -200,10 +200,10 @@ class ComplexTypeInterface void mlir::enzyme::registerBuiltinDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, BuiltinDialect *) { - BFloat16Type::attachInterface(*context); - Float16Type::attachInterface(*context); - Float32Type::attachInterface(*context); - Float64Type::attachInterface(*context); + BFloat16Type::attachInterface>(*context); + Float16Type::attachInterface>(*context); + Float32Type::attachInterface>(*context); + Float64Type::attachInterface>(*context); IntegerType::attachInterface>(*context); IndexType::attachInterface>(*context); UnrankedTensorType::attachInterface(*context); diff --git a/enzyme/Enzyme/MustExitScalarEvolution.cpp b/enzyme/Enzyme/MustExitScalarEvolution.cpp index dff36303f9c..dcd16068edd 100644 --- a/enzyme/Enzyme/MustExitScalarEvolution.cpp +++ b/enzyme/Enzyme/MustExitScalarEvolution.cpp @@ -340,12 +340,9 @@ ScalarEvolution::ExitLimit MustExitScalarEvolution::computeExitLimitFromICmp( const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsExit, bool AllowPredicates) { // If the condition was exit on true, convert the condition to exit on false - ICmpInst::Predicate Pred; - if (!ExitIfTrue) - Pred = ExitCond->getPredicate(); - else - Pred = ExitCond->getInversePredicate(); - const ICmpInst::Predicate OriginalPred = Pred; + auto Pred = (!ExitIfTrue) ? ExitCond->getPredicate() + : ExitCond->getInversePredicate(); + const auto OriginalPred = Pred; #if LLVM_VERSION_MAJOR < 14 // Handle common loops like: for (X = "string"; *X; ++X)