Skip to content

Commit

Permalink
Automerge: [RISCV][CostModel] Add cost for fabs/fsqrt of type bf16/f1…
Browse files Browse the repository at this point in the history
…6 (#118608)
  • Loading branch information
LiqinWeng authored and github-actions[bot] committed Jan 10, 2025
2 parents da98578 + 98e5962 commit 2533bb2
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 163 deletions.
68 changes: 57 additions & 11 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/CodeGen/CostTable.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PatternMatch.h"
#include <cmath>
Expand Down Expand Up @@ -1035,21 +1036,66 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
}
break;
}
case Intrinsic::fabs:
case Intrinsic::fabs: {
auto LT = getTypeLegalizationCost(RetTy);
if (ST->hasVInstructions() && LT.second.isVector()) {
// lui a0, 8
// addi a0, a0, -1
// vsetvli a1, zero, e16, m1, ta, ma
// vand.vx v8, v8, a0
// f16 with zvfhmin and bf16 with zvfhbmin
if (LT.second.getVectorElementType() == MVT::bf16 ||
(LT.second.getVectorElementType() == MVT::f16 &&
!ST->hasVInstructionsF16()))
return LT.first * getRISCVInstructionCost(RISCV::VAND_VX, LT.second,
CostKind) +
2;
else
return LT.first *
getRISCVInstructionCost(RISCV::VFSGNJX_VV, LT.second, CostKind);
}
break;
}
case Intrinsic::sqrt: {
auto LT = getTypeLegalizationCost(RetTy);
// TODO: add f16/bf16, bf16 with zvfbfmin && f16 with zvfhmin
if (ST->hasVInstructions() && LT.second.isVector()) {
unsigned Op;
switch (ICA.getID()) {
case Intrinsic::fabs:
Op = RISCV::VFSGNJX_VV;
break;
case Intrinsic::sqrt:
Op = RISCV::VFSQRT_V;
break;
SmallVector<unsigned, 4> ConvOp;
SmallVector<unsigned, 2> FsqrtOp;
MVT ConvType = LT.second;
MVT FsqrtType = LT.second;
// f16 with zvfhmin and bf16 with zvfbfmin and the type of nxv32[b]f16
// will be spilt.
if (LT.second.getVectorElementType() == MVT::bf16) {
if (LT.second == MVT::nxv32bf16) {
ConvOp = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFWCVTBF16_F_F_V,
RISCV::VFNCVTBF16_F_F_W, RISCV::VFNCVTBF16_F_F_W};
FsqrtOp = {RISCV::VFSQRT_V, RISCV::VFSQRT_V};
ConvType = MVT::nxv16f16;
FsqrtType = MVT::nxv16f32;
} else {
ConvOp = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFNCVTBF16_F_F_W};
FsqrtOp = {RISCV::VFSQRT_V};
FsqrtType = TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType);
}
} else if (LT.second.getVectorElementType() == MVT::f16 &&
!ST->hasVInstructionsF16()) {
if (LT.second == MVT::nxv32f16) {
ConvOp = {RISCV::VFWCVT_F_F_V, RISCV::VFWCVT_F_F_V,
RISCV::VFNCVT_F_F_W, RISCV::VFNCVT_F_F_W};
FsqrtOp = {RISCV::VFSQRT_V, RISCV::VFSQRT_V};
ConvType = MVT::nxv16f16;
FsqrtType = MVT::nxv16f32;
} else {
ConvOp = {RISCV::VFWCVT_F_F_V, RISCV::VFNCVT_F_F_W};
FsqrtOp = {RISCV::VFSQRT_V};
FsqrtType = TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType);
}
} else {
FsqrtOp = {RISCV::VFSQRT_V};
}
return LT.first * getRISCVInstructionCost(Op, LT.second, CostKind);

return LT.first * (getRISCVInstructionCost(FsqrtOp, FsqrtType, CostKind) +
getRISCVInstructionCost(ConvOp, ConvType, CostKind));
}
break;
}
Expand Down
Loading

0 comments on commit 2533bb2

Please sign in to comment.