Skip to content

[GlobalISel] Support saturated truncate #150219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,25 @@ class CombinerHelper {
bool matchUMulHToLShr(MachineInstr &MI) const;
void applyUMulHToLShr(MachineInstr &MI) const;

// Combine trunc(smin(smax(x, C1), C2)) -> truncssat_s(x)
// or trunc(smax(smin(x, C2), C1)) -> truncssat_s(x).
bool matchTruncSSatS(MachineInstr &MI, Register &MatchInfo) const;
void applyTruncSSatS(MachineInstr &MI, Register &MatchInfo) const;

// Combine trunc(smin(smax(x, 0), C)) -> truncssat_u(x)
// or trunc(smax(smin(x, C), 0)) -> truncssat_u(x)
// or trunc(umin(smax(x, 0), C)) -> truncssat_u(x)
bool matchTruncSSatU(MachineInstr &MI, Register &MatchInfo) const;
void applyTruncSSatU(MachineInstr &MI, Register &MatchInfo) const;

// Combine trunc(umin(x, C)) -> truncusat_u(x).
bool matchTruncUSatU(MachineInstr &MI, Register &MatchInfo) const;
void applyTruncUSatU(MachineInstr &MI, Register &MatchInfo) const;

// Combine truncusat_u(fptoui(x)) -> fptoui_sat(x)
bool matchTruncUSatUToFPTOUISat(MachineInstr &MI, Register &MatchInfo) const;
void applyTruncUSatUToFPTOUISat(MachineInstr &MI, Register &MatchInfo) const;

/// Try to transform \p MI by using all of the above
/// combine functions. Returns true if changed.
bool tryCombine(MachineInstr &MI) const;
Expand Down
27 changes: 27 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,9 @@ class GCastOp : public GenericMachineInstr {
case TargetOpcode::G_SEXT:
case TargetOpcode::G_SITOFP:
case TargetOpcode::G_TRUNC:
case TargetOpcode::G_TRUNC_SSAT_S:
case TargetOpcode::G_TRUNC_SSAT_U:
case TargetOpcode::G_TRUNC_USAT_U:
case TargetOpcode::G_UITOFP:
case TargetOpcode::G_ZEXT:
case TargetOpcode::G_ANYEXT:
Expand Down Expand Up @@ -916,6 +919,30 @@ class GTrunc : public GCastOp {
};
};

/// Represents a saturated trunc from a signed input to a signed result.
class GTruncSSatS : public GCastOp {
public:
static bool classof(const MachineInstr *MI) {
return MI->getOpcode() == TargetOpcode::G_TRUNC_SSAT_S;
};
};

/// Represents a saturated trunc from a signed input to an unsigned result.
class GTruncSSatU : public GCastOp {
public:
static bool classof(const MachineInstr *MI) {
return MI->getOpcode() == TargetOpcode::G_TRUNC_SSAT_U;
};
};

/// Represents a saturated trunc from an unsigned input to an unsigned result.
class GTruncUSatU : public GCastOp {
public:
static bool classof(const MachineInstr *MI) {
return MI->getOpcode() == TargetOpcode::G_TRUNC_USAT_U;
};
};

/// Represents a vscale.
class GVScale : public GenericMachineInstr {
public:
Expand Down
12 changes: 12 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,18 @@ m_GFPTrunc(const SrcTy &Src) {
return UnaryOp_match<SrcTy, TargetOpcode::G_FPTRUNC>(Src);
}

template <typename SrcTy>
inline UnaryOp_match<SrcTy, TargetOpcode::G_FPTOSI>
m_GFPToSI(const SrcTy &Src) {
return UnaryOp_match<SrcTy, TargetOpcode::G_FPTOSI>(Src);
}

template <typename SrcTy>
inline UnaryOp_match<SrcTy, TargetOpcode::G_FPTOUI>
m_GFPToUI(const SrcTy &Src) {
return UnaryOp_match<SrcTy, TargetOpcode::G_FPTOUI>(Src);
}

template <typename SrcTy>
inline UnaryOp_match<SrcTy, TargetOpcode::G_FABS> m_GFabs(const SrcTy &Src) {
return UnaryOp_match<SrcTy, TargetOpcode::G_FABS>(Src);
Expand Down
28 changes: 27 additions & 1 deletion llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,32 @@ def mulh_to_lshr : GICombineRule<

def mulh_combines : GICombineGroup<[mulh_to_lshr]>;

def trunc_ssats : GICombineRule<
(defs root:$root, register_matchinfo:$matchinfo),
(match (G_TRUNC $dst, $src):$root,
[{ return Helper.matchTruncSSatS(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyTruncSSatS(*${root}, ${matchinfo}); }])>;

def trunc_ssatu : GICombineRule<
(defs root:$root, register_matchinfo:$matchinfo),
(match (G_TRUNC $dst, $src):$root,
[{ return Helper.matchTruncSSatU(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyTruncSSatU(*${root}, ${matchinfo}); }])>;

def trunc_usatu : GICombineRule<
(defs root:$root, register_matchinfo:$matchinfo),
(match (G_TRUNC $dst, $src):$root,
[{ return Helper.matchTruncUSatU(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyTruncUSatU(*${root}, ${matchinfo}); }])>;

def truncusatu_to_fptouisat : GICombineRule<
(defs root:$root, register_matchinfo:$matchinfo),
(match (G_TRUNC_USAT_U $dst, $src):$root,
[{ return Helper.matchTruncUSatUToFPTOUISat(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyTruncUSatUToFPTOUISat(*${root}, ${matchinfo}); }])>;

def truncsat_combines : GICombineGroup<[trunc_ssats, trunc_ssatu, trunc_usatu, truncusatu_to_fptouisat]>;

def redundant_neg_operands: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_FADD, G_FSUB, G_FMUL, G_FDIV, G_FMAD, G_FMA):$root,
Expand Down Expand Up @@ -2066,7 +2092,7 @@ def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
simplify_neg_minmax, combine_concat_vector,
sext_trunc, zext_trunc, prefer_sign_combines, shuffle_combines,
combine_use_vector_truncate, merge_combines, overflow_combines]>;
combine_use_vector_truncate, merge_combines, overflow_combines, truncsat_combines]>;

// A combine group used to for prelegalizer combiners at -O0. The combines in
// this group have been selected based on experiments to balance code size and
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def : GINodeEquiv<G_SEXT, sext>;
def : GINodeEquiv<G_SEXT_INREG, sext_inreg>;
def : GINodeEquiv<G_ZEXT, zext>;
def : GINodeEquiv<G_TRUNC, trunc>;
def : GINodeEquiv<G_TRUNC_SSAT_S, truncssat_s>;
def : GINodeEquiv<G_TRUNC_SSAT_U, truncssat_u>;
def : GINodeEquiv<G_TRUNC_USAT_U, truncusat_u>;
def : GINodeEquiv<G_BITCAST, bitconvert>;
// G_INTTOPTR - SelectionDAG has no equivalent.
// G_PTRTOINT - SelectionDAG has no equivalent.
Expand Down
128 changes: 128 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5844,6 +5844,134 @@ void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) const {
MI.eraseFromParent();
}

bool CombinerHelper::matchTruncSSatS(MachineInstr &MI,
Register &MatchInfo) const {
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
LLT DstTy = MRI.getType(Dst);
LLT SrcTy = MRI.getType(Src);
unsigned NumDstBits = DstTy.getScalarSizeInBits();
unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");

APInt MinConst, MaxConst;
APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);

if (isLegal({TargetOpcode::G_TRUNC_SSAT_S, {DstTy, SrcTy}})) {
if (mi_match(Src, MRI,
m_GSMin(m_GSMax(m_Reg(MatchInfo), m_ICstOrSplat(MinConst)),
m_ICstOrSplat(MaxConst))) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this use m_SpecificICstSplat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_SpecificICstSplat is for int64_t, but it would be good to be able to match like that here, I could add an APInt equivalent to be used?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds good. It sounds like a useful thing to have to me.

APInt::isSameValue(MinConst, SignedMin) &&
APInt::isSameValue(MaxConst, SignedMax))
return true;
if (mi_match(Src, MRI,
m_GSMax(m_GSMin(m_Reg(MatchInfo), m_ICstOrSplat(MaxConst)),
m_ICstOrSplat(MinConst))) &&
APInt::isSameValue(MinConst, SignedMin) &&
APInt::isSameValue(MaxConst, SignedMax))
return true;
}
return false;
}

void CombinerHelper::applyTruncSSatS(MachineInstr &MI,
Register &MatchInfo) const {
Register Dst = MI.getOperand(0).getReg();
Builder.buildTruncSSatS(Dst, MatchInfo);
MI.eraseFromParent();
}

bool CombinerHelper::matchTruncSSatU(MachineInstr &MI,
Register &MatchInfo) const {
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
LLT DstTy = MRI.getType(Dst);
LLT SrcTy = MRI.getType(Src);
unsigned NumDstBits = DstTy.getScalarSizeInBits();
unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");

APInt MaxConst;
APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);

if (isLegal({TargetOpcode::G_TRUNC_SSAT_U, {DstTy, SrcTy}})) {
if (mi_match(Src, MRI,
m_GSMin(m_GSMax(m_Reg(MatchInfo), m_SpecificICstOrSplat(0)),
m_ICstOrSplat(MaxConst))) &&
APInt::isSameValue(MaxConst, UnsignedMax))
return true;
if (mi_match(Src, MRI,
m_GSMax(m_GSMin(m_Reg(MatchInfo), m_ICstOrSplat(MaxConst)),
m_SpecificICstOrSplat(0))) &&
APInt::isSameValue(MaxConst, UnsignedMax))
return true;
if (mi_match(Src, MRI,
m_GUMin(m_GSMax(m_Reg(MatchInfo), m_SpecificICstOrSplat(0)),
m_ICstOrSplat(MaxConst))) &&
APInt::isSameValue(MaxConst, UnsignedMax))
return true;
}
return false;
}

void CombinerHelper::applyTruncSSatU(MachineInstr &MI,
Register &MatchInfo) const {
Register Dst = MI.getOperand(0).getReg();
Builder.buildTruncSSatU(Dst, MatchInfo);
MI.eraseFromParent();
}

bool CombinerHelper::matchTruncUSatU(MachineInstr &MI,
Register &MatchInfo) const {
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
LLT DstTy = MRI.getType(Dst);
LLT SrcTy = MRI.getType(Src);
unsigned NumDstBits = DstTy.getScalarSizeInBits();
unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");

APInt MaxConst;
APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);

if (isLegal({TargetOpcode::G_TRUNC_SSAT_U, {DstTy, SrcTy}})) {
if (mi_match(Src, MRI,
m_GUMin(m_Reg(MatchInfo), m_ICstOrSplat(MaxConst))) &&
APInt::isSameValue(MaxConst, UnsignedMax))
return true;
}
return false;
}

void CombinerHelper::applyTruncUSatU(MachineInstr &MI,
Register &MatchInfo) const {
Register Dst = MI.getOperand(0).getReg();
Builder.buildTruncUSatU(Dst, MatchInfo);
MI.eraseFromParent();
}

bool CombinerHelper::matchTruncUSatUToFPTOUISat(MachineInstr &MI,
Register &MatchInfo) const {
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();
LLT DstTy = MRI.getType(Dst);
LLT SrcTy = MRI.getType(Src);

if (isLegalOrBeforeLegalizer({TargetOpcode::G_FPTOUI_SAT, {DstTy, SrcTy}})) {
if (mi_match(Src, MRI, m_GFPToUI((m_Reg(MatchInfo)))))
return true;
}
return false;
}

void CombinerHelper::applyTruncUSatUToFPTOUISat(MachineInstr &MI,
Register &MatchInfo) const {
Register Dst = MI.getOperand(0).getReg();
Builder.buildFPTOUI_SAT(Dst, MatchInfo);
MI.eraseFromParent();
}

bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI,
BuildFnTy &MatchInfo) const {
unsigned Opc = MI.getOpcode();
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.clampMinNumElements(0, s16, 4)
.alwaysLegal();

getActionDefinitionsBuilder({G_TRUNC_SSAT_S, G_TRUNC_SSAT_U, G_TRUNC_USAT_U})
.legalFor({{v8s8, v8s16}, {v4s16, v4s32}, {v2s32, v2s64}});

getActionDefinitionsBuilder(G_SEXT_INREG)
.legalFor({s32, s64})
.legalFor(PackedVectorAllTypeList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,16 @@
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
# DEBUG-NEXT: G_TRUNC_SSAT_S (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. the first uncovered type index: 2, OK
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
# DEBUG-NEXT: G_TRUNC_SSAT_U (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
# DEBUG-NEXT: .. the first uncovered type index: 2, OK
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
# DEBUG-NEXT: G_TRUNC_USAT_U (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
# DEBUG-NEXT: .. the first uncovered type index: 2, OK
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
# DEBUG-NEXT: G_CONSTANT (opcode {{[0-9]+}}): 1 type index, 0 imm indices
# DEBUG-NEXT: .. the first uncovered type index: 1, OK
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
Expand Down
79 changes: 79 additions & 0 deletions llvm/test/CodeGen/AArch64/truncsat.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use the existing test coverage in for example qmovn.ll and maybe fpclamptosat_vec.ll?

; RUN: llc < %s -mtriple=aarch64-unknown-unknown -global-isel=0 | FileCheck %s --check-prefixes=CHECK,CHECK-SD
; RUN: llc < %s -mtriple=aarch64-unknown-unknown -global-isel=1 | FileCheck %s --check-prefixes=CHECK,CHECK-GI


define <4 x i16> @ssats_1(<4 x i32> %x) {
; CHECK-LABEL: ssats_1:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sqxtn v0.4h, v0.4s
; CHECK-NEXT: ret
entry:
%spec.store.select = call <4 x i32> @llvm.smin.v4i32(<4 x i32> %x, <4 x i32> <i32 32767, i32 32767, i32 32767, i32 32767>)
%spec.store.select7 = call <4 x i32> @llvm.smax.v4i32(<4 x i32> %spec.store.select, <4 x i32> <i32 -32768, i32 -32768, i32 -32768, i32 -32768>)
%conv6 = trunc <4 x i32> %spec.store.select7 to <4 x i16>
ret <4 x i16> %conv6
}

define <4 x i16> @ssats_2(<4 x i32> %x) {
; CHECK-LABEL: ssats_2:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sqxtn v0.4h, v0.4s
; CHECK-NEXT: ret
entry:
%spec.store.select = call <4 x i32> @llvm.smax.v4i32(<4 x i32> %x, <4 x i32> <i32 -32768, i32 -32768, i32 -32768, i32 -32768>)
%spec.store.select7 = call <4 x i32> @llvm.smin.v4i32(<4 x i32> %spec.store.select, <4 x i32> <i32 32767, i32 32767, i32 32767, i32 32767>)
%conv6 = trunc <4 x i32> %spec.store.select7 to <4 x i16>
ret <4 x i16> %conv6
}

define <4 x i16> @ssatu_1(<4 x i32> %x) {
; CHECK-LABEL: ssatu_1:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sqxtun v0.4h, v0.4s
; CHECK-NEXT: ret
entry:
%spec.store.select = call <4 x i32> @llvm.smin.v4i32(<4 x i32> %x, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>)
%spec.store.select7 = call <4 x i32> @llvm.smax.v4i32(<4 x i32> %spec.store.select, <4 x i32> zeroinitializer)
%conv6 = trunc <4 x i32> %spec.store.select7 to <4 x i16>
ret <4 x i16> %conv6
}

define <4 x i16> @ssatu_2(<4 x i32> %x) {
; CHECK-LABEL: ssatu_2:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sqxtun v0.4h, v0.4s
; CHECK-NEXT: ret
entry:
%spec.store.select = call <4 x i32> @llvm.smax.v4i32(<4 x i32> %x, <4 x i32> zeroinitializer)
%spec.store.select7 = call <4 x i32> @llvm.smin.v4i32(<4 x i32> %spec.store.select, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>)
%conv6 = trunc <4 x i32> %spec.store.select7 to <4 x i16>
ret <4 x i16> %conv6
}

define <4 x i16> @ssatu_3(<4 x i32> %x) {
; CHECK-LABEL: ssatu_3:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sqxtun v0.4h, v0.4s
; CHECK-NEXT: ret
entry:
%spec.store.select = call <4 x i32> @llvm.smax.v4i32(<4 x i32> %x, <4 x i32> zeroinitializer)
%spec.store.select7 = call <4 x i32> @llvm.umin.v4i32(<4 x i32> %spec.store.select, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>)
%conv6 = trunc <4 x i32> %spec.store.select7 to <4 x i16>
ret <4 x i16> %conv6
}

define <4 x i16> @usatu(<4 x i32> %x) {
; CHECK-LABEL: usatu:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: uqxtn v0.4h, v0.4s
; CHECK-NEXT: ret
entry:
%spec.store.select = call <4 x i32> @llvm.umin.v4i32(<4 x i32> %x, <4 x i32> <i32 65535, i32 65535, i32 65535, i32 65535>)
%conv6 = trunc <4 x i32> %spec.store.select to <4 x i16>
ret <4 x i16> %conv6
}

;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; CHECK-GI: {{.*}}
; CHECK-SD: {{.*}}