-
Notifications
You must be signed in to change notification settings - Fork 164
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Lower bfloat arithmetic instructions
.
- Loading branch information
1 parent
f8a991b
commit ba99a6e
Showing
6 changed files
with
332 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
179 changes: 179 additions & 0 deletions
179
IGC/VectorCompiler/lib/GenXCodeGen/GenXBFloatLowering.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
/*========================== begin_copyright_notice ============================ | ||
Copyright (C) 2023 Intel Corporation | ||
SPDX-License-Identifier: MIT | ||
============================= end_copyright_notice ===========================*/ | ||
|
||
// | ||
/// GenXBFloatLowering | ||
/// ------------ | ||
/// | ||
/// This pass lowers bfloat instructions into float ones, because bfloat | ||
/// arithmetic operation aren't supported by the hardware. | ||
/// | ||
/// It transforms code, like: | ||
/// %fdiv_res = fdiv bfloat %a, %b | ||
/// Into the following code: | ||
/// %a.ext = fpext bfloat %a to float | ||
/// %b.ext = fpext bfloat %b to float | ||
/// %fdiv = fdiv float %a.ext, %b.ext | ||
/// %trunc = fptrunc float %fdiv to bfloat | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "GenX.h" | ||
#include "GenXSubtarget.h" | ||
#include "GenXTargetMachine.h" | ||
#include "GenXUtil.h" | ||
|
||
#include <llvm/IR/InstVisitor.h> | ||
#include <llvm/CodeGen/TargetPassConfig.h> | ||
|
||
#define DEBUG_TYPE "genx-bfloat-lowering" | ||
|
||
using namespace llvm; | ||
using namespace genx; | ||
|
||
static cl::opt<bool> | ||
EnableGenXBFloatLowering("enable-genx-bfloat-lowering", cl::init(true), | ||
cl::Hidden, | ||
cl::desc("Enable GenX bfloat lowering.")); | ||
|
||
namespace { | ||
// GenXBFloatLowering | ||
class GenXBFloatLowering : public FunctionPass, | ||
public InstVisitor<GenXBFloatLowering> { | ||
bool Modify = false; | ||
const GenXSubtarget *ST = nullptr; | ||
|
||
public: | ||
explicit GenXBFloatLowering() : FunctionPass(ID) {} | ||
StringRef getPassName() const override { return "GenX BFloat lowering"; } | ||
void getAnalysisUsage(AnalysisUsage &AU) const override; | ||
bool runOnFunction(Function &F) override; | ||
|
||
public: | ||
static char ID; | ||
#if LLVM_VERSION_MAJOR > 10 | ||
// fadd, fsub e.t.c | ||
void visitBinaryOperator(BinaryOperator &Inst); | ||
// fneg | ||
void visitUnaryOperator(UnaryOperator &Inst); | ||
// fcmp | ||
void visitFCmpInst(FCmpInst &Inst); | ||
#endif | ||
}; | ||
|
||
} // end namespace | ||
|
||
char GenXBFloatLowering::ID = 0; | ||
namespace llvm { | ||
void initializeGenXBFloatLoweringPass(PassRegistry &); | ||
} | ||
INITIALIZE_PASS_BEGIN(GenXBFloatLowering, "GenXBFloatLowering", | ||
"GenXBFloatLowering", false, false) | ||
INITIALIZE_PASS_END(GenXBFloatLowering, "GenXBFloatLowering", | ||
"GenXBFloatLowering", false, false) | ||
|
||
FunctionPass *llvm::createGenXBFloatLoweringPass() { | ||
initializeGenXBFloatLoweringPass(*PassRegistry::getPassRegistry()); | ||
return new GenXBFloatLowering; | ||
} | ||
|
||
void GenXBFloatLowering::getAnalysisUsage(AnalysisUsage &AU) const { | ||
AU.addRequired<TargetPassConfig>(); | ||
} | ||
|
||
#if LLVM_VERSION_MAJOR > 10 | ||
static Type *getFloatTyFromBfloat(Type *Ty) { | ||
IGC_ASSERT_EXIT(Ty->getScalarType()->isBFloatTy()); | ||
auto *FloatTy = Type::getFloatTy(Ty->getContext()); | ||
|
||
if (Ty->isBFloatTy()) | ||
return FloatTy; | ||
|
||
auto *VecTy = cast<IGCLLVM::FixedVectorType>(Ty); | ||
return IGCLLVM::FixedVectorType::get(FloatTy, VecTy->getNumElements()); | ||
} | ||
|
||
// fadd, fsub e.t.c | ||
void GenXBFloatLowering::visitBinaryOperator(BinaryOperator &Inst) { | ||
auto *Ty = Inst.getType(); | ||
if (!Inst.getType()->getScalarType()->isBFloatTy()) | ||
return; | ||
LLVM_DEBUG(dbgs() << "GenXBFloatLowering: apply on BinaryOperator\n" | ||
<< Inst << "\n"); | ||
auto *Src0 = Inst.getOperand(0); | ||
auto *Src1 = Inst.getOperand(1); | ||
IGC_ASSERT_EXIT(Src0->getType()->getScalarType()->isBFloatTy()); | ||
IGC_ASSERT_EXIT(Src1->getType()->getScalarType()->isBFloatTy()); | ||
IRBuilder<> Builder(&Inst); | ||
auto *FloatTy = getFloatTyFromBfloat(Src0->getType()); | ||
Instruction::BinaryOps Opcode = Inst.getOpcode(); | ||
auto *Op0Conv = Builder.CreateFPExt(Src0, FloatTy); | ||
auto *Op1Conv = Builder.CreateFPExt(Src1, FloatTy); | ||
auto *InstUpdate = Builder.CreateBinOp(Opcode, Op0Conv, Op1Conv); | ||
auto *Trunc = Builder.CreateFPTrunc(InstUpdate, Ty); | ||
Inst.replaceAllUsesWith(Trunc); | ||
Inst.eraseFromParent(); | ||
Modify = true; | ||
} | ||
|
||
void GenXBFloatLowering::visitFCmpInst(FCmpInst &Inst) { | ||
auto *Src0 = Inst.getOperand(0); | ||
auto *Src1 = Inst.getOperand(1); | ||
auto *SrcTy = Src0->getType(); | ||
if (!SrcTy->getScalarType()->isBFloatTy()) | ||
return; | ||
LLVM_DEBUG(dbgs() << "GenXBFloatLowering: apply on FCmp\n" << Inst << "\n"); | ||
IGC_ASSERT_EXIT(Src1->getType()->getScalarType()->isBFloatTy()); | ||
IRBuilder<> Builder(&Inst); | ||
auto *FloatTy = getFloatTyFromBfloat(SrcTy); | ||
auto *Op0Conv = Builder.CreateFPExt(Src0, FloatTy); | ||
auto *Op1Conv = Builder.CreateFPExt(Src1, FloatTy); | ||
auto *InstUpdate = Builder.CreateFCmp(Inst.getPredicate(), Op0Conv, Op1Conv); | ||
Inst.replaceAllUsesWith(InstUpdate); | ||
Inst.eraseFromParent(); | ||
Modify = true; | ||
} | ||
|
||
// fneg | ||
void GenXBFloatLowering::visitUnaryOperator(UnaryOperator &Inst) { | ||
auto *Src = Inst.getOperand(0); | ||
auto *SrcTy = Src->getType(); | ||
if (!SrcTy->getScalarType()->isBFloatTy()) | ||
return; | ||
LLVM_DEBUG(dbgs() << "GenXBFloatLowering: apply on UnaryOperator\n" | ||
<< Inst << "\n"); | ||
IGC_ASSERT_EXIT(Inst.getOpcode() == Instruction::FNeg); | ||
IRBuilder<> Builder(&Inst); | ||
auto *FloatTy = getFloatTyFromBfloat(SrcTy); | ||
auto *SrcExt = Builder.CreateFPExt(Src, FloatTy); | ||
auto *InstUpdate = Builder.CreateUnOp(Inst.getOpcode(), SrcExt); | ||
auto *Trunc = Builder.CreateFPTrunc(InstUpdate, SrcTy); | ||
Inst.replaceAllUsesWith(Trunc); | ||
Inst.eraseFromParent(); | ||
Modify = true; | ||
} | ||
#endif | ||
|
||
/*********************************************************************** | ||
* GenXBFloatLowering::runOnFunction | ||
*/ | ||
bool GenXBFloatLowering::runOnFunction(Function &F) { | ||
if (!EnableGenXBFloatLowering) | ||
return false; | ||
LLVM_DEBUG(dbgs() << "GenXBFloatLowering started\n"); | ||
|
||
#if LLVM_VERSION_MAJOR > 10 | ||
ST = &getAnalysis<TargetPassConfig>() | ||
.getTM<GenXTargetMachine>() | ||
.getGenXSubtarget(); | ||
|
||
visit(F); | ||
#endif | ||
|
||
LLVM_DEBUG(dbgs() << "GenXBFloatLowering ended\n"); | ||
return Modify; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
;=========================== begin_copyright_notice ============================ | ||
; | ||
; Copyright (C) 2023 Intel Corporation | ||
; | ||
; SPDX-License-Identifier: MIT | ||
; | ||
;============================ end_copyright_notice ============================= | ||
|
||
; REQUIRES: llvm_12_or_greater | ||
; RUN: %opt %use_old_pass_manager% -GenXBFloatLowering -march=genx64 -mcpu=Gen9 -mtriple=spir64-unknown-unknown -S < %s | FileCheck %s | ||
; COM: Supported instructions FAdd, FSub, FMul, FDiv, FRem, FCmp | ||
|
||
define bfloat @scalar_bfloat_fadd(bfloat %a, bfloat %b) { | ||
%fadd_res = fadd bfloat %a, %b | ||
ret bfloat %fadd_res | ||
} | ||
; CHECK-LABEL: define bfloat @scalar_bfloat_fadd | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float | ||
; CHECK: %[[fadd:[a-z0-9.]+]] = fadd float %[[A_EXP]], %[[B_EXP]] | ||
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc float %[[fadd]] to bfloat | ||
|
||
define <123 x bfloat> @vector_bfloat_fadd(<123 x bfloat> %a, <123 x bfloat> %b) { | ||
%fadd_res = fadd <123 x bfloat> %a, %b | ||
ret <123 x bfloat> %fadd_res | ||
} | ||
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_fadd | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float> | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float> | ||
; CHECK: %[[fadd:[a-z0-9.]+]] = fadd <123 x float> %[[A_EXP]], %[[B_EXP]] | ||
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc <123 x float> %[[fadd]] to <123 x bfloat> | ||
|
||
define bfloat @scalar_bfloat_fsub(bfloat %a, bfloat %b) { | ||
%fsub_res = fsub bfloat %a, %b | ||
ret bfloat %fsub_res | ||
} | ||
; CHECK-LABEL: define bfloat @scalar_bfloat_fsub | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float | ||
; CHECK: %[[fsub:[a-z0-9.]+]] = fsub float %[[A_EXP]], %[[B_EXP]] | ||
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc float %[[fsub]] to bfloat | ||
|
||
define <123 x bfloat> @vector_bfloat_fsub(<123 x bfloat> %a, <123 x bfloat> %b) { | ||
%fsub_res = fsub <123 x bfloat> %a, %b | ||
ret <123 x bfloat> %fsub_res | ||
} | ||
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_fsub | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float> | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float> | ||
; CHECK: %[[fsub:[a-z0-9.]+]] = fsub <123 x float> %[[A_EXP]], %[[B_EXP]] | ||
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc <123 x float> %[[fsub]] to <123 x bfloat> | ||
|
||
define bfloat @scalar_bfloat_fmul(bfloat %a, bfloat %b) { | ||
%fmul_res = fmul bfloat %a, %b | ||
ret bfloat %fmul_res | ||
} | ||
; CHECK-LABEL: define bfloat @scalar_bfloat_fmul | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float | ||
; CHECK: %[[fmul:[a-z0-9.]+]] = fmul float %[[A_EXP]], %[[B_EXP]] | ||
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc float %[[fmul]] to bfloat | ||
|
||
define <123 x bfloat> @vector_bfloat_fmul(<123 x bfloat> %a, <123 x bfloat> %b) { | ||
%fmul_res = fmul <123 x bfloat> %a, %b | ||
ret <123 x bfloat> %fmul_res | ||
} | ||
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_fmul | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float> | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float> | ||
; CHECK: %[[fmul:[a-z0-9.]+]] = fmul <123 x float> %[[A_EXP]], %[[B_EXP]] | ||
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc <123 x float> %[[fmul]] to <123 x bfloat> | ||
|
||
define bfloat @scalar_bfloat_fdiv(bfloat %a, bfloat %b) { | ||
%fdiv_res = fdiv bfloat %a, %b | ||
ret bfloat %fdiv_res | ||
} | ||
; CHECK-LABEL: define bfloat @scalar_bfloat_fdiv | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float | ||
; CHECK: %[[fdiv:[a-z0-9.]+]] = fdiv float %[[A_EXP]], %[[B_EXP]] | ||
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc float %[[fdiv]] to bfloat | ||
|
||
define <123 x bfloat> @vector_bfloat_fdiv(<123 x bfloat> %a, <123 x bfloat> %b) { | ||
%fdiv_res = fdiv <123 x bfloat> %a, %b | ||
ret <123 x bfloat> %fdiv_res | ||
} | ||
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_fdiv | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float> | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float> | ||
; CHECK: %[[fdiv:[a-z0-9.]+]] = fdiv <123 x float> %[[A_EXP]], %[[B_EXP]] | ||
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc <123 x float> %[[fdiv]] to <123 x bfloat> | ||
|
||
define bfloat @scalar_bfloat_frem(bfloat %a, bfloat %b) { | ||
%frem_res = frem bfloat %a, %b | ||
ret bfloat %frem_res | ||
} | ||
; CHECK-LABEL: define bfloat @scalar_bfloat_frem | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float | ||
; CHECK: %[[frem:[a-z0-9.]+]] = frem float %[[A_EXP]], %[[B_EXP]] | ||
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc float %[[frem]] to bfloat | ||
|
||
define <123 x bfloat> @vector_bfloat_frem(<123 x bfloat> %a, <123 x bfloat> %b) { | ||
%frem_res = frem <123 x bfloat> %a, %b | ||
ret <123 x bfloat> %frem_res | ||
} | ||
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_frem | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float> | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float> | ||
; CHECK: %[[frem:[a-z0-9.]+]] = frem <123 x float> %[[A_EXP]], %[[B_EXP]] | ||
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc <123 x float> %[[frem]] to <123 x bfloat> | ||
|
||
define i1 @scalar_bfloat_fcmp(bfloat %a, bfloat %b) { | ||
%fcmp_res = fcmp une bfloat %a, %b | ||
ret i1 %fcmp_res | ||
} | ||
; CHECK-LABEL: define i1 @scalar_bfloat_fcmp | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float | ||
; CHECK: %[[fcmp:[a-z0-9.]+]] = fcmp une float %[[A_EXP]], %[[B_EXP]] | ||
|
||
define <123 x i1> @vector_bfloat_fcmp(<123 x bfloat> %a, <123 x bfloat> %b) { | ||
%fcmp_res = fcmp une <123 x bfloat> %a, %b | ||
ret <123 x i1> %fcmp_res | ||
} | ||
; CHECK-LABEL: define <123 x i1> @vector_bfloat_fcmp | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float> | ||
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float> | ||
; CHECK: %[[fcmp:[a-z0-9.]+]] = fcmp une <123 x float> %[[A_EXP]], %[[B_EXP]] | ||
|
||
|
||
define bfloat @scalar_bfloat_fneg(bfloat %a) { | ||
%fneg_res = fneg bfloat %a | ||
ret bfloat %fneg_res | ||
} | ||
; CHECK-LABEL: define bfloat @scalar_bfloat_fneg | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float | ||
; CHECK: %[[fcmp:[a-z0-9.]+]] = fneg float %[[A_EXP]] | ||
|
||
define <123 x bfloat> @vector_bfloat_fneg(<123 x bfloat> %a) { | ||
%fneg_res = fneg <123 x bfloat> %a | ||
ret <123 x bfloat> %fneg_res | ||
} | ||
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_fneg | ||
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float> | ||
; CHECK: %[[fcmp:[a-z0-9.]+]] = fneg <123 x float> %[[A_EXP]] | ||
|