Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrantq committed Jan 17, 2025
1 parent 35a9165 commit b85c191
Showing 1 changed file with 14 additions and 68 deletions.
82 changes: 14 additions & 68 deletions enzyme/Enzyme/Herbie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,25 @@

#include "llvm/Demangle/Demangle.h"

#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"

#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"

#include "llvm/Passes/PassBuilder.h"

#include "llvm/Support/Casting.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/InstructionCost.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <llvm/Support/JSON.h>

#include "llvm/Pass.h"

Expand Down Expand Up @@ -91,7 +99,7 @@ static cl::opt<int> HerbieTimeout("herbie-timeout", cl::init(120), cl::Hidden,
"candidate expressions."));
static cl::opt<std::string>
FPOptCachePath("fpopt-cache-path", cl::init(""), cl::Hidden,
cl::desc("Experimental: path to cache Herbie results"));
cl::desc("Path to cache Herbie results"));
static cl::opt<int>
HerbieNumPoints("herbie-num-pts", cl::init(1024), cl::Hidden,
cl::desc("Number of input points Herbie uses to evaluate "
Expand Down Expand Up @@ -307,7 +315,7 @@ class FPNode {
val = builder.CreateUnaryIntrinsic(Intrinsic::cos, operandValues[0],
nullptr, "herbie.cos");
} else if (op == "tan") {
#if LLVM_VERSION_MAJOR >= 16 // TODO: Double check version
#if LLVM_VERSION_MAJOR > 16 // TODO: Double check version
val = builder.CreateUnaryIntrinsic(Intrinsic::tan, operandValues[0],
"herbie.tan");
#else
Expand Down Expand Up @@ -2215,71 +2223,9 @@ InstructionCost getInstructionCompCost(const Instruction *I,
llvm_unreachable(msg.c_str());
}

llvm::errs()
<< "IMPORTANT: Custom cost model not provided, using default cost!\n";

unsigned Opcode = I->getOpcode();
switch (Opcode) {
case Instruction::FNeg: {
SmallVector<const Value *, 1> Args(I->operands());
return TTI.getArithmeticInstrCost(
Opcode, I->getType(), TargetTransformInfo::TCK_Latency,
getOperandValueKind(I->getOperand(0)), TargetTransformInfo::OK_AnyValue,
getOperandValueProperties(I->getOperand(0)),
TargetTransformInfo::OP_None, Args, I);
}
case Instruction::FAdd:
case Instruction::FSub:
case Instruction::FMul:
case Instruction::FDiv: {
SmallVector<const Value *, 2> Args(I->operands());
return TTI.getArithmeticInstrCost(
Opcode, I->getType(), TargetTransformInfo::TCK_Latency,
getOperandValueKind(I->getOperand(0)),
getOperandValueKind(I->getOperand(1)),
getOperandValueProperties(I->getOperand(0)),
getOperandValueProperties(I->getOperand(1)), Args, I);
}
case Instruction::FCmp: {
const auto *FCI = cast<FCmpInst>(I);
return TTI.getCmpSelInstrCost(Opcode, FCI->getType(), /* CondTy */ nullptr,
FCI->getPredicate(),
TargetTransformInfo::TCK_Latency, I);
}
case Instruction::PHI: {
return TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency);
}
default: {
if (const auto *Call = dyn_cast<CallInst>(I)) {
if (Function *CalledFunc = Call->getCalledFunction()) {
if (CalledFunc->isIntrinsic()) {
auto IID = CalledFunc->getIntrinsicID();
SmallVector<Type *, 4> OperandTypes;
SmallVector<const Value *, 4> Args;
for (auto &Arg : Call->args()) {
OperandTypes.push_back(Arg->getType());
Args.push_back(Arg.get());
}

IntrinsicCostAttributes ICA(IID, Call->getType(), Args, OperandTypes,
Call->getFastMathFlags(),
cast<IntrinsicInst>(I));
return TTI.getIntrinsicInstrCost(ICA,
TargetTransformInfo::TCK_Latency);
} else {
SmallVector<Type *, 4> ArgTypes;
for (auto &Arg : Call->args())
ArgTypes.push_back(Arg->getType());

return TTI.getCallInstrCost(CalledFunc, Call->getType(), ArgTypes,
TargetTransformInfo::TCK_Latency);
}
}
}
llvm::errs() << "WARNING: Using default cost for " << *I << "\n";
return TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency);
}
}
std::string msg = "Custom cost model: instruction cost for " +
std::string(I->getOpcodeName()) + " not found!";
llvm_unreachable(msg.c_str());
}

InstructionCost computeMaxCost(
Expand Down

0 comments on commit b85c191

Please sign in to comment.