Skip to content

[WebAssembly] [Backend] Combine and(X, shuffle(X, pow 2 mask)) to all true #145108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
#include "WebAssemblySubtarget.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/IR/DiagnosticInfo.h"
Expand Down Expand Up @@ -184,6 +186,10 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);

// Combine EXTRACT VECTOR ELT of AND(AND(X, SHUFFLE(X)), SHUFFLE(...)), 0
// to all_true
setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);

// Combine wide-vector muls, with extend inputs, to extmul_half.
setTargetDAGCombine(ISD::MUL);

Expand Down Expand Up @@ -3287,6 +3293,87 @@ static SDValue performSETCCCombine(SDNode *N,

return SDValue();
}
static SmallVector<int> buildMaskArrayByPower(unsigned FromPower,
unsigned NumElements) {
// Generate 1-index array of elements from 2^Power to 2^(Power+1) exclusive
// The rest is filled with -1.
//
// For example, with NumElements = 4:
// When Power = 1: <1 -1 -1 -1>
// When Power = 2: <2 3 -1 -1>
// When Power = 4: <4 5 6 7>
assert(FromPower <= 256);
unsigned ToPower = NextPowerOf2(FromPower);
assert(FromPower < NumElements && ToPower <= NumElements);

SmallVector<int> Res;
for (unsigned I = FromPower; I < ToPower; I++)
Res.push_back(I);
Res.resize(NumElements, -1);

return Res;
}
static SDValue matchAndOfShuffle(SDNode *N, int Power = 1) {
// Matching on the case of
//
// Base case: A [bitcast for a] setcc(v, <0>, ne).
// Recursive case: N = and(X, shuffle(X, power mask)) where X is either
// recursive or base case.
using namespace llvm::SDPatternMatch;

EVT VT = N->getValueType(0);

SDValue LHS = N->getOperand(0);
int NumElements = VT.getVectorNumElements();

if (NumElements < Power)
return SDValue();

if (N->getOpcode() != ISD::AND && NumElements == Power) {
SDValue BitCast, Matched;

// Try for a setcc first.
if (sd_match(N, m_c_SetCC(m_Value(Matched), m_Zero(),
m_SpecificCondCode(ISD::SETNE))))
return Matched;

// Now try for bitcast
if (!sd_match(N, m_BitCast(m_Value(BitCast))))
return SDValue();

if (!sd_match(BitCast, m_c_SetCC(m_Value(Matched), m_Zero(),
m_SpecificCondCode(ISD::SETNE))))
return SDValue();
return Matched;
}

SmallVector<int> PowerIndices = buildMaskArrayByPower(Power, NumElements);
if (sd_match(N, m_And(m_Value(LHS),
m_Shuffle(m_Value(LHS), m_VectorVT(m_Opc(ISD::POISON)),
m_SpecificMask(PowerIndices)))))
return matchAndOfShuffle(LHS.getNode(), NextPowerOf2(Power));

return SDValue();
}
static SDValue performExtractVecEltCombine(SDNode *N, SelectionDAG &DAG) {
using namespace llvm::SDPatternMatch;

assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
SDLoc DL(N);

SDValue And;
if (!sd_match(N, m_ExtractElt(m_VectorVT(m_Value(And)), m_Zero())))
return SDValue();

if (SDValue Matched = matchAndOfShuffle(And.getNode()))
return DAG.getZExtOrTrunc(
DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
{DAG.getConstant(Intrinsic::wasm_alltrue, DL, MVT::i32), Matched}),
DL, N->getValueType(0));

return SDValue();
}

static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
Expand Down Expand Up @@ -3402,6 +3489,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performTruncateCombine(N, DCI);
case ISD::INTRINSIC_WO_CHAIN:
return performLowerPartialReduction(N, DCI.DAG);
case ISD::EXTRACT_VECTOR_ELT:
return performExtractVecEltCombine(N, DCI.DAG);
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
}
Expand Down
47 changes: 47 additions & 0 deletions llvm/test/CodeGen/WebAssembly/simd-reduceand.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
target triple = "wasm64"

define i1 @reduce_and_to_all_true_16i8(<16 x i8> %0) {
; CHECK-LABEL: reduce_and_to_all_true_16i8:
; CHECK: .functype reduce_and_to_all_true_16i8 (v128) -> (i32)
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i8x16.all_true $push0=, $0
; CHECK-NEXT: return $pop0
%2 = icmp ne <16 x i8> %0, zeroinitializer
%3 = sext <16 x i1> %2 to <16 x i8>
%4 = bitcast <16 x i8> %3 to <4 x i32>
%5 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %4)
%6 = icmp ne i32 %5, 0
ret i1 %6
}
Comment on lines +5 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This IR is coming from https://godbolt.org/z/YMo1qqccT right? I'm surprised that we end up with v4i32 from the v16i8 type in the C. Do you know where this is being introduced? Perhaps the easiest fix here is to try and keep it in v16i8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I'm not sure where it comes from but if i remove the bitcast, then it'll produce i8x16 all_true but generates a few more lines. Will investigate more

define i1 @reduce_and_to_all_true_16i8(<16 x i8> %0) {
; CHECK-LABEL: reduce_and_to_all_true_16i8:
; CHECK:         .functype reduce_and_to_all_true_16i8 (v128) -> (i32)
; CHECK-NEXT:  # %bb.0:
; CHECK-NEXT:    i8x16.all_true $push0=, $0
; CHECK-NEXT:    i32.const $push1=, 255
; CHECK-NEXT:    i32.and $push2=, $pop0, $pop1
; CHECK-NEXT:    i32.const $push3=, 0
; CHECK-NEXT:    i32.ne $push4=, $pop2, $pop3
; CHECK-NEXT:    return $pop4
  %2 = icmp ne <16 x i8> %0, zeroinitializer
  %3 = sext <16 x i1> %2 to <16 x i8>
  %4 = tail call i8 @llvm.vector.reduce.and.v8i16(<16 x i8> %3)
  %5 = icmp ne i8 %4, 0
  ret i1 %5
}



define i1 @reduce_and_to_all_true_4i32(<4 x i32> %0) {
; CHECK-LABEL: reduce_and_to_all_true_4i32:
; CHECK: .functype reduce_and_to_all_true_4i32 (v128) -> (i32)
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i32x4.all_true $push0=, $0
; CHECK-NEXT: return $pop0
%2 = icmp ne <4 x i32> %0, zeroinitializer
%3 = sext <4 x i1> %2 to <4 x i32>
%4 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %3)
%5 = icmp ne i32 %4, 0
ret i1 %5
}



define i1 @reduce_and_to_all_true_2i64(<2 x i64> %0) {
; CHECK-LABEL: reduce_and_to_all_true_2i64:
; CHECK: .functype reduce_and_to_all_true_2i64 (v128) -> (i32)
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: i32x4.all_true $push0=, $0
; CHECK-NEXT: return $pop0
%2 = bitcast <2 x i64> %0 to <4 x i32>
%3 = icmp ne <4 x i32> %2, zeroinitializer
%4 = sext <4 x i1> %3 to <4 x i32>
%5 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %4)
%6 = icmp ne i32 %5, 0
ret i1 %6
}