Skip to content

Commit

Permalink
[sve] Support PMULLB/T for Q destination elements (#126)
Browse files Browse the repository at this point in the history
Extend the SVE PMULL instructions to support Q destination elements when the
CPU feature is supported.
  • Loading branch information
mmc28a authored Jan 24, 2025
1 parent aca39bd commit 4415fe4
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 23 deletions.
12 changes: 6 additions & 6 deletions src/aarch64/assembler-sve-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7410,13 +7410,13 @@ void Assembler::pmullb(const ZRegister& zd,
// size<23:22> | Zm<20:16> | op<12> | U<11> | T<10> | Zn<9:5> | Zd<4:0>

VIXL_ASSERT(CPUHas(CPUFeatures::kSVE2));
VIXL_ASSERT(CPUHas(CPUFeatures::kSVEPmull128) || !zd.IsLaneSizeQ());
VIXL_ASSERT(AreSameLaneSize(zn, zm));
VIXL_ASSERT(!zd.IsLaneSizeB() && !zd.IsLaneSizeS());
VIXL_ASSERT(zd.GetLaneSizeInBytes() == zn.GetLaneSizeInBytes() * 2);
// SVEPmull128 is not supported
VIXL_ASSERT(!zd.IsLaneSizeQ());
Instr size = zd.IsLaneSizeQ() ? 0 : SVESize(zd);

Emit(0x45006800 | SVESize(zd) | Rd(zd) | Rn(zn) | Rm(zm));
Emit(0x45006800 | size | Rd(zd) | Rn(zn) | Rm(zm));
}

void Assembler::pmullt(const ZRegister& zd,
Expand All @@ -7427,13 +7427,13 @@ void Assembler::pmullt(const ZRegister& zd,
// size<23:22> | Zm<20:16> | op<12> | U<11> | T<10> | Zn<9:5> | Zd<4:0>

VIXL_ASSERT(CPUHas(CPUFeatures::kSVE2));
VIXL_ASSERT(CPUHas(CPUFeatures::kSVEPmull128) || !zd.IsLaneSizeQ());
VIXL_ASSERT(AreSameLaneSize(zn, zm));
VIXL_ASSERT(!zd.IsLaneSizeB() && !zd.IsLaneSizeS());
VIXL_ASSERT(zd.GetLaneSizeInBytes() == zn.GetLaneSizeInBytes() * 2);
// SVEPmull128 is not supported
VIXL_ASSERT(!zd.IsLaneSizeQ());
Instr size = zd.IsLaneSizeQ() ? 0 : SVESize(zd);

Emit(0x45006c00 | SVESize(zd) | Rd(zd) | Rn(zn) | Rm(zm));
Emit(0x45006c00 | size | Rd(zd) | Rn(zn) | Rm(zm));
}

void Assembler::raddhnb(const ZRegister& zd,
Expand Down
4 changes: 4 additions & 0 deletions src/aarch64/cpu-features-auditor-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1882,6 +1882,10 @@ void CPUFeaturesAuditor::Visit(Metadata* metadata, const Instruction* instr) {
CPUFeatures(CPUFeatures::kNEON, CPUFeatures::kSHA512)},
{"sha512su1_vvv2_cryptosha512_3"_h,
CPUFeatures(CPUFeatures::kNEON, CPUFeatures::kSHA512)},
{"pmullb_z_zz_q"_h,
CPUFeatures(CPUFeatures::kSVE2, CPUFeatures::kSVEPmull128)},
{"pmullt_z_zz_q"_h,
CPUFeatures(CPUFeatures::kSVE2, CPUFeatures::kSVEPmull128)},
};

if (features.count(form_hash_) > 0) {
Expand Down
25 changes: 19 additions & 6 deletions src/aarch64/disasm-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,8 @@ const Disassembler::FormToVisitorFnMap *Disassembler::GetFormToVisitorFnMap() {
{"nbsl_z_zzz"_h, &Disassembler::DisassembleSVEBitwiseTernary},
{"nmatch_p_p_zz"_h, &Disassembler::Disassemble_PdT_PgZ_ZnT_ZmT},
{"pmul_z_zz"_h, &Disassembler::Disassemble_ZdB_ZnB_ZmB},
{"pmullb_z_zz"_h, &Disassembler::Disassemble_ZdT_ZnTb_ZmTb},
{"pmullt_z_zz"_h, &Disassembler::Disassemble_ZdT_ZnTb_ZmTb},
{"pmullb_z_zz"_h, &Disassembler::DisassembleSVEPmull},
{"pmullt_z_zz"_h, &Disassembler::DisassembleSVEPmull},
{"raddhnb_z_zz"_h, &Disassembler::DisassembleSVEAddSubHigh},
{"raddhnt_z_zz"_h, &Disassembler::DisassembleSVEAddSubHigh},
{"rax1_z_zz"_h, &Disassembler::Disassemble_ZdD_ZnD_ZmD},
Expand Down Expand Up @@ -761,6 +761,8 @@ const Disassembler::FormToVisitorFnMap *Disassembler::GetFormToVisitorFnMap() {
{"sha512h_qqv_cryptosha512_3"_h, &Disassembler::DisassembleSHA512},
{"sha512su0_vv2_cryptosha512_2"_h, &Disassembler::DisassembleSHA512},
{"sha512su1_vvv2_cryptosha512_3"_h, &Disassembler::DisassembleSHA512},
{"pmullb_z_zz_q"_h, &Disassembler::DisassembleSVEPmull128},
{"pmullt_z_zz_q"_h, &Disassembler::DisassembleSVEPmull128},
};
return &form_to_visitor;
} // NOLINT(readability/fn_size)
Expand Down Expand Up @@ -5852,15 +5854,26 @@ void Disassembler::Disassemble_ZdT_ZnTb(const Instruction *instr) {
}
}

void Disassembler::DisassembleSVEPmull(const Instruction *instr) {
if (instr->GetSVEVectorFormat() == kFormatVnS) {
VisitUnallocated(instr);
} else {
Disassemble_ZdT_ZnTb_ZmTb(instr);
}
}

void Disassembler::DisassembleSVEPmull128(const Instruction *instr) {
FormatWithDecodedMnemonic(instr, "'Zd.q, 'Zn.d, 'Zm.d");
}

void Disassembler::Disassemble_ZdT_ZnTb_ZmTb(const Instruction *instr) {
const char *form = "'Zd.'t, 'Zn.'th, 'Zm.'th";
if (instr->GetSVEVectorFormat() == kFormatVnB) {
// TODO: This is correct for saddlbt, ssublbt, subltb, which don't have
// b-lane sized form, and for pmull[b|t] as feature `SVEPmull128` isn't
// supported, but may need changes for other instructions reaching here.
// b-lane sized form, but may need changes for other instructions reaching
// here.
Format(instr, "unimplemented", "(ZdT_ZnTb_ZmTb)");
} else {
Format(instr, mnemonic_.c_str(), form);
FormatWithDecodedMnemonic(instr, "'Zd.'t, 'Zn.'th, 'Zm.'th");
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/aarch64/disasm-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ class Disassembler : public DecoderVisitor {
void DisassembleSVEBitwiseTernary(const Instruction* instr);
void DisassembleSVEFlogb(const Instruction* instr);
void DisassembleSVEFPPair(const Instruction* instr);
void DisassembleSVEPmull(const Instruction* instr);
void DisassembleSVEPmull128(const Instruction* instr);

void DisassembleNoArgs(const Instruction* instr);

Expand Down
2 changes: 2 additions & 0 deletions src/aarch64/instructions-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,8 @@ VectorFormat VectorFormatHalfWidth(VectorFormat vform) {
return kFormatVnH;
case kFormatVnD:
return kFormatVnS;
case kFormatVnQ:
return kFormatVnD;
default:
VIXL_UNREACHABLE();
return kFormatUndefined;
Expand Down
27 changes: 23 additions & 4 deletions src/aarch64/simulator-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ const Simulator::FormToVisitorFnMap* Simulator::GetFormToVisitorFnMap() {
{"sha512h2_qqv_cryptosha512_3"_h, &Simulator::SimulateSHA512},
{"sha512su0_vv2_cryptosha512_2"_h, &Simulator::SimulateSHA512},
{"sha512su1_vvv2_cryptosha512_3"_h, &Simulator::SimulateSHA512},
{"pmullb_z_zz_q"_h, &Simulator::SimulateSVEPmull128},
{"pmullt_z_zz_q"_h, &Simulator::SimulateSVEPmull128},
};
return &form_to_visitor;
}
Expand Down Expand Up @@ -2909,6 +2911,23 @@ void Simulator::SimulateSVEInterleavedArithLong(const Instruction* instr) {
}
}

void Simulator::SimulateSVEPmull128(const Instruction* instr) {
SimVRegister& zd = ReadVRegister(instr->GetRd());
SimVRegister& zm = ReadVRegister(instr->GetRm());
SimVRegister& zn = ReadVRegister(instr->GetRn());
SimVRegister zn_temp, zm_temp;

if (form_hash_ == "pmullb_z_zz_q"_h) {
pack_even_elements(kFormatVnD, zn_temp, zn);
pack_even_elements(kFormatVnD, zm_temp, zm);
} else {
VIXL_ASSERT(form_hash_ == "pmullt_z_zz_q"_h);
pack_odd_elements(kFormatVnD, zn_temp, zn);
pack_odd_elements(kFormatVnD, zm_temp, zm);
}
pmull(kFormatVnQ, zd, zn_temp, zm_temp);
}

void Simulator::SimulateSVEIntMulLongVec(const Instruction* instr) {
VectorFormat vform = instr->GetSVEVectorFormat();
SimVRegister& zd = ReadVRegister(instr->GetRd());
Expand All @@ -2923,15 +2942,15 @@ void Simulator::SimulateSVEIntMulLongVec(const Instruction* instr) {

switch (form_hash_) {
case "pmullb_z_zz"_h:
// '00' is reserved for Q-sized lane.
if (vform == kFormatVnB) {
// Size '10' is undefined.
if (vform == kFormatVnS) {
VIXL_UNIMPLEMENTED();
}
pmull(vform, zd, zn_b, zm_b);
break;
case "pmullt_z_zz"_h:
// '00' is reserved for Q-sized lane.
if (vform == kFormatVnB) {
// Size '10' is undefined.
if (vform == kFormatVnS) {
VIXL_UNIMPLEMENTED();
}
pmull(vform, zd, zn_t, zm_t);
Expand Down
8 changes: 4 additions & 4 deletions src/aarch64/simulator-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -872,10 +872,9 @@ class LogicVRegister {
SetUint(vform, index, value.second);
return;
}
// TODO: Extend this to SVE.
VIXL_ASSERT((vform == kFormat1Q) && (index == 0));
SetUint(kFormat2D, 0, value.second);
SetUint(kFormat2D, 1, value.first);
VIXL_ASSERT((vform == kFormat1Q) || (vform == kFormatVnQ));
SetUint(kFormatVnD, 2 * index, value.second);
SetUint(kFormatVnD, 2 * index + 1, value.first);
}

void SetUintArray(VectorFormat vform, const uint64_t* src) const {
Expand Down Expand Up @@ -1504,6 +1503,7 @@ class Simulator : public DecoderVisitor {
void SimulateSVESaturatingMulAddHigh(const Instruction* instr);
void SimulateSVESaturatingMulHighIndex(const Instruction* instr);
void SimulateSVEFPConvertLong(const Instruction* instr);
void SimulateSVEPmull128(const Instruction* instr);
void SimulateMatrixMul(const Instruction* instr);
void SimulateSVEFPMatrixMul(const Instruction* instr);
void SimulateNEONMulByElementLong(const Instruction* instr);
Expand Down
8 changes: 8 additions & 0 deletions test/aarch64/test-cpu-features-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3856,5 +3856,13 @@ TEST_FEAT(sm4e, sm4e(v12.V4S(), v13.V4S()))
TEST_FEAT(sm4ekey, sm4ekey(v12.V4S(), v13.V4S(), v14.V4S()))
#undef TEST_FEAT

#define TEST_FEAT(NAME, ASM) \
TEST_TEMPLATE(CPUFeatures(CPUFeatures::kSVE2, CPUFeatures::kSVEPmull128), \
SVE_PMULL128_##NAME, \
ASM)
TEST_FEAT(pmullb, pmullb(z12.VnQ(), z21.VnD(), z12.VnD()))
TEST_FEAT(pmullt, pmullt(z12.VnQ(), z21.VnD(), z12.VnD()))
#undef TEST_FEAT

} // namespace aarch64
} // namespace vixl
11 changes: 8 additions & 3 deletions test/aarch64/test-disasm-sve-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7673,13 +7673,14 @@ TEST(sve2_integer_multiply_long_vector) {
COMPARE(sqdmullt(z7.VnD(), z4.VnS(), z0.VnS(), 0),
"sqdmullt z7.d, z4.s, z0.s[0]");

// Feature `SVEPmull128` is not supported.
// COMPARE(pmullb(z12.VnQ(), z21.VnD(), z12.VnD()),
// "pmullb z12.q, z21.d, z12.d");
COMPARE(pmullb(z12.VnH(), z21.VnB(), z12.VnB()),
"pmullb z12.h, z21.b, z12.b");
COMPARE(pmullt(z31.VnD(), z30.VnS(), z26.VnS()),
"pmullt z31.d, z30.s, z26.s");
COMPARE(pmullb(z12.VnQ(), z21.VnD(), z12.VnD()),
"pmullb z12.q, z21.d, z12.d");
COMPARE(pmullt(z12.VnQ(), z21.VnD(), z12.VnD()),
"pmullt z12.q, z21.d, z12.d");

COMPARE(smullb(z10.VnD(), z4.VnS(), z4.VnS()), "smullb z10.d, z4.s, z4.s");
COMPARE(smullb(z11.VnH(), z14.VnB(), z14.VnB()),
Expand All @@ -7701,6 +7702,10 @@ TEST(sve2_integer_multiply_long_vector) {
COMPARE(umullt(z24.VnH(), z7.VnB(), z16.VnB()), "umullt z24.h, z7.b, z16.b");
COMPARE(umullt(z24.VnS(), z8.VnH(), z26.VnH()), "umullt z24.s, z8.h, z26.h");

// Check related but undefined encodings.
COMPARE(dci(0x45806800), "unallocated (Unallocated)"); // pmullb s, h, h
COMPARE(dci(0x45806c00), "unallocated (Unallocated)"); // pmullt s, h, h

CLEANUP();
}

Expand Down
125 changes: 125 additions & 0 deletions test/aarch64/test-simulator-sve2-aarch64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9117,5 +9117,130 @@ TEST_SVE(sve2_extract) {
}
}

TEST_SVE(sve2_pmull128) {
SVE_SETUP_WITH_FEATURES(CPUFeatures::kSVE,
CPUFeatures::kSVE2,
CPUFeatures::kNEON,
CPUFeatures::kCRC32,
CPUFeatures::kSVEPmull128);
START();

SetInitialMachineState(&masm);
// state = 0xe2bd2480

{
ExactAssemblyScope scope(&masm, 40 * kInstructionSize);
__ dci(0x45006800); // pmullb z0.q, z0.d, z0.d
// vl128 state = 0x4107ca0c
__ dci(0x45006a28); // pmullb z8.q, z17.d, z0.d
// vl128 state = 0xa87d231a
__ dci(0x45016a6c); // pmullb z12.q, z19.d, z1.d
// vl128 state = 0xc547fcf6
__ dci(0x45116e68); // pmullt z8.q, z19.d, z17.d
// vl128 state = 0x6a01d521
__ dci(0x45106a69); // pmullb z9.q, z19.d, z16.d
// vl128 state = 0x64a7ba8a
__ dci(0x45006a4d); // pmullb z13.q, z18.d, z0.d
// vl128 state = 0xe59e3f8e
__ dci(0x45086e5d); // pmullt z29.q, z18.d, z8.d
// vl128 state = 0xbfbb9316
__ dci(0x450a6e75); // pmullt z21.q, z19.d, z10.d
// vl128 state = 0x29f6a4c7
__ dci(0x45126e74); // pmullt z20.q, z19.d, z18.d
// vl128 state = 0x4ced9406
__ dci(0x45176e75); // pmullt z21.q, z19.d, z23.d
// vl128 state = 0xd09e5676
__ dci(0x45176e77); // pmullt z23.q, z19.d, z23.d
// vl128 state = 0x568c0e25
__ dci(0x45176e75); // pmullt z21.q, z19.d, z23.d
// vl128 state = 0xb2f13c36
__ dci(0x45176b71); // pmullb z17.q, z27.d, z23.d
// vl128 state = 0x160bec4f
__ dci(0x451f6b30); // pmullb z16.q, z25.d, z31.d
// vl128 state = 0x2d7e7f49
__ dci(0x451f6b20); // pmullb z0.q, z25.d, z31.d
// vl128 state = 0x113d828b
__ dci(0x451f6b90); // pmullb z16.q, z28.d, z31.d
// vl128 state = 0xb8b3b3d9
__ dci(0x451f6f12); // pmullt z18.q, z24.d, z31.d
// vl128 state = 0x277aacb8
__ dci(0x451f6f16); // pmullt z22.q, z24.d, z31.d
// vl128 state = 0xef79c8da
__ dci(0x450b6f17); // pmullt z23.q, z24.d, z11.d
// vl128 state = 0x1dc19104
__ dci(0x450a6e1f); // pmullt z31.q, z16.d, z10.d
// vl128 state = 0x3ccb4ea8
__ dci(0x451a6e2f); // pmullt z15.q, z17.d, z26.d
// vl128 state = 0x14e13481
__ dci(0x45126a3f); // pmullb z31.q, z17.d, z18.d
// vl128 state = 0x4e6502f9
__ dci(0x451a6b3e); // pmullb z30.q, z25.d, z26.d
// vl128 state = 0xf6f18478
__ dci(0x45126a3a); // pmullb z26.q, z17.d, z18.d
// vl128 state = 0xdd4f14fb
__ dci(0x45126afb); // pmullb z27.q, z23.d, z18.d
// vl128 state = 0xcbf3bee2
__ dci(0x45126aff); // pmullb z31.q, z23.d, z18.d
// vl128 state = 0x627bec09
__ dci(0x45126aef); // pmullb z15.q, z23.d, z18.d
// vl128 state = 0xf5de1fa9
__ dci(0x45106abf); // pmullb z31.q, z21.d, z16.d
// vl128 state = 0x44bb6385
__ dci(0x451a6abb); // pmullb z27.q, z21.d, z26.d
// vl128 state = 0x5c5fa224
__ dci(0x450a68b3); // pmullb z19.q, z5.d, z10.d
// vl128 state = 0x28b6085c
__ dci(0x450e69b2); // pmullb z18.q, z13.d, z14.d
// vl128 state = 0x450898d6
__ dci(0x450e69b6); // pmullb z22.q, z13.d, z14.d
// vl128 state = 0x79d7911b
__ dci(0x450e69b4); // pmullb z20.q, z13.d, z14.d
// vl128 state = 0x98bf6939
__ dci(0x450f6924); // pmullb z4.q, z9.d, z15.d
// vl128 state = 0xb8a1bbc7
__ dci(0x45176925); // pmullb z5.q, z9.d, z23.d
// vl128 state = 0x631b41c8
__ dci(0x451f69a4); // pmullb z4.q, z13.d, z31.d
// vl128 state = 0x617fc272
__ dci(0x451b69e0); // pmullb z0.q, z15.d, z27.d
// vl128 state = 0x77780ac1
__ dci(0x451b69e8); // pmullb z8.q, z15.d, z27.d
// vl128 state = 0xce5ae18f
__ dci(0x450f69e0); // pmullb z0.q, z15.d, z15.d
// vl128 state = 0xa037371a
__ dci(0x450b6be8); // pmullb z8.q, z31.d, z11.d
// vl128 state = 0xb59be233
}

uint32_t state;
ComputeMachineStateHash(&masm, &state);
__ Mov(x0, reinterpret_cast<uint64_t>(&state));
__ Ldr(w0, MemOperand(x0));

END();
if (CAN_RUN()) {
RUN();
uint32_t expected_hashes[] = {
0xb59be233,
0x32430624,
0x5cc3ec66,
0xecfdffe7,
0x6d77a270,
0xa0d604f2,
0x2178aa11,
0xabdcbeaa,
0xab3b974f,
0x11a874f5,
0xf2eb6131,
0x6d311c6c,
0xd4e99b72,
0x5177ce8e,
0x32aa02f0,
0x681ef977,
};
ASSERT_EQUAL_64(expected_hashes[core.GetSVELaneCount(kQRegSize) - 1], x0);
}
}

} // namespace aarch64
} // namespace vixl

0 comments on commit 4415fe4

Please sign in to comment.