Skip to content

Commit

Permalink
work around broken gcc _mm512_cmp_ph_mask, reduce sign conversion war…
Browse files Browse the repository at this point in the history
…nings. Refs #2494

PiperOrigin-RevId: 732900038
  • Loading branch information
jan-wassenberg authored and copybara-github committed Mar 3, 2025
1 parent c766697 commit e1dffe4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
31 changes: 22 additions & 9 deletions hwy/ops/x86_256-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2413,29 +2413,36 @@ HWY_API Vec256<int32_t> BroadcastSignBit(const Vec256<int32_t> v) {
return ShiftRight<31>(v);
}

#if HWY_TARGET <= HWY_AVX3

template <int kBits>
HWY_API Vec256<int64_t> ShiftRight(const Vec256<int64_t> v) {
return Vec256<int64_t>{
_mm256_srai_epi64(v.raw, static_cast<Shift64Count>(kBits))};
}

HWY_API Vec256<int64_t> BroadcastSignBit(const Vec256<int64_t> v) {
return ShiftRight<63>(v);
}

#else // AVX2

HWY_API Vec256<int64_t> BroadcastSignBit(const Vec256<int64_t> v) {
#if HWY_TARGET == HWY_AVX2
const DFromV<decltype(v)> d;
return VecFromMask(v < Zero(d));
#else
return Vec256<int64_t>{_mm256_srai_epi64(v.raw, 63)};
#endif
}

template <int kBits>
HWY_API Vec256<int64_t> ShiftRight(const Vec256<int64_t> v) {
#if HWY_TARGET <= HWY_AVX3
return Vec256<int64_t>{
_mm256_srai_epi64(v.raw, static_cast<Shift64Count>(kBits))};
#else
const Full256<int64_t> di;
const Full256<uint64_t> du;
const auto right = BitCast(di, ShiftRight<kBits>(BitCast(du, v)));
const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v));
return right | sign;
#endif
}

#endif // #if HWY_TARGET <= HWY_AVX3

// ------------------------------ IfNegativeThenElse (BroadcastSignBit)
HWY_API Vec256<int8_t> IfNegativeThenElse(Vec256<int8_t> v, Vec256<int8_t> yes,
Vec256<int8_t> no) {
Expand Down Expand Up @@ -2495,6 +2502,10 @@ HWY_API Vec256<int32_t> IfNegativeThenNegOrUndefIfZero(Vec256<int32_t> mask,

// ------------------------------ ShiftLeftSame

// Disable sign conversion warnings for GCC debug intrinsics.
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion")

HWY_API Vec256<uint16_t> ShiftLeftSame(const Vec256<uint16_t> v,
const int bits) {
#if HWY_COMPILER_GCC
Expand Down Expand Up @@ -2642,6 +2653,8 @@ HWY_API Vec256<int8_t> ShiftRightSame(Vec256<int8_t> v, const int bits) {
return (shifted ^ shifted_sign) - shifted_sign;
}

HWY_DIAGNOSTICS(pop)

// ------------------------------ Neg (Xor, Sub)

// Tag dispatch instead of SFINAE for MSVC 2017 compatibility
Expand Down
17 changes: 16 additions & 1 deletion hwy/ops/x86_512-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,11 @@ HWY_API Vec512<T> Ror(Vec512<T> a, Vec512<T> b) {
// ------------------------------ ShiftLeftSame

// GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512
// shift-with-immediate: the counts should all be unsigned int.
// shift-with-immediate: the counts should all be unsigned int. Despite casting,
// we still see warnings in GCC debug builds, hence disable.
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion")

#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100
using Shift16Count = int;
using Shift3264Count = int;
Expand Down Expand Up @@ -1642,6 +1646,8 @@ HWY_API Vec512<int8_t> ShiftRightSame(Vec512<int8_t> v, const int bits) {
return (shifted ^ shifted_sign) - shifted_sign;
}

HWY_DIAGNOSTICS(pop)

// ------------------------------ Minimum

// Unsigned
Expand Down Expand Up @@ -2946,6 +2952,15 @@ HWY_API Vec512<int64_t> BroadcastSignBit(Vec512<int64_t> v) {

// ------------------------------ Floating-point classification (Not)

#if HWY_COMPILER_GCC_ACTUAL && defined(_mm512_fpclass_ph_mask)
// GCC's _mm512_cmp_ph_mask uses `__mmask8` instead of `__mmask32`, hence only
// the first 8 lanes are set.
#undef _mm512_fpclass_ph_mask
#define _mm512_fpclass_ph_mask(x, c) \
((__mmask32)__builtin_ia32_fpclassph512_mask((__v32hf)(__m512h)(x), \
(int)(c), (__mmask32) - 1))
#endif

#if HWY_HAVE_FLOAT16 || HWY_IDE

HWY_API Mask512<float16_t> IsNaN(Vec512<float16_t> v) {
Expand Down

0 comments on commit e1dffe4

Please sign in to comment.