Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU] selecting v_sat_pk instruction, version 2 #123297

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5844fd4
selecting v_sat_pk instruction, version 2
Shoreshen Jan 17, 2025
8505185
fix format
Shoreshen Jan 17, 2025
506ccd3
update comments
Shoreshen Jan 17, 2025
7fce2b2
upadte comments
Shoreshen Jan 17, 2025
424938f
fix format
Shoreshen Jan 17, 2025
88e52c1
fix comments
Shoreshen Jan 17, 2025
1dc8b9c
fix format
Shoreshen Jan 17, 2025
c5e3e65
fix comment
Shoreshen Jan 17, 2025
97755d2
Merge branch 'llvm:main' into select_v_sat_pk_v2
Shoreshen Jan 20, 2025
c99c42a
Merge branch 'llvm:main' into select_v_sat_pk_v2
Shoreshen Jan 21, 2025
3cbf7aa
update vNi8
Shoreshen Jan 21, 2025
7b166f9
add v3i16 case
Shoreshen Jan 21, 2025
f0e5101
Merge branch 'llvm:main' into select_v_sat_pk_v2
Shoreshen Jan 21, 2025
bb6edd1
Merge remote-tracking branch 'origin/main' into select_v_sat_pk_v2
Shoreshen Jan 22, 2025
b3147a8
handle N=2,4,8 case for now
Shoreshen Jan 22, 2025
28d2560
Merge branch 'llvm:main' into select_v_sat_pk_v2
Shoreshen Jan 23, 2025
17b2a49
Merge branch 'main' into select_v_sat_pk_v2
Shoreshen Jan 27, 2025
64e2125
fix test case
Shoreshen Jan 27, 2025
e7f3a17
Merge remote-tracking branch 'origin/main' into select_v_sat_pk_v2
Shoreshen Jan 29, 2025
6e16e60
fix comments & update main
Shoreshen Jan 29, 2025
c78b5d8
fix format
Shoreshen Jan 29, 2025
81ae12d
Merge branch 'main' into select_v_sat_pk_v2
Shoreshen Jan 31, 2025
9a7d148
Merge branch 'llvm:main' into select_v_sat_pk_v2
Shoreshen Feb 3, 2025
3143c07
Merge branch 'main' into select_v_sat_pk_v2
Shoreshen Feb 3, 2025
c291430
Merge branch 'llvm:main' into select_v_sat_pk_v2
Shoreshen Feb 3, 2025
dbc2d06
Merge branch 'main' into select_v_sat_pk_v2
Shoreshen Feb 8, 2025
a42182c
Merge branch 'llvm:main' into select_v_sat_pk_v2
Shoreshen Feb 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5498,6 +5498,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(UMIN3)
NODE_NAME_CASE(FMED3)
NODE_NAME_CASE(SMED3)
NODE_NAME_CASE(SAT_PK_CAST)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this extra node, can't we just select from TRUNC_SSAT_U directly?
Is it because it gets transformed/lost otherwise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for printing and dumping, without this the debug dump will show unknown node

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the type signature is different. This is forcing the pack to use a legal integer type instead of v2i8

NODE_NAME_CASE(UMED3)
NODE_NAME_CASE(FMAXIMUM3)
NODE_NAME_CASE(FMINIMUM3)
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ enum NodeType : unsigned {
FMED3,
SMED3,
UMED3,
SAT_PK_CAST,
FMAXIMUM3,
FMINIMUM3,
FDOT2,
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ def AMDGPUumed3 : SDNode<"AMDGPUISD::UMED3", AMDGPUDTIntTernaryOp,
[]
>;

// Special node to handle v_sat_pk to avoid v2i8
def AMDGPUsat_pk_cast : SDNode<"AMDGPUISD::SAT_PK_CAST", SDTUnaryOp, []>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to document what this is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain what the node is, not just to avoid v2i8. It's to pack a v2i18 into i16


def AMDGPUfmed3_impl : SDNode<"AMDGPUISD::FMED3", SDTFPTernaryOp, []>;

def AMDGPUfdot2_impl : SDNode<"AMDGPUISD::FDOT2",
Expand Down
72 changes: 72 additions & 0 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,35 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
{MVT::v4f32, MVT::v8f32, MVT::v16f32, MVT::v32f32},
Custom);
}

// True 16 instruction is current not supported
// FIXME: Add support for true 16 when supported
if (!(Subtarget->hasTrue16BitInsts() && Subtarget->useRealTrue16Insts())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how this is checking for the existence of the underlying instruction. Plus the negation should be pushed through the condition

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , I did this based on the tablegen file, which is according to the previous patch and Sisyph's comment from last PR. The instruction is applicable either if NotHasTrue16BitInsts or UseFakeTrue16Insts predicate is satisfied. while:

  1. For NotHasTrue16BitInsts predicate, it requires (Subtarget->hasVOP3PInsts()) && (!Subtarget->hasTrue16BitInsts())
  2. For UseFakeTrue16Insts predicate, it requires (Subtarget->hasVOP3PInsts()) && (Subtarget->hasTrue16BitInsts() && !Subtarget->useRealTrue16Insts()

So by combining these two, we will apply the instruction when !Subtarget->hasTrue16BitInsts() or !Subtarget->useRealTrue16Insts() is true.

// MVT::vNi16 for src type check in foldToSaturated
// MVT::vNi8 for dst type check in CustomLowerNode
setOperationAction(ISD::TRUNCATE_SSAT_U,
{
MVT::v2i16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't have to override every single type that could decompose. Ideally the combiner should be able to figure it out based on the legalizer rules

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , for truncate_ssat_u to be folded TLI.isOperationLegalOrCustom function has pass.

We didn't hook this function, so it goes default and will check getOperationAction(Op, SrcVT) == Custom, which will look up OpActions[(unsigned)VT.getSimpleVT().SimpleTy][Op], and this is set here.

If we do not set every vNi16 (source type), the related truncat_ssat_u will not be created.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For step 1 I wouldn't do this. The fix for this kind of issue is in the combiner forming them, not the legalizer rules for a specific operation

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , I agree, but changing the combiner will relate to changing llvm side code. currently I'm not planning on changing llvm side's code in this PR.

So I think we may either do this, or hook the TLI.isOperationLegalOrCustom function??

But I think hooking TLI.isOperationLegalOrCustom will make it strange, since the input variable is op code and SrcVT....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle the basic case in the first step, and don't worry about handling every vector perfectly right away.

MVT::v4i16,
MVT::v8i16,
MVT::v16i16,
MVT::v32i16,
MVT::v64i16,
MVT::v128i16,
MVT::v256i16,
MVT::v512i16,
MVT::v2i8,
MVT::v4i8,
MVT::v8i8,
MVT::v16i8,
MVT::v32i8,
MVT::v64i8,
MVT::v128i8,
MVT::v256i8,
MVT::v512i8,
},
Custom);
}
}

setOperationAction({ISD::FNEG, ISD::FABS}, MVT::v4f16, Custom);
Expand Down Expand Up @@ -1975,6 +2004,12 @@ bool SITargetLowering::isTypeDesirableForOp(unsigned Op, EVT VT) const {
if (VT == MVT::i1 && Op == ISD::SETCC)
return false;

// Special case for vNi8 handling where N is even
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't think you should need anything in isTypeDesirableForOp

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , this function checks the destination type, while TLI.isOperationLegalOrCustom checks the source type.

The Dst type are vNi8, if we didn't return true here, it goes to the default function to check isTypeLegal(DstVT)

The backend haven't add register class for vNi8. We maybe can add the relevant register class, but makeing vNi8 legal for register class may cause unpredictable result.

Personally I think we could add the register class for relevant type when it is formally legal in the backend. So I decide to handle it here for special case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , this function checks the destination type, while TLI.isOperationLegalOrCustom checks the source type.

This just sounds buggy. The interpretation of which type is the one that matters for the opcode needs to be globally consistent

The backend haven't add register class for vNi8. We maybe can add the relevant register class,

Please no, this is a huge amount of work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , yeah it is kind of weird logic. And what made it more strange is that to get into the ReplaceNodeResults function, it will check TLI.getOperationAction(Opc, DstVT) == Custom........

But if we want to change this, I think we also need to modify AArch64 backend....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, bugs cause other bugs and all the use points need to be fixed

if (Op == ISD::TRUNCATE_SSAT_U && VT.isVector() &&
VT.getVectorElementType() == MVT::i8 &&
((VT.getVectorNumElements() & 1) == 0))
return true;

return TargetLowering::isTypeDesirableForOp(Op, VT);
}

Expand Down Expand Up @@ -6606,6 +6641,43 @@ void SITargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(lowerFSQRTF16(SDValue(N, 0), DAG));
break;
}
case ISD::TRUNCATE_SSAT_U: {
SDLoc SL(N);
SDValue Src = N->getOperand(0);
EVT SrcVT = Src.getValueType();
EVT DstVT = N->getValueType(0);

assert(SrcVT.isVector() && DstVT.isVector());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also should assert the element type is i8


unsigned EleNo = SrcVT.getVectorNumElements();
assert(EleNo == DstVT.getVectorNumElements());

if (EleNo == 2) {
SDValue Op =
DAG.getNode(AMDGPUISD::SAT_PK_CAST, SL, MVT::i16, N->getOperand(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Src from above

Op = DAG.getNode(ISD::BITCAST, SL, N->getValueType(0), Op);
Results.push_back(Op);
} else {
// Must be even number
assert((EleNo & 1) == 0);
SmallVector<SDValue> DstPairs;
EVT SrcEleVT = SrcVT.getVectorElementType();
EVT DstEleVT = DstVT.getVectorElementType();
EVT SrcPairVT = EVT::getVectorVT(*DAG.getContext(), SrcEleVT, 2);
EVT DstPairVT = EVT::getVectorVT(*DAG.getContext(), DstEleVT, 2);
for (unsigned i = 0; i + 1 < EleNo; i = i + 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (unsigned i = 0; i + 1 < EleNo; i = i + 2) {
for (unsigned I = 0; I != EleNo; I += 2) {

SDValue SrcPair = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SL, SrcPairVT,
Src, DAG.getConstant(i, SL, MVT::i32));
SDValue SatPk =
DAG.getNode(AMDGPUISD::SAT_PK_CAST, SL, MVT::i16, SrcPair);
SDValue DstPair = DAG.getNode(ISD::BITCAST, SL, DstPairVT, SatPk);
DstPairs.push_back(DstPair);
}
SDValue Op = DAG.getNode(ISD::CONCAT_VECTORS, SL, DstVT, DstPairs);
Results.push_back(Op);
}
break;
}
default:
AMDGPUTargetLowering::ReplaceNodeResults(N, Results, DAG);
break;
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/AMDGPU/SIInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -3324,6 +3324,21 @@ def : GCNPat <
(v2i16 (V_LSHL_OR_B32_e64 $src1, (i32 16), (i32 (V_AND_B32_e64 (i32 (V_MOV_B32_e32 (i32 0xffff))), $src0))))
>;

multiclass V_SAT_PK_Pat<Instruction inst> {
def : GCNPat<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern isn't doing much, you should be able to pass the node to the SDNodeOperator argument to the instruction definition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , could you be more specific? Should I use other type of pattern?

Copy link
Contributor

@arsenm arsenm Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The node is basically the same as the instruction definition, so you should be able to use the built-in pattern attached to the instruction def.
Something like

in VOP1Instructions.td:
defm V_SAT_PK_U8_I16 : VOP1Inst_t16<"v_sat_pk_u8_i16", VOP_I16_I32, AMDGPUsat_pk_cast>;

Copy link
Contributor Author

@Shoreshen Shoreshen Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arsenm , by adding the node I got the following:

def V_SAT_PK_U8_I16_e64: list<dag> Pattern = [(set i16:$vdst, (AMDGPUsat_pk_cast (i32 (VOP3Mods0 i32:$src0))))];
def V_SAT_PK_U8_I16_fake16_e64: list<dag> Pattern = [(set i16:$vdst, (AMDGPUsat_pk_cast (i32 (VOP3Mods0 i32:$src0))))];
def V_SAT_PK_U8_I16_t16_e64: list<dag> Pattern = [(set i16:$vdst, (AMDGPUsat_pk_cast (i32 (VOP3OpSelMods i32:$src0, i32:$src0_modifiers))))];

I think there are 2 problems:

  1. The source is i32, instead of v2i16
  2. It requires the operand of AMDGPUsat_pk_cast be complex pattern of VOP3Mods0 or VOP3OpSelMods

If the instruction cannot cover any type of (i16 (AMDGPUsat_pk_cast v2i8)), we gain risk of failing in selection.

I also tried to create a new VOP_I16_V2I16 type, but it makes V_SAT_PK_U8_I16_e64 and V_SAT_PK_U8_I16_fake16_e64 4 operands instructions (with modifier, clamp and opsel)

I think in order to make the passing node work, I need to modify related complex pattern functions and replace (v2i8 (truncssat_u v2i16)) with some patterns that can fit the complex pattern functions

(i16 (AMDGPUsat_pk_cast v2i16:$src)),
(inst VRegSrc_32:$src)
>;
}

let OtherPredicates = [NotHasTrue16BitInsts] in {
defm : V_SAT_PK_Pat<V_SAT_PK_U8_I16_e64>;
} // End OtherPredicates = [NotHasTrue16BitInsts]

let True16Predicate = UseFakeTrue16Insts in {
defm : V_SAT_PK_Pat<V_SAT_PK_U8_I16_fake16_e64>;
} // End True16Predicate = UseFakeTrue16Insts

// With multiple uses of the shift, this will duplicate the shift and
// increase register pressure.
def : GCNPat <
Expand Down
Loading
Loading