From 8e36e65e6570827796acf3ae7df421fc244ca8aa Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 1 Jan 2025 14:18:42 -0500 Subject: [PATCH] Improve cast ft error (#2213) --- enzyme/Enzyme/AdjointGenerator.h | 46 +++++++++++++++++--------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 96b8494302e..655bdca6943 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1380,31 +1380,33 @@ class AdjointGenerator : public llvm::InstVisitor { ss << "Cannot deduce adding type (cast) of " << I; EmitNoTypeError(str, I, gutils, Builder2); } - assert(FT); - auto rule = [&](Value *dif) { - if (I.getOpcode() == CastInst::CastOps::FPTrunc || - I.getOpcode() == CastInst::CastOps::FPExt) { - return Builder2.CreateFPCast(dif, op0->getType()); - } else if (I.getOpcode() == CastInst::CastOps::BitCast) { - return Builder2.CreateBitCast(dif, op0->getType()); - } else if (I.getOpcode() == CastInst::CastOps::Trunc) { - // TODO CHECK THIS - return Builder2.CreateZExt(dif, op0->getType()); - } else { - std::string s; - llvm::raw_string_ostream ss(s); - ss << *I.getParent()->getParent() << "\n"; - ss << "cannot handle above cast " << I << "\n"; - EmitNoDerivativeError(ss.str(), I, gutils, Builder2); - return (llvm::Value *)UndefValue::get(op0->getType()); - } - }; + if (FT) { + + auto rule = [&](Value *dif) { + if (I.getOpcode() == CastInst::CastOps::FPTrunc || + I.getOpcode() == CastInst::CastOps::FPExt) { + return Builder2.CreateFPCast(dif, op0->getType()); + } else if (I.getOpcode() == CastInst::CastOps::BitCast) { + return Builder2.CreateBitCast(dif, op0->getType()); + } else if (I.getOpcode() == CastInst::CastOps::Trunc) { + // TODO CHECK THIS + return Builder2.CreateZExt(dif, op0->getType()); + } else { + std::string s; + llvm::raw_string_ostream ss(s); + ss << *I.getParent()->getParent() << "\n"; + ss << "cannot handle above cast " << I << "\n"; + EmitNoDerivativeError(ss.str(), I, gutils, Builder2); + return (llvm::Value *)UndefValue::get(op0->getType()); + } + }; - Value *dif = diffe(&I, Builder2); - Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif); + Value *dif = diffe(&I, Builder2); + Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif); - addToDiffe(orig_op0, diff, Builder2, FT); + addToDiffe(orig_op0, diff, Builder2, FT); + } } Type *diffTy = gutils->getShadowType(I.getType());