Skip to content

Commit

Permalink
Implement approximation for pow
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnh committed Nov 18, 2024
1 parent f89cf98 commit 014e32f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 8 deletions.
25 changes: 22 additions & 3 deletions include/kernel_float/binops.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,7 @@ struct multiply<bool> {
namespace detail {
template<typename Policy, typename T, size_t N>
struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
KERNEL_FLOAT_INLINE static void
call(ops::divide<T> fun, T* result, const T* lhs, const T* rhs) {
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
T rhs_rcp[N];

// Fast way to perform division is to multiply by the reciprocal
Expand All @@ -310,13 +309,33 @@ struct apply_impl<accurate_policy, ops::divide<T>, N, T, T, T>:
template<>
struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
KERNEL_FLOAT_INLINE static void
call(ops::divide<float> fun, float* result, const float* lhs, const float* rhs) {
call(ops::divide<float>, float* result, const float* lhs, const float* rhs) {
*result = __fdividef(*lhs, *rhs);
}
};
#endif
} // namespace detail

namespace detail {
// Override `pow` using `log2` and `exp2`
template<typename Policy, typename T, size_t N>
struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
T lhs_log[N];
T result_log[N];

// Fast way to perform power function is using log2 and exp2
apply_impl<Policy, ops::log2<T>, N, T, T>::call({}, lhs_log, lhs);
apply_impl<Policy, ops::multiply<T>, N, T, T, T>::call({}, result_log, lhs_log, rhs);
apply_impl<Policy, ops::exp2<T>, N, T, T, T>::call({}, result, result_log);
}
};

template<typename T, size_t N>
struct apply_impl<accurate_policy, ops::pow<T>, N, T, T, T>:
apply_base_impl<accurate_policy, ops::pow<T>, N, T, T, T> {};
} // namespace detail

template<typename L, typename R, typename T = promoted_vector_value_type<L, R>>
KERNEL_FLOAT_INLINE zip_common_type<ops::divide<T>, T, T>
fast_divide(const L& left, const R& right) {
Expand Down
29 changes: 24 additions & 5 deletions single_include/kernel_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

//================================================================================
// this file has been auto-generated, do not modify its contents!
// date: 2024-11-18 13:40:03.668017
// git hash: ae0e6b16ac2d626e69bb08554044a77671f408ab
// date: 2024-11-18 13:50:24.614671
// git hash: f89cf98f79e78ab6013063dea4b4b516ce163855
//================================================================================

#ifndef KERNEL_FLOAT_MACROS_H
Expand Down Expand Up @@ -1950,8 +1950,7 @@ struct multiply<bool> {
namespace detail {
template<typename Policy, typename T, size_t N>
struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
KERNEL_FLOAT_INLINE static void
call(ops::divide<T> fun, T* result, const T* lhs, const T* rhs) {
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
T rhs_rcp[N];

// Fast way to perform division is to multiply by the reciprocal
Expand All @@ -1968,13 +1967,33 @@ struct apply_impl<accurate_policy, ops::divide<T>, N, T, T, T>:
template<>
struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
KERNEL_FLOAT_INLINE static void
call(ops::divide<float> fun, float* result, const float* lhs, const float* rhs) {
call(ops::divide<float>, float* result, const float* lhs, const float* rhs) {
*result = __fdividef(*lhs, *rhs);
}
};
#endif
} // namespace detail

namespace detail {
// Override `pow` using `log2` and `exp2`
template<typename Policy, typename T, size_t N>
struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
T lhs_log[N];
T result_log[N];

// Fast way to perform power function is using log2 and exp2
apply_impl<Policy, ops::log2<T>, N, T, T>::call({}, lhs_log, lhs);
apply_impl<Policy, ops::multiply<T>, N, T, T, T>::call({}, result_log, lhs_log, rhs);
apply_impl<Policy, ops::exp2<T>, N, T, T, T>::call({}, result, result_log);
}
};

template<typename T, size_t N>
struct apply_impl<accurate_policy, ops::pow<T>, N, T, T, T>:
apply_base_impl<accurate_policy, ops::pow<T>, N, T, T, T> {};
} // namespace detail

template<typename L, typename R, typename T = promoted_vector_value_type<L, R>>
KERNEL_FLOAT_INLINE zip_common_type<ops::divide<T>, T, T>
fast_divide(const L& left, const R& right) {
Expand Down

0 comments on commit 014e32f

Please sign in to comment.