Skip to content

Commit

Permalink
Provide and extra as_batch() method for batch_constant and as_bool_ba…
Browse files Browse the repository at this point in the history
…tch() batch_bool_constant

They come in addition to the implicit conversion operator which is
cumbersome to use.
  • Loading branch information
serge-sans-paille committed Mar 28, 2024
1 parent cf66d84 commit fe0d160
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 25 deletions.
8 changes: 4 additions & 4 deletions include/xsimd/arch/xsimd_avx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,8 +1517,8 @@ namespace xsimd
batch_constant<uint32_t, A, (V0 % 4), (V1 % 4), (V2 % 4), (V3 % 4), (V4 % 4), (V5 % 4), (V6 % 4), (V7 % 4)> half_mask;

// permute within each lane
__m256 r0 = _mm256_permutevar_ps(low_low, (batch<uint32_t, A>)half_mask);
__m256 r1 = _mm256_permutevar_ps(hi_hi, (batch<uint32_t, A>)half_mask);
__m256 r0 = _mm256_permutevar_ps(low_low, half_mask.as_batch());
__m256 r1 = _mm256_permutevar_ps(hi_hi, half_mask.as_batch());

// mask to choose the right lane
batch_bool_constant<uint32_t, A, (V0 >= 4), (V1 >= 4), (V2 >= 4), (V3 >= 4), (V4 >= 4), (V5 >= 4), (V6 >= 4), (V7 >= 4)> blend_mask;
Expand All @@ -1542,8 +1542,8 @@ namespace xsimd
batch_constant<uint64_t, A, (V0 % 2) * -1, (V1 % 2) * -1, (V2 % 2) * -1, (V3 % 2) * -1> half_mask;

// permute within each lane
__m256d r0 = _mm256_permutevar_pd(low_low, (batch<uint64_t, A>)half_mask);
__m256d r1 = _mm256_permutevar_pd(hi_hi, (batch<uint64_t, A>)half_mask);
__m256d r0 = _mm256_permutevar_pd(low_low, half_mask.as_batch());
__m256d r1 = _mm256_permutevar_pd(hi_hi, half_mask.as_batch());

// mask to choose the right lane
batch_bool_constant<uint64_t, A, (V0 >= 2), (V1 >= 2), (V2 >= 2), (V3 >= 2)> blend_mask;
Expand Down
4 changes: 2 additions & 2 deletions include/xsimd/arch/xsimd_avx2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ namespace xsimd
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>
inline batch<float, A> swizzle(batch<float, A> const& self, batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<avx2>) noexcept
{
return _mm256_permutevar8x32_ps(self, (batch<uint32_t, A>)mask);
return _mm256_permutevar8x32_ps(self, mask.as_batch());
}

template <class A, uint64_t V0, uint64_t V1, uint64_t V2, uint64_t V3>
Expand All @@ -938,7 +938,7 @@ namespace xsimd
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>
inline batch<uint32_t, A> swizzle(batch<uint32_t, A> const& self, batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<avx2>) noexcept
{
return _mm256_permutevar8x32_epi32(self, (batch<uint32_t, A>)mask);
return _mm256_permutevar8x32_epi32(self, mask.as_batch());
}
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>
inline batch<int32_t, A> swizzle(batch<int32_t, A> const& self, batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<avx2>) noexcept
Expand Down
8 changes: 4 additions & 4 deletions include/xsimd/arch/xsimd_avx512bw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,25 +619,25 @@ namespace xsimd
template <class A, uint16_t... Vs>
inline batch<uint16_t, A> swizzle(batch<uint16_t, A> const& self, batch_constant<uint16_t, A, Vs...> mask, requires_arch<avx512bw>) noexcept
{
return swizzle(self, (batch<uint16_t, A>)mask, avx512bw {});
return swizzle(self, mask.as_batch(), avx512bw {});
}

template <class A, uint16_t... Vs>
inline batch<int16_t, A> swizzle(batch<int16_t, A> const& self, batch_constant<uint16_t, A, Vs...> mask, requires_arch<avx512bw>) noexcept
{
return swizzle(self, (batch<uint16_t, A>)mask, avx512bw {});
return swizzle(self, mask.as_batch(), avx512bw {});
}

template <class A, uint8_t... Vs>
inline batch<uint8_t, A> swizzle(batch<uint8_t, A> const& self, batch_constant<uint8_t, A, Vs...> mask, requires_arch<avx512bw>) noexcept
{
return swizzle(self, (batch<uint8_t, A>)mask, avx512bw {});
return swizzle(self, mask.as_batch(), avx512bw {});
}

template <class A, uint8_t... Vs>
inline batch<int8_t, A> swizzle(batch<int8_t, A> const& self, batch_constant<uint8_t, A, Vs...> mask, requires_arch<avx512bw>) noexcept
{
return swizzle(self, (batch<uint8_t, A>)mask, avx512bw {});
return swizzle(self, mask.as_batch(), avx512bw {});
}

// zip_hi
Expand Down
16 changes: 8 additions & 8 deletions include/xsimd/arch/xsimd_avx512f.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ namespace xsimd
inline T reduce_max(batch<T, A> const& self, requires_arch<avx512f>) noexcept
{
constexpr batch_constant<uint64_t, A, 5, 6, 7, 8, 0, 0, 0, 0> mask;
batch<T, A> step = _mm512_permutexvar_epi64((batch<uint64_t, A>)mask, self);
batch<T, A> step = _mm512_permutexvar_epi64(mask.as_batch(), self);
batch<T, A> acc = max(self, step);
__m256i low = _mm512_castsi512_si256(acc);
return reduce_max(batch<T, avx2>(low));
Expand All @@ -1434,7 +1434,7 @@ namespace xsimd
inline T reduce_min(batch<T, A> const& self, requires_arch<avx512f>) noexcept
{
constexpr batch_constant<uint64_t, A, 5, 6, 7, 8, 0, 0, 0, 0> mask;
batch<T, A> step = _mm512_permutexvar_epi64((batch<uint64_t, A>)mask, self);
batch<T, A> step = _mm512_permutexvar_epi64(mask.as_batch(), self);
batch<T, A> acc = min(self, step);
__m256i low = _mm512_castsi512_si256(acc);
return reduce_min(batch<T, avx2>(low));
Expand Down Expand Up @@ -1919,37 +1919,37 @@ namespace xsimd
template <class A, uint32_t... Vs>
inline batch<float, A> swizzle(batch<float, A> const& self, batch_constant<uint32_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
{
return swizzle(self, (batch<uint32_t, A>)mask, avx512f {});
return swizzle(self, mask.as_batch(), avx512f {});
}

template <class A, uint64_t... Vs>
inline batch<double, A> swizzle(batch<double, A> const& self, batch_constant<uint64_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
{
return swizzle(self, (batch<uint64_t, A>)mask, avx512f {});
return swizzle(self, mask.as_batch(), avx512f {});
}

template <class A, uint64_t... Vs>
inline batch<uint64_t, A> swizzle(batch<uint64_t, A> const& self, batch_constant<uint64_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
{
return swizzle(self, (batch<uint64_t, A>)mask, avx512f {});
return swizzle(self, mask.as_batch(), avx512f {});
}

template <class A, uint64_t... Vs>
inline batch<int64_t, A> swizzle(batch<int64_t, A> const& self, batch_constant<uint64_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
{
return swizzle(self, (batch<uint64_t, A>)mask, avx512f {});
return swizzle(self, mask.as_batch(), avx512f {});
}

template <class A, uint32_t... Vs>
inline batch<uint32_t, A> swizzle(batch<uint32_t, A> const& self, batch_constant<uint32_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
{
return swizzle(self, (batch<uint32_t, A>)mask, avx512f {});
return swizzle(self, mask.as_batch(), avx512f {});
}

template <class A, uint32_t... Vs>
inline batch<int32_t, A> swizzle(batch<int32_t, A> const& self, batch_constant<uint32_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
{
return swizzle(self, (batch<uint32_t, A>)mask, avx512f {});
return swizzle(self, mask.as_batch(), avx512f {});
}

namespace detail
Expand Down
6 changes: 3 additions & 3 deletions include/xsimd/arch/xsimd_ssse3.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ namespace xsimd
constexpr batch_constant<uint8_t, A, 2 * V0, 2 * V0 + 1, 2 * V1, 2 * V1 + 1, 2 * V2, 2 * V2 + 1, 2 * V3, 2 * V3 + 1,
2 * V4, 2 * V4 + 1, 2 * V5, 2 * V5 + 1, 2 * V6, 2 * V6 + 1, 2 * V7, 2 * V7 + 1>
mask8;
return _mm_shuffle_epi8(self, (batch<uint8_t, A>)mask8);
return _mm_shuffle_epi8(self, mask8.as_batch());
}

template <class A, uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3, uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7>
Expand All @@ -158,14 +158,14 @@ namespace xsimd
uint8_t V8, uint8_t V9, uint8_t V10, uint8_t V11, uint8_t V12, uint8_t V13, uint8_t V14, uint8_t V15>
inline batch<uint8_t, A> swizzle(batch<uint8_t, A> const& self, batch_constant<uint8_t, A, V0, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15> mask, requires_arch<ssse3>) noexcept
{
return swizzle(self, (batch<uint8_t, A>)mask, ssse3 {});
return swizzle(self, mask.as_batch(), ssse3 {});
}

template <class A, uint8_t V0, uint8_t V1, uint8_t V2, uint8_t V3, uint8_t V4, uint8_t V5, uint8_t V6, uint8_t V7,
uint8_t V8, uint8_t V9, uint8_t V10, uint8_t V11, uint8_t V12, uint8_t V13, uint8_t V14, uint8_t V15>
inline batch<int8_t, A> swizzle(batch<int8_t, A> const& self, batch_constant<uint8_t, A, V0, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15> mask, requires_arch<ssse3>) noexcept
{
return swizzle(self, (batch<uint8_t, A>)mask, ssse3 {});
return swizzle(self, mask.as_batch(), ssse3 {});
}

}
Expand Down
4 changes: 2 additions & 2 deletions include/xsimd/arch/xsimd_sve.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ namespace xsimd
inline batch<T, A> swizzle(batch<T, A> const& arg, batch_constant<I, A, idx...> indices, requires_arch<sve>) noexcept
{
static_assert(batch<T, A>::size == sizeof...(idx), "invalid swizzle indices");
return swizzle(arg, (batch<I, A>)indices, sve {});
return swizzle(arg, indices.as_batch(), sve {});
}

template <class A, class T, class I, I... idx>
Expand All @@ -751,7 +751,7 @@ namespace xsimd
requires_arch<sve>) noexcept
{
static_assert(batch<std::complex<T>, A>::size == sizeof...(idx), "invalid swizzle indices");
return swizzle(arg, (batch<I, A>)indices, sve {});
return swizzle(arg, indices.as_batch(), sve {});
}

/*************
Expand Down
17 changes: 15 additions & 2 deletions include/xsimd/types/xsimd_batch_constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,15 @@ namespace xsimd
static_assert(sizeof...(Values) == batch_type::size, "consistent batch size");

public:
constexpr operator batch_bool<T, A>() const noexcept { return { Values... }; }
/**
* @brief Generate a batch of @p batch_type from this @p batch_bool_constant
*/
constexpr batch_type as_batch_bool() const noexcept { return { Values... }; }

/**
* @brief Generate a batch of @p batch_type from this @p batch_bool_constant
*/
constexpr operator batch_type() const noexcept { return as_batch_bool(); }

constexpr bool get(size_t i) const noexcept
{
Expand Down Expand Up @@ -130,7 +138,12 @@ namespace xsimd
/**
* @brief Generate a batch of @p batch_type from this @p batch_constant
*/
inline operator batch_type() const noexcept { return { Values... }; }
inline batch_type as_batch() const noexcept { return { Values... }; }

/**
* @brief Generate a batch of @p batch_type from this @p batch_constant
*/
inline operator batch_type() const noexcept { return as_batch(); }

/**
* @brief Get the @p i th element of this @p batch_constant
Expand Down
22 changes: 22 additions & 0 deletions test/test_batch_constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ struct constant_batch_test
CHECK_BATCH_EQ((batch_type)b, expected);
}

void test_cast() const
{
constexpr auto cst_b = xsimd::make_batch_constant<value_type, arch_type, generator>();
auto b0 = cst_b.as_batch();
auto b1 = (batch_type)cst_b;
CHECK_BATCH_EQ(b0, b1);
// The actual values are already tested in test_init_from_generator
}

struct arange
{
static constexpr value_type get(size_t index, size_t /*size*/)
Expand Down Expand Up @@ -135,6 +144,8 @@ TEST_CASE_TEMPLATE("[constant batch]", B, BATCH_INT_TYPES)
constant_batch_test<B> Test;
SUBCASE("init_from_generator") { Test.test_init_from_generator(); }

SUBCASE("as_batch") { Test.test_cast(); }

SUBCASE("init_from_generator_arange")
{
Test.test_init_from_generator_arange();
Expand Down Expand Up @@ -216,6 +227,15 @@ struct constant_bool_batch_test
}
};

void test_cast() const
{
constexpr auto all_true = xsimd::make_batch_bool_constant<value_type, arch_type, constant<true>>();
auto b0 = all_true.as_batch_bool();
auto b1 = (batch_bool_type)all_true;
CHECK_BATCH_EQ(b0, batch_bool_type(true));
CHECK_BATCH_EQ(b1, batch_bool_type(true));
}

void test_ops() const
{
constexpr auto all_true = xsimd::make_batch_bool_constant<value_type, arch_type, constant<true>>();
Expand Down Expand Up @@ -252,6 +272,8 @@ TEST_CASE_TEMPLATE("[constant bool batch]", B, BATCH_INT_TYPES)
constant_bool_batch_test<B> Test;
SUBCASE("init_from_generator") { Test.test_init_from_generator(); }

SUBCASE("as_batch") { Test.test_cast(); }

SUBCASE("init_from_generator_split")
{
Test.test_init_from_generator_split();
Expand Down

0 comments on commit fe0d160

Please sign in to comment.