Skip to content

Commit

Permalink
Adapt to upstream (#2228)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jan 19, 2025
1 parent c4f953f commit bf11541
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ static mlir::Type batchType(mlir::Type type, int64_t width) {
return RankedTensorType::get({width}, type);
}

class FloatTypeInterface
: public AutoDiffTypeInterface::ExternalModel<FloatTypeInterface,
FloatType> {
template <typename ConcreteType>
class FloatTypeInterface : public AutoDiffTypeInterface::ExternalModel<
FloatTypeInterface<ConcreteType>, ConcreteType> {
public:
Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
auto fltType = self.cast<FloatType>();
auto fltType = self.cast<ConcreteType>();
return builder.create<arith::ConstantFloatOp>(
loc, APFloat(fltType.getFloatSemantics(), 0), fltType);
}
Expand Down Expand Up @@ -200,10 +200,10 @@ class ComplexTypeInterface
void mlir::enzyme::registerBuiltinDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, BuiltinDialect *) {
BFloat16Type::attachInterface<FloatTypeInterface>(*context);
Float16Type::attachInterface<FloatTypeInterface>(*context);
Float32Type::attachInterface<FloatTypeInterface>(*context);
Float64Type::attachInterface<FloatTypeInterface>(*context);
BFloat16Type::attachInterface<FloatTypeInterface<BFloat16Type>>(*context);
Float16Type::attachInterface<FloatTypeInterface<Float16Type>>(*context);
Float32Type::attachInterface<FloatTypeInterface<Float32Type>>(*context);
Float64Type::attachInterface<FloatTypeInterface<Float64Type>>(*context);
IntegerType::attachInterface<IntegerTypeInterface<IntegerType>>(*context);
IndexType::attachInterface<IntegerTypeInterface<IndexType>>(*context);
UnrankedTensorType::attachInterface<TensorTypeInterface>(*context);
Expand Down
9 changes: 3 additions & 6 deletions enzyme/Enzyme/MustExitScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,12 +340,9 @@ ScalarEvolution::ExitLimit MustExitScalarEvolution::computeExitLimitFromICmp(
const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsExit,
bool AllowPredicates) {
// If the condition was exit on true, convert the condition to exit on false
ICmpInst::Predicate Pred;
if (!ExitIfTrue)
Pred = ExitCond->getPredicate();
else
Pred = ExitCond->getInversePredicate();
const ICmpInst::Predicate OriginalPred = Pred;
auto Pred = (!ExitIfTrue) ? ExitCond->getPredicate()
: ExitCond->getInversePredicate();
const auto OriginalPred = Pred;

#if LLVM_VERSION_MAJOR < 14
// Handle common loops like: for (X = "string"; *X; ++X)
Expand Down

0 comments on commit bf11541

Please sign in to comment.