Skip to content

Commit

Permalink
Lower bfloat arithmetic instructions
Browse files Browse the repository at this point in the history
.
  • Loading branch information
igorban-intel authored and igcbot committed Dec 21, 2023
1 parent f8a991b commit ba99a6e
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 1 deletion.
1 change: 1 addition & 0 deletions IGC/VectorCompiler/lib/GenXCodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ set(CODEGEN_SOURCES
GenXArgIndirection.cpp
GenXBaling.cpp
GenXBiFPrepare.cpp
GenXBFloatLowering.cpp
GenXBuiltinFunctions.cpp
GenXCFSimplification.cpp
GenXCategory.cpp
Expand Down
1 change: 1 addition & 0 deletions IGC/VectorCompiler/lib/GenXCodeGen/GenX.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ ModulePass *createGenXEarlySimdCFConformancePass();
FunctionPass *createGenXPredToSimdCFPass();
FunctionPass *createGenXReduceIntSizePass();
FunctionPass *createGenXInlineAsmLoweringPass();
FunctionPass *createGenXBFloatLoweringPass();
FunctionPass *createGenXLoweringPass();
FunctionPass *createGenXVectorCombinerPass();
FunctionPass *createGenXLowerAggrCopiesPass();
Expand Down
179 changes: 179 additions & 0 deletions IGC/VectorCompiler/lib/GenXCodeGen/GenXBFloatLowering.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*========================== begin_copyright_notice ============================
Copyright (C) 2023 Intel Corporation
SPDX-License-Identifier: MIT
============================= end_copyright_notice ===========================*/

//
/// GenXBFloatLowering
/// ------------
///
/// This pass lowers bfloat instructions into float ones, because bfloat
/// arithmetic operation aren't supported by the hardware.
///
/// It transforms code, like:
/// %fdiv_res = fdiv bfloat %a, %b
/// Into the following code:
/// %a.ext = fpext bfloat %a to float
/// %b.ext = fpext bfloat %b to float
/// %fdiv = fdiv float %a.ext, %b.ext
/// %trunc = fptrunc float %fdiv to bfloat
//===----------------------------------------------------------------------===//

#include "GenX.h"
#include "GenXSubtarget.h"
#include "GenXTargetMachine.h"
#include "GenXUtil.h"

#include <llvm/IR/InstVisitor.h>
#include <llvm/CodeGen/TargetPassConfig.h>

#define DEBUG_TYPE "genx-bfloat-lowering"

using namespace llvm;
using namespace genx;

static cl::opt<bool>
EnableGenXBFloatLowering("enable-genx-bfloat-lowering", cl::init(true),
cl::Hidden,
cl::desc("Enable GenX bfloat lowering."));

namespace {
// GenXBFloatLowering
class GenXBFloatLowering : public FunctionPass,
public InstVisitor<GenXBFloatLowering> {
bool Modify = false;
const GenXSubtarget *ST = nullptr;

public:
explicit GenXBFloatLowering() : FunctionPass(ID) {}
StringRef getPassName() const override { return "GenX BFloat lowering"; }
void getAnalysisUsage(AnalysisUsage &AU) const override;
bool runOnFunction(Function &F) override;

public:
static char ID;
#if LLVM_VERSION_MAJOR > 10
// fadd, fsub e.t.c
void visitBinaryOperator(BinaryOperator &Inst);
// fneg
void visitUnaryOperator(UnaryOperator &Inst);
// fcmp
void visitFCmpInst(FCmpInst &Inst);
#endif
};

} // end namespace

char GenXBFloatLowering::ID = 0;
namespace llvm {
void initializeGenXBFloatLoweringPass(PassRegistry &);
}
INITIALIZE_PASS_BEGIN(GenXBFloatLowering, "GenXBFloatLowering",
"GenXBFloatLowering", false, false)
INITIALIZE_PASS_END(GenXBFloatLowering, "GenXBFloatLowering",
"GenXBFloatLowering", false, false)

FunctionPass *llvm::createGenXBFloatLoweringPass() {
initializeGenXBFloatLoweringPass(*PassRegistry::getPassRegistry());
return new GenXBFloatLowering;
}

void GenXBFloatLowering::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<TargetPassConfig>();
}

#if LLVM_VERSION_MAJOR > 10
static Type *getFloatTyFromBfloat(Type *Ty) {
IGC_ASSERT_EXIT(Ty->getScalarType()->isBFloatTy());
auto *FloatTy = Type::getFloatTy(Ty->getContext());

if (Ty->isBFloatTy())
return FloatTy;

auto *VecTy = cast<IGCLLVM::FixedVectorType>(Ty);
return IGCLLVM::FixedVectorType::get(FloatTy, VecTy->getNumElements());
}

// fadd, fsub e.t.c
void GenXBFloatLowering::visitBinaryOperator(BinaryOperator &Inst) {
auto *Ty = Inst.getType();
if (!Inst.getType()->getScalarType()->isBFloatTy())
return;
LLVM_DEBUG(dbgs() << "GenXBFloatLowering: apply on BinaryOperator\n"
<< Inst << "\n");
auto *Src0 = Inst.getOperand(0);
auto *Src1 = Inst.getOperand(1);
IGC_ASSERT_EXIT(Src0->getType()->getScalarType()->isBFloatTy());
IGC_ASSERT_EXIT(Src1->getType()->getScalarType()->isBFloatTy());
IRBuilder<> Builder(&Inst);
auto *FloatTy = getFloatTyFromBfloat(Src0->getType());
Instruction::BinaryOps Opcode = Inst.getOpcode();
auto *Op0Conv = Builder.CreateFPExt(Src0, FloatTy);
auto *Op1Conv = Builder.CreateFPExt(Src1, FloatTy);
auto *InstUpdate = Builder.CreateBinOp(Opcode, Op0Conv, Op1Conv);
auto *Trunc = Builder.CreateFPTrunc(InstUpdate, Ty);
Inst.replaceAllUsesWith(Trunc);
Inst.eraseFromParent();
Modify = true;
}

void GenXBFloatLowering::visitFCmpInst(FCmpInst &Inst) {
auto *Src0 = Inst.getOperand(0);
auto *Src1 = Inst.getOperand(1);
auto *SrcTy = Src0->getType();
if (!SrcTy->getScalarType()->isBFloatTy())
return;
LLVM_DEBUG(dbgs() << "GenXBFloatLowering: apply on FCmp\n" << Inst << "\n");
IGC_ASSERT_EXIT(Src1->getType()->getScalarType()->isBFloatTy());
IRBuilder<> Builder(&Inst);
auto *FloatTy = getFloatTyFromBfloat(SrcTy);
auto *Op0Conv = Builder.CreateFPExt(Src0, FloatTy);
auto *Op1Conv = Builder.CreateFPExt(Src1, FloatTy);
auto *InstUpdate = Builder.CreateFCmp(Inst.getPredicate(), Op0Conv, Op1Conv);
Inst.replaceAllUsesWith(InstUpdate);
Inst.eraseFromParent();
Modify = true;
}

// fneg
void GenXBFloatLowering::visitUnaryOperator(UnaryOperator &Inst) {
auto *Src = Inst.getOperand(0);
auto *SrcTy = Src->getType();
if (!SrcTy->getScalarType()->isBFloatTy())
return;
LLVM_DEBUG(dbgs() << "GenXBFloatLowering: apply on UnaryOperator\n"
<< Inst << "\n");
IGC_ASSERT_EXIT(Inst.getOpcode() == Instruction::FNeg);
IRBuilder<> Builder(&Inst);
auto *FloatTy = getFloatTyFromBfloat(SrcTy);
auto *SrcExt = Builder.CreateFPExt(Src, FloatTy);
auto *InstUpdate = Builder.CreateUnOp(Inst.getOpcode(), SrcExt);
auto *Trunc = Builder.CreateFPTrunc(InstUpdate, SrcTy);
Inst.replaceAllUsesWith(Trunc);
Inst.eraseFromParent();
Modify = true;
}
#endif

/***********************************************************************
* GenXBFloatLowering::runOnFunction
*/
bool GenXBFloatLowering::runOnFunction(Function &F) {
if (!EnableGenXBFloatLowering)
return false;
LLVM_DEBUG(dbgs() << "GenXBFloatLowering started\n");

#if LLVM_VERSION_MAJOR > 10
ST = &getAnalysis<TargetPassConfig>()
.getTM<GenXTargetMachine>()
.getGenXSubtarget();

visit(F);
#endif

LLVM_DEBUG(dbgs() << "GenXBFloatLowering ended\n");
return Modify;
}
4 changes: 3 additions & 1 deletion IGC/VectorCompiler/lib/GenXCodeGen/GenXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ void initializeGenXPasses(PassRegistry &registry) {
// initializeGenXLivenessWrapperPass(registry);
initializeGenXLivenessWrapperPass(registry);
initializeGenXLowerAggrCopiesPass(registry);
initializeGenXBFloatLoweringPass(registry);
initializeGenXLoweringPass(registry);
initializeGenXModulePass(registry);
initializeGenXNumberingWrapperPass(registry);
Expand Down Expand Up @@ -479,7 +480,8 @@ bool GenXTargetMachine::addPassesToEmitFile(PassManagerBase &PM,
BuiltinFunctionKind::PreLegalization));
vc::addPass(PM, createAlwaysInlinerLegacyPass());
}

/// .. include:: GenXBFloatLowering.cpp
vc::addPass(PM, createGenXBFloatLoweringPass());
/// .. include:: GenXLowering.cpp
vc::addPass(PM, createGenXLoweringPass());

Expand Down
1 change: 1 addition & 0 deletions IGC/VectorCompiler/lib/GenXCodeGen/GenXTargetMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ void initializeGenXPasses(PassRegistry &);
void initializeFunctionGroupAnalysisPass(PassRegistry &);
void initializeGenXAddressCommoningWrapperPass(PassRegistry &);
void initializeGenXArgIndirectionWrapperPass(PassRegistry &);
void initializeGenXBFloatLoweringPass(PassRegistry &);
void initializeGenXCategoryWrapperPass(PassRegistry &);
void initializeGenXCFSimplificationPass(PassRegistry &);
void initializeGenXCisaBuilderPass(PassRegistry &);
Expand Down
147 changes: 147 additions & 0 deletions IGC/VectorCompiler/test/GenXBFloatLowering/bfloat.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
;=========================== begin_copyright_notice ============================
;
; Copyright (C) 2023 Intel Corporation
;
; SPDX-License-Identifier: MIT
;
;============================ end_copyright_notice =============================

; REQUIRES: llvm_12_or_greater
; RUN: %opt %use_old_pass_manager% -GenXBFloatLowering -march=genx64 -mcpu=Gen9 -mtriple=spir64-unknown-unknown -S < %s | FileCheck %s
; COM: Supported instructions FAdd, FSub, FMul, FDiv, FRem, FCmp

define bfloat @scalar_bfloat_fadd(bfloat %a, bfloat %b) {
%fadd_res = fadd bfloat %a, %b
ret bfloat %fadd_res
}
; CHECK-LABEL: define bfloat @scalar_bfloat_fadd
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float
; CHECK: %[[fadd:[a-z0-9.]+]] = fadd float %[[A_EXP]], %[[B_EXP]]
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc float %[[fadd]] to bfloat

define <123 x bfloat> @vector_bfloat_fadd(<123 x bfloat> %a, <123 x bfloat> %b) {
%fadd_res = fadd <123 x bfloat> %a, %b
ret <123 x bfloat> %fadd_res
}
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_fadd
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float>
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float>
; CHECK: %[[fadd:[a-z0-9.]+]] = fadd <123 x float> %[[A_EXP]], %[[B_EXP]]
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc <123 x float> %[[fadd]] to <123 x bfloat>

define bfloat @scalar_bfloat_fsub(bfloat %a, bfloat %b) {
%fsub_res = fsub bfloat %a, %b
ret bfloat %fsub_res
}
; CHECK-LABEL: define bfloat @scalar_bfloat_fsub
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float
; CHECK: %[[fsub:[a-z0-9.]+]] = fsub float %[[A_EXP]], %[[B_EXP]]
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc float %[[fsub]] to bfloat

define <123 x bfloat> @vector_bfloat_fsub(<123 x bfloat> %a, <123 x bfloat> %b) {
%fsub_res = fsub <123 x bfloat> %a, %b
ret <123 x bfloat> %fsub_res
}
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_fsub
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float>
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float>
; CHECK: %[[fsub:[a-z0-9.]+]] = fsub <123 x float> %[[A_EXP]], %[[B_EXP]]
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc <123 x float> %[[fsub]] to <123 x bfloat>

define bfloat @scalar_bfloat_fmul(bfloat %a, bfloat %b) {
%fmul_res = fmul bfloat %a, %b
ret bfloat %fmul_res
}
; CHECK-LABEL: define bfloat @scalar_bfloat_fmul
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float
; CHECK: %[[fmul:[a-z0-9.]+]] = fmul float %[[A_EXP]], %[[B_EXP]]
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc float %[[fmul]] to bfloat

define <123 x bfloat> @vector_bfloat_fmul(<123 x bfloat> %a, <123 x bfloat> %b) {
%fmul_res = fmul <123 x bfloat> %a, %b
ret <123 x bfloat> %fmul_res
}
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_fmul
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float>
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float>
; CHECK: %[[fmul:[a-z0-9.]+]] = fmul <123 x float> %[[A_EXP]], %[[B_EXP]]
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc <123 x float> %[[fmul]] to <123 x bfloat>

define bfloat @scalar_bfloat_fdiv(bfloat %a, bfloat %b) {
%fdiv_res = fdiv bfloat %a, %b
ret bfloat %fdiv_res
}
; CHECK-LABEL: define bfloat @scalar_bfloat_fdiv
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float
; CHECK: %[[fdiv:[a-z0-9.]+]] = fdiv float %[[A_EXP]], %[[B_EXP]]
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc float %[[fdiv]] to bfloat

define <123 x bfloat> @vector_bfloat_fdiv(<123 x bfloat> %a, <123 x bfloat> %b) {
%fdiv_res = fdiv <123 x bfloat> %a, %b
ret <123 x bfloat> %fdiv_res
}
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_fdiv
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float>
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float>
; CHECK: %[[fdiv:[a-z0-9.]+]] = fdiv <123 x float> %[[A_EXP]], %[[B_EXP]]
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc <123 x float> %[[fdiv]] to <123 x bfloat>

define bfloat @scalar_bfloat_frem(bfloat %a, bfloat %b) {
%frem_res = frem bfloat %a, %b
ret bfloat %frem_res
}
; CHECK-LABEL: define bfloat @scalar_bfloat_frem
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float
; CHECK: %[[frem:[a-z0-9.]+]] = frem float %[[A_EXP]], %[[B_EXP]]
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc float %[[frem]] to bfloat

define <123 x bfloat> @vector_bfloat_frem(<123 x bfloat> %a, <123 x bfloat> %b) {
%frem_res = frem <123 x bfloat> %a, %b
ret <123 x bfloat> %frem_res
}
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_frem
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float>
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float>
; CHECK: %[[frem:[a-z0-9.]+]] = frem <123 x float> %[[A_EXP]], %[[B_EXP]]
; CHECK: %[[TRUNK:[a-z0-9.]+]] = fptrunc <123 x float> %[[frem]] to <123 x bfloat>

define i1 @scalar_bfloat_fcmp(bfloat %a, bfloat %b) {
%fcmp_res = fcmp une bfloat %a, %b
ret i1 %fcmp_res
}
; CHECK-LABEL: define i1 @scalar_bfloat_fcmp
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext bfloat %b to float
; CHECK: %[[fcmp:[a-z0-9.]+]] = fcmp une float %[[A_EXP]], %[[B_EXP]]

define <123 x i1> @vector_bfloat_fcmp(<123 x bfloat> %a, <123 x bfloat> %b) {
%fcmp_res = fcmp une <123 x bfloat> %a, %b
ret <123 x i1> %fcmp_res
}
; CHECK-LABEL: define <123 x i1> @vector_bfloat_fcmp
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float>
; CHECK-DAG: %[[B_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %b to <123 x float>
; CHECK: %[[fcmp:[a-z0-9.]+]] = fcmp une <123 x float> %[[A_EXP]], %[[B_EXP]]


define bfloat @scalar_bfloat_fneg(bfloat %a) {
%fneg_res = fneg bfloat %a
ret bfloat %fneg_res
}
; CHECK-LABEL: define bfloat @scalar_bfloat_fneg
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext bfloat %a to float
; CHECK: %[[fcmp:[a-z0-9.]+]] = fneg float %[[A_EXP]]

define <123 x bfloat> @vector_bfloat_fneg(<123 x bfloat> %a) {
%fneg_res = fneg <123 x bfloat> %a
ret <123 x bfloat> %fneg_res
}
; CHECK-LABEL: define <123 x bfloat> @vector_bfloat_fneg
; CHECK-DAG: %[[A_EXP:[a-z0-9.]+]] = fpext <123 x bfloat> %a to <123 x float>
; CHECK: %[[fcmp:[a-z0-9.]+]] = fneg <123 x float> %[[A_EXP]]

0 comments on commit ba99a6e

Please sign in to comment.