From ae752732bb57c2fb1401f4e936ce56328c5ff242 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 25 Apr 2025 20:11:27 -0700 Subject: [PATCH 1/9] Use sycl complex extension throughout element-wise and utils --- .../kernels/elementwise_functions/acos.hpp | 17 +++-- .../kernels/elementwise_functions/acosh.hpp | 36 +++++----- .../kernels/elementwise_functions/asin.hpp | 24 +++---- .../kernels/elementwise_functions/asinh.hpp | 19 +++-- .../kernels/elementwise_functions/atan.hpp | 9 ++- .../kernels/elementwise_functions/atanh.hpp | 8 ++- .../elementwise_functions/cabs_impl.hpp | 13 ++-- .../kernels/elementwise_functions/cos.hpp | 27 +++---- .../kernels/elementwise_functions/cosh.hpp | 9 +-- .../kernels/elementwise_functions/exp.hpp | 11 +-- .../kernels/elementwise_functions/exp2.hpp | 13 ++-- .../kernels/elementwise_functions/expm1.hpp | 7 +- .../kernels/elementwise_functions/imag.hpp | 5 +- .../elementwise_functions/isfinite.hpp | 8 ++- .../kernels/elementwise_functions/isinf.hpp | 8 ++- .../kernels/elementwise_functions/isnan.hpp | 8 ++- .../kernels/elementwise_functions/log1p.hpp | 8 ++- .../kernels/elementwise_functions/proj.hpp | 8 ++- .../kernels/elementwise_functions/real.hpp | 7 +- .../kernels/elementwise_functions/round.hpp | 8 ++- .../kernels/elementwise_functions/sin.hpp | 16 +++-- .../kernels/elementwise_functions/sinh.hpp | 9 +-- .../kernels/elementwise_functions/tan.hpp | 14 ++-- .../kernels/elementwise_functions/tanh.hpp | 8 ++- .../libtensor/include/utils/math_utils.hpp | 71 ++++++++++++------- .../source/sorting/rich_comparisons.hpp | 16 +++-- 26 files changed, 230 insertions(+), 157 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp index 7dbfb6618c..41622906e1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp @@ -72,9 +72,10 @@ template struct AcosFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isnan(x)) { /* acos(NaN + I*+-Inf) = NaN + I*-+Inf */ @@ -106,12 +107,10 @@ template struct AcosFunctor constexpr realT r_eps = realT(1) / std::numeric_limits::epsilon(); if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = exprm_ns::complex; - sycl_complexT log_in = - exprm_ns::log(exprm_ns::complex(in)); + sycl_complexT log_z = exprm_ns::log(z); - const realT wx = log_in.real(); - const realT wy = log_in.imag(); + const realT wx = log_z.real(); + const realT wy = log_z.imag(); const realT rx = sycl::fabs(wy); realT ry = wx + sycl::log(realT(2)); @@ -119,7 +118,7 @@ template struct AcosFunctor } /* ordinary cases */ - return exprm_ns::acos(exprm_ns::complex(in)); // acos(in); + return exprm_ns::acos(z); // acos(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp index a81ff3da99..cdad55c8a4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp @@ -77,17 +77,19 @@ template struct AcoshFunctor * where the sign is chosen so Re(acosh(in)) >= 0. * So, we first calculate acos(in) and then acosh(in). */ - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); - resT acos_in; + sycl_complexT acos_z; if (std::isnan(x)) { /* acos(NaN + I*+-Inf) = NaN + I*-+Inf */ if (std::isinf(y)) { - acos_in = resT{q_nan, -y}; + acos_z = resT{q_nan, -y}; } else { - acos_in = resT{q_nan, q_nan}; + acos_z = resT{q_nan, q_nan}; } } else if (std::isnan(y)) { @@ -95,15 +97,15 @@ template struct AcoshFunctor constexpr realT inf = std::numeric_limits::infinity(); if (std::isinf(x)) { - acos_in = resT{q_nan, -inf}; + acos_z = resT{q_nan, -inf}; } /* acos(0 + I*NaN) = Pi/2 + I*NaN with inexact */ else if (x == realT(0)) { const realT pi_half = sycl::atan(realT(1)) * 2; - acos_in = resT{pi_half, q_nan}; + acos_z = resT{pi_half, q_nan}; } else { - acos_in = resT{q_nan, q_nan}; + acos_z = resT{q_nan, q_nan}; } } @@ -113,23 +115,21 @@ template struct AcoshFunctor * For large x or y including acos(+-Inf + I*+-Inf) */ if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = typename exprm_ns::complex; - const sycl_complexT log_in = exprm_ns::log(sycl_complexT(in)); - const realT wx = log_in.real(); - const realT wy = log_in.imag(); + const sycl_complexT log_z = exprm_ns::log(z); + const realT wx = log_z.real(); + const realT wy = log_z.imag(); const realT rx = sycl::fabs(wy); realT ry = wx + sycl::log(realT(2)); - acos_in = resT{rx, (sycl::signbit(y)) ? ry : -ry}; + acos_z = resT{rx, (sycl::signbit(y)) ? ry : -ry}; } else { /* ordinary cases */ - acos_in = - exprm_ns::acos(exprm_ns::complex(in)); // acos(in); + acos_z = exprm_ns::acos(z); // acos(z); } /* Now we calculate acosh(z) */ - const realT rx = std::real(acos_in); - const realT ry = std::imag(acos_in); + const realT rx = exprm_ns::real(acos_z); + const realT ry = exprm_ns::imag(acos_z); /* acosh(NaN + I*NaN) = NaN + I*NaN */ if (std::isnan(rx) && std::isnan(ry)) { @@ -145,7 +145,7 @@ template struct AcoshFunctor return resT{ry, ry}; } /* ordinary cases */ - const realT res_im = sycl::copysign(rx, std::imag(in)); + const realT res_im = sycl::copysign(rx, exprm_ns::imag(z)); return resT{sycl::fabs(ry), res_im}; } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp index 70b48895b4..cf15831aab 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp @@ -80,8 +80,10 @@ template struct AsinFunctor * y = imag(I * conj(in)) = real(in) * and then return {imag(w), real(w)} which is asin(in) */ - const realT x = std::imag(in); - const realT y = std::real(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::imag(z); + const realT y = exprm_ns::real(z); if (std::isnan(x)) { /* asinh(NaN + I*+-Inf) = opt(+-)Inf + I*NaN */ @@ -120,26 +122,24 @@ template struct AsinFunctor constexpr realT r_eps = realT(1) / std::numeric_limits::epsilon(); if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = exprm_ns::complex; - const sycl_complexT z{x, y}; + const sycl_complexT z1{x, y}; realT wx, wy; if (!sycl::signbit(x)) { - const auto log_z = exprm_ns::log(z); - wx = log_z.real() + sycl::log(realT(2)); - wy = log_z.imag(); + const auto log_z1 = exprm_ns::log(z1); + wx = log_z1.real() + sycl::log(realT(2)); + wy = log_z1.imag(); } else { - const auto log_mz = exprm_ns::log(-z); - wx = log_mz.real() + sycl::log(realT(2)); - wy = log_mz.imag(); + const auto log_mz1 = exprm_ns::log(-z1); + wx = log_mz1.real() + sycl::log(realT(2)); + wy = log_mz1.imag(); } const realT asinh_re = sycl::copysign(wx, x); const realT asinh_im = sycl::copysign(wy, y); return resT{asinh_im, asinh_re}; } /* ordinary cases */ - return exprm_ns::asin( - exprm_ns::complex(in)); // sycl::asin(in); + return exprm_ns::asin(z); // sycl::asin(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp index 420ba3246c..61c5b7a75c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp @@ -72,9 +72,10 @@ template struct AsinhFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isnan(x)) { /* asinh(NaN + I*+-Inf) = opt(+-)Inf + I*NaN */ @@ -109,12 +110,10 @@ template struct AsinhFunctor realT(1) / std::numeric_limits::epsilon(); if (sycl::fabs(x) > r_eps || sycl::fabs(y) > r_eps) { - using sycl_complexT = exprm_ns::complex; - sycl_complexT log_in = (sycl::signbit(x)) - ? exprm_ns::log(sycl_complexT(-in)) - : exprm_ns::log(sycl_complexT(in)); - realT wx = log_in.real() + sycl::log(realT(2)); - realT wy = log_in.imag(); + sycl_complexT log_in = + (sycl::signbit(x)) ? exprm_ns::log(-z) : exprm_ns::log(z); + realT wx = exprm_ns::real(log_in) + sycl::log(realT(2)); + realT wy = exprm_ns::imag(log_in); const realT res_re = sycl::copysign(wx, x); const realT res_im = sycl::copysign(wy, y); @@ -122,7 +121,7 @@ template struct AsinhFunctor } /* ordinary cases */ - return exprm_ns::asinh(exprm_ns::complex(in)); // asinh(in); + return exprm_ns::asinh(z); // asinh(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp index 29c4941d76..1f9e4079d1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp @@ -83,8 +83,11 @@ template struct AtanFunctor * y = imag(I * conj(in)) = real(in) * and then return {imag(w), real(w)} which is atan(in) */ - const realT x = std::imag(in); - const realT y = std::real(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::imag(z); + const realT y = exprm_ns::real(z); + if (std::isnan(x)) { /* atanh(NaN + I*+-Inf) = sign(NaN)*0 + I*+-Pi/2 */ if (std::isinf(y)) { @@ -132,7 +135,7 @@ template struct AtanFunctor return resT{atanh_im, atanh_re}; } /* ordinary cases */ - return exprm_ns::atan(exprm_ns::complex(in)); // atan(in); + return exprm_ns::atan(z); // atan(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp index 39f11e0f90..bf702ba575 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp @@ -73,8 +73,10 @@ template struct AtanhFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isnan(x)) { /* atanh(NaN + I*+-Inf) = sign(NaN)0 + I*+-PI/2 */ @@ -123,7 +125,7 @@ template struct AtanhFunctor return resT{res_re, res_im}; } /* ordinary cases */ - return exprm_ns::atanh(exprm_ns::complex(in)); // atanh(in); + return exprm_ns::atanh(z); // atanh(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp index afa83a64cb..e1677871a5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp @@ -51,8 +51,10 @@ template realT cabs(std::complex const &z) // * If x is a finite number and y is NaN, the result is NaN. // * If x is NaN and y is NaN, the result is NaN. - const realT x = std::real(z); - const realT y = std::imag(z); + using sycl_complexT = exprm_ns::complex; + sycl_complexT _z = exprm_ns::complex(z); + const realT x = exprm_ns::real(_z); + const realT y = exprm_ns::imag(_z); constexpr realT q_nan = std::numeric_limits::quiet_NaN(); constexpr realT p_inf = std::numeric_limits::infinity(); @@ -60,11 +62,8 @@ template realT cabs(std::complex const &z) const realT res = std::isinf(x) ? p_inf - : ((std::isinf(y) - ? p_inf - : ((std::isnan(x) - ? q_nan - : exprm_ns::abs(exprm_ns::complex(z)))))); + : ((std::isinf(y) ? p_inf + : ((std::isnan(x) ? q_nan : exprm_ns::abs(_z))))); return res; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index 5940315c62..45c95e0821 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -72,30 +72,31 @@ template struct CosFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT z_re = exprm_ns::real(z); + const realT z_im = exprm_ns::imag(z); - realT const &in_re = std::real(in); - realT const &in_im = std::imag(in); - - const bool in_re_finite = std::isfinite(in_re); - const bool in_im_finite = std::isfinite(in_im); + const bool z_re_finite = std::isfinite(z_re); + const bool z_im_finite = std::isfinite(z_im); /* * Handle the nearly-non-exceptional cases where * real and imaginary parts of input are finite. */ - if (in_re_finite && in_im_finite) { - return exprm_ns::cos(exprm_ns::complex(in)); // cos(in); + if (z_re_finite && z_im_finite) { + return exprm_ns::cos(z); // cos(z); } /* - * since cos(in) = cosh(I * in), for special cases, - * we return cosh(I * in). + * since cos(z) = cosh(I * z), for special cases, + * we return cosh(I * z). */ - const realT x = -in_im; - const realT y = in_re; + const realT x = -z_im; + const realT y = z_re; - const bool xfinite = in_im_finite; - const bool yfinite = in_re_finite; + const bool xfinite = z_im_finite; + const bool yfinite = z_re_finite; /* * cosh(+-0 +- I Inf) = dNaN + I sign(d(+-0, dNaN))0. * The sign of 0 in the result is unspecified. Choice = normally diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp index 59468428d1..266180f751 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp @@ -73,8 +73,10 @@ template struct CoshFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); const bool xfinite = std::isfinite(x); const bool yfinite = std::isfinite(y); @@ -84,8 +86,7 @@ template struct CoshFunctor * real and imaginary parts of input are finite. */ if (xfinite && yfinite) { - return exprm_ns::cosh( - exprm_ns::complex(in)); // cosh(in); + return exprm_ns::cosh(z); // cosh(z); } /* diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp index 00f8213251..bc644da10c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp @@ -72,12 +72,13 @@ template struct ExpFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isfinite(x)) { if (std::isfinite(y)) { - return exprm_ns::exp( - exprm_ns::complex(in)); // exp(in); + return exprm_ns::exp(z); // exp(z); } else { return resT{q_nan, q_nan}; @@ -86,7 +87,7 @@ template struct ExpFunctor else if (std::isnan(x)) { /* x is nan */ if (y == realT(0)) { - return resT{in}; + return resT{z}; } else { return resT{x, q_nan}; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp index 22291101ca..c7791b70ee 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp @@ -71,15 +71,18 @@ template struct Exp2Functor if constexpr (is_complex::value) { using realT = typename argT::value_type; - const argT tmp = in * sycl::log(realT(2)); - constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(tmp); - const realT y = std::imag(tmp); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); + + const sycl_complexT tmp = z * sycl::log(realT(2)); + if (std::isfinite(x)) { if (std::isfinite(y)) { - return exprm_ns::exp(exprm_ns::complex(tmp)); + return exprm_ns::exp(tmp); } else { return resT{q_nan, q_nan}; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index d1d64f4904..097267f2aa 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -31,6 +31,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -73,8 +74,10 @@ template struct Expm1Functor using realT = typename argT::value_type; // expm1(x + I*y) = expm1(x)*cos(y) - 2*sin(y / 2)^2 + // I*exp(x)*sin(y) - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); // special cases if (std::isinf(x)) { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index c5e0feea12..f51568f202 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -31,6 +31,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -72,7 +73,9 @@ template struct ImagFunctor resT operator()(const argT &in) const { if constexpr (is_complex_v) { - return std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = typename exprm_ns::complex; + return exprm_ns::imag(sycl_complexT(in)); } else { static_assert(std::is_same_v); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index b0651a4d8b..32f6addf2f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -69,8 +70,11 @@ template struct IsFiniteFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - const bool real_isfinite = std::isfinite(std::real(in)); - const bool imag_isfinite = std::isfinite(std::imag(in)); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const bool real_isfinite = std::isfinite(exprm_ns::real(z)); + const bool imag_isfinite = std::isfinite(exprm_ns::imag(z)); return (real_isfinite && imag_isfinite); } else if constexpr (std::is_same::value || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index ec78746143..87a215bd9f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -69,8 +70,11 @@ template struct IsInfFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - const bool real_isinf = std::isinf(std::real(in)); - const bool imag_isinf = std::isinf(std::imag(in)); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const bool real_isinf = std::isinf(exprm_ns::real(z)); + const bool imag_isinf = std::isinf(exprm_ns::imag(z)); return (real_isinf || imag_isinf); } else if constexpr (std::is_same::value || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index fbf6ef9383..abac44be84 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -29,6 +29,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -70,8 +71,11 @@ template struct IsNanFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - const bool real_isnan = sycl::isnan(std::real(in)); - const bool imag_isnan = sycl::isnan(std::imag(in)); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const bool real_isnan = sycl::isnan(exprm_ns::real(z)); + const bool imag_isnan = sycl::isnan(exprm_ns::imag(z)); return (real_isnan || imag_isnan); } else if constexpr (std::is_same::value || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp index b8d993dd94..444e5a346e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp @@ -30,6 +30,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -78,8 +79,11 @@ template struct Log1pFunctor // = log1p(x^2 + 2x + y^2) / 2 // + I * atan2(y, x + 1) using realT = typename argT::value_type; - const realT x = std::real(in); - const realT y = std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); // imaginary part of result const realT res_im = sycl::atan2(y, x + 1); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp index df5edface1..45c54e3b09 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp @@ -32,6 +32,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -70,8 +71,11 @@ template struct ProjFunctor resT operator()(const argT &in) const { using realT = typename argT::value_type; - const realT x = std::real(in); - const realT y = std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); if (std::isinf(x)) { return value_at_infinity(y); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index 9ecb822a20..539abf22f1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -31,6 +31,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -71,7 +72,9 @@ template struct RealFunctor resT operator()(const argT &in) const { if constexpr (is_complex_v) { - return std::real(in); + using realT = typename argT::value_type; + using sycl_complexT = typename exprm_ns::complex; + return exprm_ns::real(sycl_complexT(in)); } else { static_assert(std::is_same_v); @@ -174,7 +177,7 @@ template struct RealContigFactory template struct RealTypeMapFactory { - /*! @brief get typeid for output type of std::real(T x) */ + /*! @brief get typeid for output type of real(T x) */ std::enable_if_t::value, int> get() { using rT = typename RealOutputType::value_type; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index 7fbb20ae32..b53126754e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -29,6 +29,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -67,14 +68,15 @@ template struct RoundFunctor resT operator()(const argT &in) const { - if constexpr (std::is_integral_v) { return in; } else if constexpr (is_complex::value) { using realT = typename argT::value_type; - return resT{round_func(std::real(in)), - round_func(std::imag(in))}; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + return resT{round_func(exprm_ns::real(z)), + round_func(exprm_ns::imag(z))}; } else { return round_func(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index e075a90a88..81dab66026 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -72,8 +72,11 @@ template struct SinFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - realT const &in_re = std::real(in); - realT const &in_im = std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + realT const &in_re = exprm_ns::real(z); + realT const &in_im = exprm_ns::imag(z); const bool in_re_finite = std::isfinite(in_re); const bool in_im_finite = std::isfinite(in_im); @@ -82,8 +85,7 @@ template struct SinFunctor * real and imaginary parts of input are finite. */ if (in_re_finite && in_im_finite) { - resT res = - exprm_ns::sin(exprm_ns::complex(in)); // sin(in); + resT res = exprm_ns::sin(z); // sin(z); if (in_re == realT(0)) { res.real(sycl::copysign(realT(0), in_re)); } @@ -91,9 +93,9 @@ template struct SinFunctor } /* - * since sin(in) = -I * sinh(I * in), for special cases, - * we calculate real and imaginary parts of z = sinh(I * in) and - * then return { imag(z) , -real(z) } which is sin(in). + * since sin(z) = -I * sinh(I * z), for special cases, + * we calculate real and imaginary parts of z = sinh(I * z) and + * then return { imag(z) , -real(z) } which is sin(z). */ const realT x = -in_im; const realT y = in_re; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp index 23b3588a3b..4bba379f74 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp @@ -70,9 +70,10 @@ template struct SinhFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); const bool xfinite = std::isfinite(x); const bool yfinite = std::isfinite(y); @@ -82,7 +83,7 @@ template struct SinhFunctor * real and imaginary parts of input are finite. */ if (xfinite && yfinite) { - return exprm_ns::sinh(exprm_ns::complex(in)); + return exprm_ns::sinh(z); } /* * sinh(+-0 +- I Inf) = sign(d(+-0, dNaN))0 + I dNaN. diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp index 770518f918..a575a8ec0c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp @@ -75,12 +75,14 @@ template struct TanFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); /* - * since tan(in) = -I * tanh(I * in), for special cases, - * we calculate real and imaginary parts of z = tanh(I * in) and - * return { imag(z) , -real(z) } which is tan(in). + * since tan(z) = -I * tanh(I * z), for special cases, + * we calculate real and imaginary parts of z = tanh(I * z) and + * return { imag(z) , -real(z) } which is tan(z). */ - const realT x = -std::imag(in); - const realT y = std::real(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = -exprm_ns::imag(z); + const realT y = exprm_ns::real(z); /* * tanh(NaN + i 0) = NaN + i 0 * @@ -121,7 +123,7 @@ template struct TanFunctor return resT{q_nan, q_nan}; } /* ordinary cases */ - return exprm_ns::tan(exprm_ns::complex(in)); // tan(in); + return exprm_ns::tan(z); // tan(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp index 1d06fd3c4f..e88018e933 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp @@ -75,8 +75,10 @@ template struct TanhFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - const realT x = std::real(in); - const realT y = std::imag(in); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z = sycl_complexT(in); + const realT x = exprm_ns::real(z); + const realT y = exprm_ns::imag(z); /* * tanh(NaN + i 0) = NaN + i 0 * @@ -115,7 +117,7 @@ template struct TanhFunctor return resT{q_nan, q_nan}; } /* ordinary cases */ - return exprm_ns::tanh(exprm_ns::complex(in)); // tanh(in); + return exprm_ns::tanh(z); // tanh(z); } else { static_assert(std::is_floating_point_v || diff --git a/dpctl/tensor/libtensor/include/utils/math_utils.hpp b/dpctl/tensor/libtensor/include/utils/math_utils.hpp index a49b56b6ba..ecd9c1fe18 100644 --- a/dpctl/tensor/libtensor/include/utils/math_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/math_utils.hpp @@ -24,7 +24,8 @@ #pragma once #include -#include +#define SYCL_EXT_ONEAPI_COMPLEX +#include #include namespace dpctl @@ -34,13 +35,18 @@ namespace tensor namespace math_utils { +namespace exprm_ns = sycl::ext::oneapi::experimental; + template bool less_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 < imag2) @@ -50,10 +56,13 @@ template bool less_complex(const T &x1, const T &x2) template bool greater_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 > imag2) @@ -63,10 +72,13 @@ template bool greater_complex(const T &x1, const T &x2) template bool less_equal_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 <= imag2) @@ -76,10 +88,13 @@ template bool less_equal_complex(const T &x1, const T &x2) template bool greater_equal_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); return (real1 == real2) ? (imag1 >= imag2) @@ -89,10 +104,13 @@ template bool greater_equal_complex(const T &x1, const T &x2) template T max_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); bool isnan_imag1 = std::isnan(imag1); bool gt = (real1 == real2) @@ -104,10 +122,13 @@ template T max_complex(const T &x1, const T &x2) template T min_complex(const T &x1, const T &x2) { using realT = typename T::value_type; - realT real1 = std::real(x1); - realT real2 = std::real(x2); - realT imag1 = std::imag(x1); - realT imag2 = std::imag(x2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(x1); + sycl_complexT z2 = sycl_complexT(x2); + realT real1 = exprm_ns::real(z1); + realT real2 = exprm_ns::real(z2); + realT imag1 = exprm_ns::imag(z1); + realT imag2 = exprm_ns::imag(z2); bool isnan_imag1 = std::isnan(imag1); bool lt = (real1 == real2) diff --git a/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp b/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp index 2aaa1cfafa..44b70c28ec 100644 --- a/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp +++ b/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp @@ -24,7 +24,9 @@ #pragma once +#define SYCL_EXT_ONEAPI_COMPLEX #include "sycl/sycl.hpp" +#include #include namespace dpctl @@ -53,6 +55,8 @@ template struct ExtendedRealFPGreater } }; +namespace exprm_ns = sycl::ext::oneapi::experimental; + template struct ExtendedComplexFPLess { /* [(R, R), (R, nan), (nan, R), (nan, nan)] */ @@ -60,15 +64,17 @@ template struct ExtendedComplexFPLess bool operator()(const cT &v1, const cT &v2) const { using realT = typename cT::value_type; - - const realT real1 = std::real(v1); - const realT real2 = std::real(v2); + using sycl_complexT = exprm_ns::complex; + sycl_complexT z1 = sycl_complexT(v1); + sycl_complexT z2 = sycl_complexT(v2); + const realT real1 = exprm_ns::real(z1); + const realT real2 = exprm_ns::real(z2); const bool r1_nan = std::isnan(real1); const bool r2_nan = std::isnan(real2); - const realT imag1 = std::imag(v1); - const realT imag2 = std::imag(v2); + const realT imag1 = exprm_ns::imag(z1); + const realT imag2 = exprm_ns::imag(z2); const bool i1_nan = std::isnan(imag1); const bool i2_nan = std::isnan(imag2); From 88e64c1b61c26716e016565bb39bd98843aeb8c3 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 28 Apr 2025 23:28:43 -0700 Subject: [PATCH 2/9] Update binary functions multiply and subtract to use experimental SYCL complex type --- .../elementwise_functions/multiply.hpp | 15 ++++++++++- .../elementwise_functions/subtract.hpp | 25 +++++++++++++++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp index ca24383b44..a296e019ce 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp @@ -419,7 +419,20 @@ template struct MultiplyInplaceFunctor using supports_vec = std::negation< std::disjunction, tu_ns::is_complex>>; - void operator()(resT &res, const argT &in) { res *= in; } + void operator()(resT &res, const argT &in) + { + if constexpr (tu_ns::is_complex_v && tu_ns::is_complex_v) { + using res_rT = typename resT::value_type; + using arg_rT = typename argT::value_type; + + auto res1 = exprm_ns::complex(res); + res1 *= exprm_ns::complex(in); + res = res1; + } + else { + res *= in; + } + } template void operator()(sycl::vec &res, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp index 51a3955142..4b2978ffc1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp @@ -29,6 +29,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -62,7 +63,17 @@ template struct SubtractFunctor resT operator()(const argT1 &in1, const argT2 &in2) const { - return in1 - in2; + if constexpr (tu_ns::is_complex_v && tu_ns::is_complex_v) + { + using realT1 = typename argT1::value_type; + using realT2 = typename argT2::value_type; + + return exprm_ns::complex(in1) - + exprm_ns::complex(in2); + } + else { + return in1 - in2; + } } template @@ -424,7 +435,17 @@ template struct SubtractInplaceFunctor void operator()(sycl::vec &res, const sycl::vec &in) { - res -= in; + if constexpr (tu_ns::is_complex_v && tu_ns::is_complex_v) { + using res_rT = typename resT::value_type; + using arg_rT = typename argT::value_type; + + auto res1 = exprm_ns::complex(res); + res1 -= exprm_ns::complex(in); + res = res1; + } + else { + res -= in; + } } }; From 01f22d3714c2279b2bc6de495712fc6f9b26e272 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 29 Apr 2025 14:56:29 -0700 Subject: [PATCH 3/9] Use experimental SYCL complex in dot product --- .../kernels/linalg_functions/dot_product.hpp | 65 ++++++++++++++----- 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp index 71e2c15b6b..6f74766600 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -40,6 +40,9 @@ #include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace dpctl { namespace tensor @@ -49,6 +52,8 @@ namespace kernels using dpctl::tensor::ssize_t; namespace su_ns = dpctl::tensor::sycl_utils; +namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template ( lhs_[lhs_batch_offset + lhs_reduction_offset]) * convert_impl( @@ -175,7 +180,7 @@ struct DotProductFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using dpctl::tensor::type_utils::convert_impl; + using tu_ns::convert_impl; outT val = convert_impl( lhs_[lhs_batch_offset + lhs_reduction_offset]) * convert_impl( @@ -273,7 +278,7 @@ struct DotProductCustomFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using dpctl::tensor::type_utils::convert_impl; + using tu_ns::convert_impl; outT val = convert_impl( lhs_[lhs_batch_offset + lhs_reduction_offset]) * convert_impl( @@ -718,13 +723,26 @@ struct DotProductNoAtomicFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl( - lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( - rhs_[rhs_batch_offset + rhs_reduction_offset]); - - local_red_val += val; + using tu_ns::convert_impl; + using tu_ns::is_complex_v; + if constexpr (is_complex_v) { + using realT = typename outT::value_type; + using sycl_complexT = exprm_ns::complex; + + sycl_complexT val = + sycl_complexT(convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset])) * + sycl_complexT(convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset])); + local_red_val = outT(sycl_complexT(local_red_val) + val); + } + else { + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + local_red_val += val; + } } auto work_group = it.get_group(); @@ -819,13 +837,26 @@ struct DotProductNoAtomicCustomFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using dpctl::tensor::type_utils::convert_impl; - outT val = convert_impl( - lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( - rhs_[rhs_batch_offset + rhs_reduction_offset]); - - local_red_val += val; + using tu_ns::convert_impl; + using tu_ns::is_complex_v; + if constexpr (is_complex_v) { + using realT = typename outT::value_type; + using sycl_complexT = exprm_ns::complex; + + sycl_complexT val = + sycl_complexT(convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset])) * + sycl_complexT(convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset])); + local_red_val = outT(sycl_complexT(local_red_val) + val); + } + else { + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + local_red_val += val; + } } auto work_group = it.get_group(); From 34168bd57aa875c0457a7c4960712662454a393d Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 29 Apr 2025 17:43:11 -0700 Subject: [PATCH 4/9] Use experimental namespace in sequential dot product --- .../kernels/linalg_functions/dot_product.hpp | 56 ++++++++++--------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp index 6f74766600..7a9afc7f74 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -97,11 +97,23 @@ struct SequentialDotProduct auto lhs_reduction_offset = reduction_offsets.get_first_offset(); auto rhs_reduction_offset = reduction_offsets.get_second_offset(); - using tu_ns::convert_impl; - red_val += convert_impl( - lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( - rhs_[rhs_batch_offset + rhs_reduction_offset]); + if constexpr (tu_ns::is_complex_v) { + using realT = typename outT::value_type; + using sycl_complex = exprm_ns::complex; + + auto tmp = sycl_complex(red_val); + tmp += sycl_complex(tu_ns::convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset])) * + sycl_complex(tu_ns::convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset])); + red_val = outT(tmp); + } + else { + red_val += tu_ns::convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + tu_ns::convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + } } out_[out_batch_offset] = red_val; @@ -180,10 +192,9 @@ struct DotProductFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using tu_ns::convert_impl; - outT val = convert_impl( + outT val = tu_ns::convert_impl( lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( + tu_ns::convert_impl( rhs_[rhs_batch_offset + rhs_reduction_offset]); local_red_val += val; @@ -278,10 +289,9 @@ struct DotProductCustomFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using tu_ns::convert_impl; - outT val = convert_impl( + outT val = tu_ns::convert_impl( lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( + tu_ns::convert_impl( rhs_[rhs_batch_offset + rhs_reduction_offset]); local_red_val += val; @@ -723,23 +733,21 @@ struct DotProductNoAtomicFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using tu_ns::convert_impl; - using tu_ns::is_complex_v; - if constexpr (is_complex_v) { + if constexpr (tu_ns::is_complex_v) { using realT = typename outT::value_type; using sycl_complexT = exprm_ns::complex; sycl_complexT val = - sycl_complexT(convert_impl( + sycl_complexT(tu_ns::convert_impl( lhs_[lhs_batch_offset + lhs_reduction_offset])) * - sycl_complexT(convert_impl( + sycl_complexT(tu_ns::convert_impl( rhs_[rhs_batch_offset + rhs_reduction_offset])); local_red_val = outT(sycl_complexT(local_red_val) + val); } else { - outT val = convert_impl( + outT val = tu_ns::convert_impl( lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( + tu_ns::convert_impl( rhs_[rhs_batch_offset + rhs_reduction_offset]); local_red_val += val; } @@ -837,23 +845,21 @@ struct DotProductNoAtomicCustomFunctor const auto &rhs_reduction_offset = reduction_offsets_.get_second_offset(); - using tu_ns::convert_impl; - using tu_ns::is_complex_v; - if constexpr (is_complex_v) { + if constexpr (tu_ns::is_complex_v) { using realT = typename outT::value_type; using sycl_complexT = exprm_ns::complex; sycl_complexT val = - sycl_complexT(convert_impl( + sycl_complexT(tu_ns::convert_impl( lhs_[lhs_batch_offset + lhs_reduction_offset])) * - sycl_complexT(convert_impl( + sycl_complexT(tu_ns::convert_impl( rhs_[rhs_batch_offset + rhs_reduction_offset])); local_red_val = outT(sycl_complexT(local_red_val) + val); } else { - outT val = convert_impl( + outT val = tu_ns::convert_impl( lhs_[lhs_batch_offset + lhs_reduction_offset]) * - convert_impl( + tu_ns::convert_impl( rhs_[rhs_batch_offset + rhs_reduction_offset]); local_red_val += val; } From 0fd49ae36bc43847984511acb4e887dc84d9be7d Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 29 Apr 2025 18:05:08 -0700 Subject: [PATCH 5/9] Use experimental complex namespace in gemm --- .../include/kernels/linalg_functions/gemm.hpp | 104 +++++++++++++----- 1 file changed, 76 insertions(+), 28 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 4ad4eb142a..232a9c04ca 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -40,6 +40,9 @@ #include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" +#define SYCL_EXT_ONEAPI_COMPLEX +#include + namespace dpctl { namespace tensor @@ -48,6 +51,8 @@ namespace kernels { using dpctl::tensor::ssize_t; +namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; namespace gemm_detail { @@ -1082,8 +1087,21 @@ class GemmBatchFunctorThreadNM_vecm #pragma unroll for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) { - private_C[pr_i * wi_delta_m_vecs + pr_j] += - pr_lhs[pr_i] * pr_rhs[pr_j]; + if constexpr (tu_ns::is_complex_v) { + using realT = typename resT::value_type; + using sycl_complex = exprm_ns::complex; + + auto tmp = sycl_complex( + private_C[pr_i * wi_delta_m_vecs + pr_j]); + tmp += sycl_complex(pr_lhs[pr_i]) * + sycl_complex(pr_rhs[pr_j]); + private_C[pr_i * wi_delta_m_vecs + pr_j] = + resT(tmp); + } + else { + private_C[pr_i * wi_delta_m_vecs + pr_j] += + pr_lhs[pr_i] * pr_rhs[pr_j]; + } } } } @@ -1949,9 +1967,21 @@ class GemmBatchNoAtomicFunctorThreadNM slmB_t local_sum(identity_); for (std::size_t private_s = 0; private_s < wi_delta_k; ++private_s) { - local_sum = local_sum + - (local_A_block[a_offset + a_pr_offset + private_s] * - local_B_block[b_offset + private_s]); + if constexpr (tu_ns::is_complex_v) { + using realT = typename resT::value_type; + using sycl_complex = exprm_ns::complex; + auto tmp = sycl_complex(local_sum); + tmp += (sycl_complex(local_A_block[a_offset + a_pr_offset + + private_s]) * + sycl_complex(local_B_block[b_offset + private_s])); + local_sum = resT(tmp); + } + else { + local_sum = + local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } } const std::size_t gl_i = i + private_i; @@ -2114,12 +2144,28 @@ class GemmBatchNoAtomicFunctorThreadK accV_t private_sum(identity_); constexpr accV_t vec_identity_(identity_); for (std::size_t t = local_s; t < local_B_block.size(); t += delta_k) { - private_sum += - ((i < n) && (t + t_shift < k)) - ? (static_cast( - lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * - local_B_block[t]) - : vec_identity_; + if constexpr (tu_ns::is_complex_v) { + using realT = typename resT::value_type; + using sycl_complex = exprm_ns::complex; + + auto tmp = sycl_complex(private_sum); + tmp += ((i < n) && (t + t_shift < k)) + ? sycl_complex(static_cast( + lhs[lhs_offset + + lhs_indexer(global_s_offset + t)])) * + sycl_complex(local_B_block[t]) + : sycl_complex(vec_identity_); + private_sum = resT(tmp); + } + else { + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } } std::size_t workspace_i_shift = local_i * delta_k; @@ -2130,7 +2176,17 @@ class GemmBatchNoAtomicFunctorThreadK if (local_s == 0 && i < n) { accV_t local_sum(workspace[workspace_i_shift]); for (std::size_t t = 1; t < delta_k; ++t) { - local_sum += workspace[workspace_i_shift + t]; + if constexpr (tu_ns::is_complex_v) { + using realT = typename resT::value_type; + using sycl_complex = exprm_ns::complex; + + auto tmp = sycl_complex(local_sum); + tmp += sycl_complex(workspace[workspace_i_shift + t]); + local_sum = resT(tmp); + } + else { + local_sum += workspace[workspace_i_shift + t]; + } } const std::size_t total_offset = @@ -2863,8 +2919,7 @@ sycl::event gemm_batch_tree_impl(sycl::queue &exec_q, } if (max_nm < 64) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { if (m < 4) { constexpr std::uint32_t m_groups_one = 1; return gemm_batch_tree_k_impl 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { constexpr std::uint32_t m_groups_four = 4; return gemm_batch_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, @@ -3435,8 +3489,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, } if (max_nm < 64) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { if (m < 4) { return gemm_batch_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, @@ -3454,8 +3507,7 @@ gemm_batch_contig_tree_impl(sycl::queue &exec_q, } } else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { return gemm_batch_contig_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); } @@ -3840,8 +3892,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, } if (max_nm < 64) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { if (m < 4) { return gemm_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, @@ -3866,8 +3917,7 @@ sycl::event gemm_tree_impl(sycl::queue &exec_q, } } else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { return gemm_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, @@ -4191,8 +4241,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, } if (max_nm < 64) { - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { if (m < 4) { return gemm_contig_tree_k_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); @@ -4208,8 +4257,7 @@ sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, } } else { // m > 1, n > k or m > k - using dpctl::tensor::type_utils::is_complex; - if constexpr (!is_complex::value) { + if constexpr (!tu_ns::is_complex_v) { return gemm_contig_tree_nm_impl( exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); } From e3be74ec88e2f6b532ff950098e837e5934a68f4 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 29 Apr 2025 23:08:54 -0700 Subject: [PATCH 6/9] Use specialized functor for multiplying or adding complex inputs converts to experimental sycl complex values, then performs math operations --- .../libtensor/include/utils/math_utils.hpp | 14 +++ .../libtensor/include/utils/sycl_utils.hpp | 100 ++++++++++++------ .../source/accumulators/cumulative_prod.cpp | 10 +- .../source/accumulators/cumulative_sum.cpp | 8 +- .../libtensor/source/reductions/prod.cpp | 28 +++-- .../libtensor/source/reductions/sum.cpp | 25 +++-- 6 files changed, 127 insertions(+), 58 deletions(-) diff --git a/dpctl/tensor/libtensor/include/utils/math_utils.hpp b/dpctl/tensor/libtensor/include/utils/math_utils.hpp index ecd9c1fe18..d0c0475ffa 100644 --- a/dpctl/tensor/libtensor/include/utils/math_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/math_utils.hpp @@ -154,6 +154,20 @@ template T logaddexp(T x, T y) } } +template T plus_complex(const T &x1, const T &x2) +{ + using realT = typename T::value_type; + using sycl_complexT = exprm_ns::complex; + return T(sycl_complexT(x1) + sycl_complexT(x2)); +} + +template T multiplies_complex(const T &x1, const T &x2) +{ + using realT = typename T::value_type; + using sycl_complexT = exprm_ns::complex; + return T(sycl_complexT(x1) * sycl_complexT(x2)); +} + } // namespace math_utils } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index ece8852643..bcbb54ff39 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -298,7 +298,11 @@ T custom_inclusive_scan_over_group(GroupT &&wg, return scan_val; } -// Reduction functors +// Define identities and operator checking structs + +template struct GetIdentity +{ +}; // Maximum @@ -324,38 +328,6 @@ template struct Maximum } }; -// Minimum - -template struct Minimum -{ - T operator()(const T &x, const T &y) const - { - if constexpr (detail::IsComplex::value) { - using dpctl::tensor::math_utils::min_complex; - return min_complex(x, y); - } - else if constexpr (std::is_floating_point_v || - std::is_same_v) - { - return (std::isnan(x) || x < y) ? x : y; - } - else if constexpr (std::is_same_v) { - return x && y; - } - else { - return (x < y) ? x : y; - } - } -}; - -// Define identities and operator checking structs - -template struct GetIdentity -{ -}; - -// Maximum - template using IsMaximum = std::bool_constant> || std::is_same_v>>; @@ -389,6 +361,28 @@ struct GetIdentity struct Minimum +{ + T operator()(const T &x, const T &y) const + { + if constexpr (detail::IsComplex::value) { + using dpctl::tensor::math_utils::min_complex; + return min_complex(x, y); + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + return (std::isnan(x) || x < y) ? x : y; + } + else if constexpr (std::is_same_v) { + return x && y; + } + else { + return (x < y) ? x : y; + } + } +}; + template using IsMinimum = std::bool_constant> || std::is_same_v>>; @@ -422,19 +416,55 @@ struct GetIdentity struct Plus +{ + T operator()(const T &x, const T &y) const + { + if constexpr (detail::IsComplex::value) { + using dpctl::tensor::math_utils::plus_complex; + return plus_complex(x, y); + } + else { + return sycl::plus(x, y); + } + } +}; + template using IsPlus = std::bool_constant> || - std::is_same_v>>; + std::is_same_v> || + std::is_same_v>>; template using IsSyclPlus = std::bool_constant>>; +template +struct GetIdentity::value>> +{ + static constexpr T value = static_cast(0); +}; + // Multiplies +template struct Multiplies +{ + T operator()(const T &x, const T &y) const + { + if constexpr (detail::IsComplex::value) { + using dpctl::tensor::math_utils::multiplies_complex; + return multiplies_complex(x, y); + } + else { + return sycl::multiplies(x, y); + } + } +}; + template using IsMultiplies = std::bool_constant> || - std::is_same_v>>; + std::is_same_v> || + std::is_same_v>>; template using IsSyclMultiplies = diff --git a/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp b/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp index 045b1b330e..992e3592b6 100644 --- a/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp +++ b/dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp @@ -46,6 +46,7 @@ namespace py_internal namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; namespace impl { @@ -133,9 +134,12 @@ struct TypePairSupportDataForProdAccumulation }; template -using CumProdScanOpT = std::conditional_t, - sycl::logical_and, - sycl::multiplies>; +using CumProdScanOpT = + std::conditional_t, + sycl::logical_and, + std::conditional_t, + su_ns::Multiplies, + sycl::multiplies>>; template struct CumProd1DContigFactory diff --git a/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp b/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp index e44678e15f..22e32dfb03 100644 --- a/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp +++ b/dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp @@ -34,6 +34,7 @@ #include "kernels/accumulators.hpp" #include "utils/sycl_utils.hpp" #include "utils/type_dispatch_building.hpp" +#include "utils/type_utils.hpp" namespace py = pybind11; @@ -46,6 +47,7 @@ namespace py_internal namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; namespace impl { @@ -133,8 +135,10 @@ struct TypePairSupportDataForSumAccumulation }; template -using CumSumScanOpT = std:: - conditional_t, sycl::logical_or, sycl::plus>; +using CumSumScanOpT = std::conditional_t< + std::is_same_v, + sycl::logical_or, + std::conditional_t, su_ns::Plus, sycl::plus>>; template struct CumSum1DContigFactory diff --git a/dpctl/tensor/libtensor/source/reductions/prod.cpp b/dpctl/tensor/libtensor/source/reductions/prod.cpp index 7c768ce179..f2822e4773 100644 --- a/dpctl/tensor/libtensor/source/reductions/prod.cpp +++ b/dpctl/tensor/libtensor/source/reductions/prod.cpp @@ -31,7 +31,9 @@ #include #include "kernels/reductions.hpp" +#include "utils/sycl_utils.hpp" #include "utils/type_dispatch_building.hpp" +#include "utils/type_utils.hpp" #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" @@ -45,7 +47,9 @@ namespace tensor namespace py_internal { +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; namespace impl { @@ -256,9 +260,11 @@ struct ProductOverAxisTempsStridedFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t, - sycl::logical_and, - sycl::multiplies>; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_and, + std::conditional_t, + su_ns::Multiplies, + sycl::multiplies>>; return dpctl::tensor::kernels:: reduction_over_group_temps_strided_impl; @@ -315,9 +321,11 @@ struct ProductOverAxis1TempsContigFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t, - sycl::logical_and, - sycl::multiplies>; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_and, + std::conditional_t, + su_ns::Multiplies, + sycl::multiplies>>; return dpctl::tensor::kernels:: reduction_axis1_over_group_temps_contig_impl; @@ -336,9 +344,11 @@ struct ProductOverAxis0TempsContigFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t, - sycl::logical_and, - sycl::multiplies>; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_and, + std::conditional_t, + su_ns::Multiplies, + sycl::multiplies>>; return dpctl::tensor::kernels:: reduction_axis0_over_group_temps_contig_impl; diff --git a/dpctl/tensor/libtensor/source/reductions/sum.cpp b/dpctl/tensor/libtensor/source/reductions/sum.cpp index f449a6cde3..a45a86702f 100644 --- a/dpctl/tensor/libtensor/source/reductions/sum.cpp +++ b/dpctl/tensor/libtensor/source/reductions/sum.cpp @@ -31,7 +31,9 @@ #include #include "kernels/reductions.hpp" +#include "utils/sycl_utils.hpp" #include "utils/type_dispatch_building.hpp" +#include "utils/type_utils.hpp" #include "reduction_atomic_support.hpp" #include "reduction_over_axis.hpp" @@ -45,7 +47,9 @@ namespace tensor namespace py_internal { +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; namespace impl { @@ -256,9 +260,10 @@ struct SumOverAxisTempsStridedFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = - std::conditional_t, - sycl::logical_or, sycl::plus>; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, + su_ns::Plus, sycl::plus>>; return dpctl::tensor::kernels:: reduction_over_group_temps_strided_impl; @@ -315,9 +320,10 @@ struct SumOverAxis1TempsContigFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = - std::conditional_t, - sycl::logical_or, sycl::plus>; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, + su_ns::Plus, sycl::plus>>; return dpctl::tensor::kernels:: reduction_axis1_over_group_temps_contig_impl; @@ -336,9 +342,10 @@ struct SumOverAxis0TempsContigFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = - std::conditional_t, - sycl::logical_or, sycl::plus>; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, + su_ns::Plus, sycl::plus>>; return dpctl::tensor::kernels:: reduction_axis0_over_group_temps_contig_impl; From 6fba5f29eeef20ef5ee654af3bcc457edce99baa Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 29 Apr 2025 23:14:38 -0700 Subject: [PATCH 7/9] Use custom plus operator in gemm and dot product tree reduction kernels --- .../kernels/linalg_functions/dot_product.hpp | 25 +++--- .../include/kernels/linalg_functions/gemm.hpp | 83 ++++++++++--------- 2 files changed, 56 insertions(+), 52 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp index 7a9afc7f74..7e62c8b861 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -755,9 +755,10 @@ struct DotProductNoAtomicFunctor auto work_group = it.get_group(); - using RedOpT = typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using RedOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; outT red_val_over_wg = sycl::reduce_over_group( work_group, local_red_val, outT(0), RedOpT()); @@ -1009,9 +1010,10 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q, // prevents running out of resources on CPU std::size_t max_wg = reduction_detail::get_work_group_size(d); - using ReductionOpT = typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; std::size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -1051,7 +1053,7 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q, } else { constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; // more than one work-groups is needed, requires a temporary std::size_t reduction_groups = @@ -1252,9 +1254,10 @@ dot_product_contig_tree_impl(sycl::queue &exec_q, // prevents running out of resources on CPU std::size_t max_wg = reduction_detail::get_work_group_size(d); - using ReductionOpT = typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; std::size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -1298,7 +1301,7 @@ dot_product_contig_tree_impl(sycl::queue &exec_q, } else { constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; // more than one work-groups is needed, requires a temporary std::size_t reduction_groups = diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index 232a9c04ca..baf79acf78 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -51,6 +51,7 @@ namespace kernels { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace tu_ns = dpctl::tensor::type_utils; namespace exprm_ns = sycl::ext::oneapi::experimental; @@ -101,7 +102,7 @@ void scale_gemm_nm_parameters(const std::size_t &local_mem_size, } } // namespace gemm_detail -using dpctl::tensor::sycl_utils::choose_workgroup_size; +using su_ns::choose_workgroup_size; template class gemm_seq_reduction_krn; @@ -2367,12 +2368,12 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = @@ -2663,12 +2664,12 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; @@ -3034,12 +3035,12 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = @@ -3222,12 +3223,12 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; @@ -3591,12 +3592,12 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = @@ -3745,12 +3746,12 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; @@ -3979,12 +3980,12 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = @@ -4118,12 +4119,12 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = - typename std::conditional, - sycl::logical_or, - sycl::plus>::type; + using ReductionOpT = std::conditional_t< + std::is_same_v, sycl::logical_or, + std::conditional_t, su_ns::Plus, + sycl::plus>>; constexpr resTy identity_val = - sycl::known_identity::value; + su_ns::Identity::value; std::size_t iter_nelems = n * m; std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; From da15d0e89e782d7600b62e786c9b0fda1fc0e219 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 30 Apr 2025 14:23:06 -0700 Subject: [PATCH 8/9] Refactor operators when dispatching to tree reductions --- .../kernels/linalg_functions/dot_product.hpp | 23 +++++---- .../include/kernels/linalg_functions/gemm.hpp | 51 +++++++------------ .../libtensor/source/reductions/prod.cpp | 26 ++++------ .../libtensor/source/reductions/sum.cpp | 21 ++++---- 4 files changed, 53 insertions(+), 68 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp index 7e62c8b861..5246fdea61 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -55,6 +55,17 @@ namespace su_ns = dpctl::tensor::sycl_utils; namespace tu_ns = dpctl::tensor::type_utils; namespace exprm_ns = sycl::ext::oneapi::experimental; +namespace detail +{ + +template +using SumTempsOpT = std::conditional_t< + std::is_same_v, + sycl::logical_or, + std::conditional_t, su_ns::Plus, sycl::plus>>; + +} // namespace detail + template , sycl::logical_or, std::conditional_t, su_ns::Plus, - sycl::plus>>; + sycl::plus>>; outT red_val_over_wg = sycl::reduce_over_group( work_group, local_red_val, outT(0), RedOpT()); @@ -1010,10 +1021,7 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q, // prevents running out of resources on CPU std::size_t max_wg = reduction_detail::get_work_group_size(d); - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, su_ns::Plus, - sycl::plus>>; + using ReductionOpT = detail::SumTempsOpT; std::size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { @@ -1254,10 +1262,7 @@ dot_product_contig_tree_impl(sycl::queue &exec_q, // prevents running out of resources on CPU std::size_t max_wg = reduction_detail::get_work_group_size(d); - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, su_ns::Plus, - sycl::plus>>; + using ReductionOpT = detail::SumTempsOpT; std::size_t reductions_per_wi(preferred_reductions_per_wi); if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index baf79acf78..fb5e8dce14 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1795,6 +1795,17 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, // ========== Gemm Tree +namespace gemm_detail +{ + +template +using SumTempsOpT = std::conditional_t< + std::is_same_v, + sycl::logical_or, + std::conditional_t, su_ns::Plus, sycl::plus>>; + +} // namespace gemm_detail + template , sycl::logical_or, - std::conditional_t, su_ns::Plus, - sycl::plus>>; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = su_ns::Identity::value; @@ -2664,10 +2672,7 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, su_ns::Plus, - sycl::plus>>; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; @@ -3035,10 +3040,7 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, depends); } else { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, su_ns::Plus, - sycl::plus>>; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = su_ns::Identity::value; @@ -3223,10 +3225,7 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, su_ns::Plus, - sycl::plus>>; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = su_ns::Identity::value; std::size_t iter_nelems = batch_nelems * n * m; @@ -3592,10 +3591,7 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q, res_indexer, depends); } else { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, su_ns::Plus, - sycl::plus>>; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = su_ns::Identity::value; @@ -3746,10 +3742,7 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, su_ns::Plus, - sycl::plus>>; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = su_ns::Identity::value; @@ -3980,10 +3973,7 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, res_indexer, depends); } else { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, su_ns::Plus, - sycl::plus>>; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = su_ns::Identity::value; @@ -4119,10 +4109,7 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, lhs_indexer, rhs_indexer, res_indexer, depends); } else { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, su_ns::Plus, - sycl::plus>>; + using ReductionOpT = gemm_detail::SumTempsOpT; constexpr resTy identity_val = su_ns::Identity::value; diff --git a/dpctl/tensor/libtensor/source/reductions/prod.cpp b/dpctl/tensor/libtensor/source/reductions/prod.cpp index f2822e4773..40a3bf4dc5 100644 --- a/dpctl/tensor/libtensor/source/reductions/prod.cpp +++ b/dpctl/tensor/libtensor/source/reductions/prod.cpp @@ -233,6 +233,14 @@ struct TypePairSupportDataForProductReductionTemps td_ns::NotDefinedEntry>::is_defined; }; +template +using ProdTempsOpT = + std::conditional_t, + sycl::logical_and, + std::conditional_t, + su_ns::Multiplies, + sycl::multiplies>>; + template struct ProductOverAxisAtomicStridedFactory { @@ -260,11 +268,7 @@ struct ProductOverAxisTempsStridedFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_and, - std::conditional_t, - su_ns::Multiplies, - sycl::multiplies>>; + using ReductionOpT = ProdTempsOpT; return dpctl::tensor::kernels:: reduction_over_group_temps_strided_impl; @@ -321,11 +325,7 @@ struct ProductOverAxis1TempsContigFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_and, - std::conditional_t, - su_ns::Multiplies, - sycl::multiplies>>; + using ReductionOpT = ProdTempsOpT; return dpctl::tensor::kernels:: reduction_axis1_over_group_temps_contig_impl; @@ -344,11 +344,7 @@ struct ProductOverAxis0TempsContigFactory if constexpr (TypePairSupportDataForProductReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_and, - std::conditional_t, - su_ns::Multiplies, - sycl::multiplies>>; + using ReductionOpT = ProdTempsOpT; return dpctl::tensor::kernels:: reduction_axis0_over_group_temps_contig_impl; diff --git a/dpctl/tensor/libtensor/source/reductions/sum.cpp b/dpctl/tensor/libtensor/source/reductions/sum.cpp index a45a86702f..e9476c3dfb 100644 --- a/dpctl/tensor/libtensor/source/reductions/sum.cpp +++ b/dpctl/tensor/libtensor/source/reductions/sum.cpp @@ -233,6 +233,12 @@ struct TypePairSupportDataForSumReductionTemps td_ns::NotDefinedEntry>::is_defined; }; +template +using SumTempsOpT = std::conditional_t< + std::is_same_v, + sycl::logical_or, + std::conditional_t, su_ns::Plus, sycl::plus>>; + template struct SumOverAxisAtomicStridedFactory { @@ -260,10 +266,7 @@ struct SumOverAxisTempsStridedFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, - su_ns::Plus, sycl::plus>>; + using ReductionOpT = SumTempsOpT; return dpctl::tensor::kernels:: reduction_over_group_temps_strided_impl; @@ -320,10 +323,7 @@ struct SumOverAxis1TempsContigFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, - su_ns::Plus, sycl::plus>>; + using ReductionOpT = SumTempsOpT; return dpctl::tensor::kernels:: reduction_axis1_over_group_temps_contig_impl; @@ -342,10 +342,7 @@ struct SumOverAxis0TempsContigFactory if constexpr (TypePairSupportDataForSumReductionTemps< srcTy, dstTy>::is_defined) { - using ReductionOpT = std::conditional_t< - std::is_same_v, sycl::logical_or, - std::conditional_t, - su_ns::Plus, sycl::plus>>; + using ReductionOpT = SumTempsOpT; return dpctl::tensor::kernels:: reduction_axis0_over_group_temps_contig_impl; From 52bb73e7704d3b73e401b482ee429e4c013a4c54 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 30 Apr 2025 16:00:13 -0700 Subject: [PATCH 9/9] Refactor sycl_complex indirect include * Move sycl_complex.hpp to utils * No longer use exprm_ns defined by header, define on per-file basis * Include alias to type sycl_complex_t under sycl_utils namespace * Use identical include macro where inclusion of sycl_complex would be impossible --- .../kernels/elementwise_functions/acos.hpp | 6 +++-- .../kernels/elementwise_functions/acosh.hpp | 6 +++-- .../kernels/elementwise_functions/add.hpp | 27 +++++++++++++++---- .../kernels/elementwise_functions/angle.hpp | 6 +++-- .../kernels/elementwise_functions/asin.hpp | 6 +++-- .../kernels/elementwise_functions/asinh.hpp | 6 +++-- .../kernels/elementwise_functions/atan.hpp | 6 +++-- .../kernels/elementwise_functions/atanh.hpp | 6 +++-- .../elementwise_functions/cabs_impl.hpp | 9 ++++--- .../kernels/elementwise_functions/conj.hpp | 6 +++-- .../kernels/elementwise_functions/cos.hpp | 6 +++-- .../kernels/elementwise_functions/cosh.hpp | 6 +++-- .../kernels/elementwise_functions/equal.hpp | 7 ++--- .../kernels/elementwise_functions/exp.hpp | 6 +++-- .../kernels/elementwise_functions/exp2.hpp | 6 +++-- .../kernels/elementwise_functions/expm1.hpp | 6 +++-- .../kernels/elementwise_functions/imag.hpp | 6 +++-- .../elementwise_functions/isfinite.hpp | 6 +++-- .../kernels/elementwise_functions/isinf.hpp | 6 +++-- .../kernels/elementwise_functions/isnan.hpp | 6 +++-- .../kernels/elementwise_functions/log.hpp | 6 +++-- .../kernels/elementwise_functions/log10.hpp | 6 +++-- .../kernels/elementwise_functions/log1p.hpp | 6 +++-- .../kernels/elementwise_functions/log2.hpp | 6 +++-- .../elementwise_functions/multiply.hpp | 12 +++++---- .../kernels/elementwise_functions/pow.hpp | 12 +++++---- .../kernels/elementwise_functions/proj.hpp | 6 +++-- .../kernels/elementwise_functions/real.hpp | 6 +++-- .../elementwise_functions/reciprocal.hpp | 5 ++-- .../kernels/elementwise_functions/round.hpp | 6 +++-- .../kernels/elementwise_functions/sign.hpp | 4 ++- .../kernels/elementwise_functions/sin.hpp | 6 +++-- .../kernels/elementwise_functions/sinh.hpp | 6 +++-- .../kernels/elementwise_functions/sqrt.hpp | 6 +++-- .../kernels/elementwise_functions/square.hpp | 5 ++-- .../elementwise_functions/subtract.hpp | 12 +++++---- .../kernels/elementwise_functions/tan.hpp | 6 +++-- .../kernels/elementwise_functions/tanh.hpp | 6 +++-- .../elementwise_functions/true_divide.hpp | 18 +++++++------ .../kernels/linalg_functions/dot_product.hpp | 11 +++----- .../include/kernels/linalg_functions/gemm.hpp | 13 ++++----- .../libtensor/include/kernels/reductions.hpp | 14 +++++----- .../libtensor/include/utils/math_utils.hpp | 11 ++++++-- .../sycl_complex.hpp | 18 +++++++++++-- .../libtensor/include/utils/sycl_utils.hpp | 1 + .../source/sorting/rich_comparisons.hpp | 9 ++++--- 46 files changed, 235 insertions(+), 127 deletions(-) rename dpctl/tensor/libtensor/include/{kernels/elementwise_functions => utils}/sycl_complex.hpp (81%) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp index 41622906e1..3af8f0f4af 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace acos { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -72,7 +74,7 @@ template struct AcosFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp index cdad55c8a4..2bcd3dbd4e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace acosh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -77,7 +79,7 @@ template struct AcoshFunctor * where the sign is chosen so Re(acosh(in)) >= 0. * So, we first calculate acos(in) and then acosh(in). */ - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp index 476e7b52b9..e7b7f0c0e7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -50,8 +50,10 @@ namespace add { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct AddFunctor { @@ -69,21 +71,22 @@ template struct AddFunctor using rT1 = typename argT1::value_type; using rT2 = typename argT2::value_type; - return exprm_ns::complex(in1) + exprm_ns::complex(in2); + return su_ns::sycl_complex_t(in1) + + su_ns::sycl_complex_t(in2); } else if constexpr (tu_ns::is_complex::value && !tu_ns::is_complex::value) { using rT1 = typename argT1::value_type; - return exprm_ns::complex(in1) + in2; + return su_ns::sycl_complex_t(in1) + in2; } else if constexpr (!tu_ns::is_complex::value && tu_ns::is_complex::value) { using rT2 = typename argT2::value_type; - return in1 + exprm_ns::complex(in2); + return in1 + su_ns::sycl_complex_t(in2); } else { return in1 + in2; @@ -460,7 +463,21 @@ template struct AddInplaceFunctor using supports_vec = std::negation< std::disjunction, tu_ns::is_complex>>; - void operator()(resT &res, const argT &in) { res += in; } + void operator()(resT &res, const argT &in) + { + if constexpr (tu_ns::is_complex_v && tu_ns::is_complex_v) { + using rT1 = typename resT::value_type; + using rT2 = typename argT::value_type; + + auto tmp = su_ns::sycl_complex_t(res); + tmp += su_ns::sycl_complex_t(in); + + res = resT(tmp); + } + else { + res += in; + } + } template void operator()(sycl::vec &res, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp index 726f90ba81..501a73765d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace angle { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -71,7 +73,7 @@ template struct AngleFunctor { using rT = typename argT::value_type; - return exprm_ns::arg(exprm_ns::complex(in)); // arg(in); + return exprm_ns::arg(su_ns::sycl_complex_t(in)); // arg(in); } }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp index cf15831aab..9920bca56c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace asin { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -80,7 +82,7 @@ template struct AsinFunctor * y = imag(I * conj(in)) = real(in) * and then return {imag(w), real(w)} which is asin(in) */ - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::imag(z); const realT y = exprm_ns::real(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp index 61c5b7a75c..ea686fccc3 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace asinh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -72,7 +74,7 @@ template struct AsinhFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp index 1f9e4079d1..2728616841 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace atan { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::kernels::vec_size_utils::ContigHyperparameterSetDefault; using dpctl::tensor::kernels::vec_size_utils::UnaryContigHyperparameterSetEntry; @@ -83,7 +85,7 @@ template struct AtanFunctor * y = imag(I * conj(in)) = real(in) * and then return {imag(w), real(w)} which is atan(in) */ - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::imag(z); const realT y = exprm_ns::real(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp index bf702ba575..eee287823d 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace atanh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,7 +75,7 @@ template struct AtanhFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp index e1677871a5..c82e986f27 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp @@ -27,7 +27,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" namespace dpctl { @@ -38,6 +38,9 @@ namespace kernels namespace detail { +namespace su_ns = dpctl::tensor::sycl_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; + template realT cabs(std::complex const &z) { // Special values for cabs( x + y * 1j): @@ -51,8 +54,8 @@ template realT cabs(std::complex const &z) // * If x is a finite number and y is NaN, the result is NaN. // * If x is NaN and y is NaN, the result is NaN. - using sycl_complexT = exprm_ns::complex; - sycl_complexT _z = exprm_ns::complex(z); + using sycl_complexT = su_ns::sycl_complex_t; + sycl_complexT _z = su_ns::sycl_complex_t(z); const realT x = exprm_ns::real(_z); const realT y = exprm_ns::imag(_z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp index 19a95df5a1..61859c9efe 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace conj { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,7 +75,7 @@ template struct ConjFunctor if constexpr (is_complex::value) { using rT = typename argT::value_type; - return exprm_ns::conj(exprm_ns::complex(in)); // conj(in); + return exprm_ns::conj(su_ns::sycl_complex_t(in)); // conj(in); } else { if constexpr (!std::is_same_v) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp index 45c95e0821..e1da401886 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace cos { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -72,7 +74,7 @@ template struct CosFunctor using realT = typename argT::value_type; constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT z_re = exprm_ns::real(z); const realT z_im = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp index 266180f751..4b841c3486 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace cosh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,7 +75,7 @@ template struct CoshFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp index a53f6412de..9f3f09791e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/equal.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -49,6 +49,7 @@ namespace equal { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; @@ -71,8 +72,8 @@ template struct EqualFunctor using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; - return exprm_ns::complex(in1) == - exprm_ns::complex(in2); + return su_ns::sycl_complex_t(in1) == + su_ns::sycl_complex_t(in2); } else { if constexpr (std::is_integral_v && diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp index bc644da10c..a35481996f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace exp { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -72,7 +74,7 @@ template struct ExpFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp index c7791b70ee..93f54970af 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/exp2.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace exp2 { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,7 +75,7 @@ template struct Exp2Functor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index 097267f2aa..973da46d63 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace expm1 { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -74,7 +76,7 @@ template struct Expm1Functor using realT = typename argT::value_type; // expm1(x + I*y) = expm1(x)*cos(y) - 2*sin(y / 2)^2 + // I*exp(x)*sin(y) - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index f51568f202..c1b52f1d14 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace imag { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::is_complex_v; @@ -74,7 +76,7 @@ template struct ImagFunctor { if constexpr (is_complex_v) { using realT = typename argT::value_type; - using sycl_complexT = typename exprm_ns::complex; + using sycl_complexT = typename su_ns::sycl_complex_t; return exprm_ns::imag(sycl_complexT(in)); } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp index 32f6addf2f..83deac1a42 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isfinite.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -47,7 +47,9 @@ namespace isfinite { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -71,7 +73,7 @@ template struct IsFiniteFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const bool real_isfinite = std::isfinite(exprm_ns::real(z)); const bool imag_isfinite = std::isfinite(exprm_ns::imag(z)); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp index 87a215bd9f..1c0a2875f4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isinf.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace isinf { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -71,7 +73,7 @@ template struct IsInfFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const bool real_isinf = std::isinf(exprm_ns::real(z)); const bool imag_isinf = std::isinf(exprm_ns::imag(z)); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp index abac44be84..1317bdc945 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/isnan.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -48,7 +48,9 @@ namespace isnan { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -72,7 +74,7 @@ template struct IsNanFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const bool real_isnan = sycl::isnan(exprm_ns::real(z)); const bool imag_isnan = sycl::isnan(exprm_ns::imag(z)); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp index 84471a5ef4..c33af596d8 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace log { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -71,7 +73,7 @@ template struct LogFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - return exprm_ns::log(exprm_ns::complex(in)); // log(in); + return exprm_ns::log(su_ns::sycl_complex_t(in)); // log(in); } else { return sycl::log(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp index d308c85ac9..30868237dc 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log10.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace log10 { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -75,7 +77,7 @@ template struct Log10Functor if constexpr (is_complex::value) { using realT = typename argT::value_type; // return (log(in) / log(realT{10})); - return exprm_ns::log(exprm_ns::complex(in)) / + return exprm_ns::log(su_ns::sycl_complex_t(in)) / sycl::log(realT{10}); } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp index 444e5a346e..ee29b3ad4f 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace log1p { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -80,7 +82,7 @@ template struct Log1pFunctor // + I * atan2(y, x + 1) using realT = typename argT::value_type; using realT = typename argT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp index 42c837cfa3..4b708b2b93 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/log2.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace log2 { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::vec_cast; @@ -76,7 +78,7 @@ template struct Log2Functor using realT = typename argT::value_type; // log(in) / log(realT{2}); - return exprm_ns::log(exprm_ns::complex(in)) / + return exprm_ns::log(su_ns::sycl_complex_t(in)) / sycl::log(realT{2}); } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp index a296e019ce..30bd058252 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/multiply.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -51,8 +51,10 @@ namespace multiply { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct MultiplyFunctor { @@ -70,8 +72,8 @@ template struct MultiplyFunctor using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; - return exprm_ns::complex(in1) * - exprm_ns::complex(in2); + return su_ns::sycl_complex_t(in1) * + su_ns::sycl_complex_t(in2); } else { return in1 * in2; @@ -425,8 +427,8 @@ template struct MultiplyInplaceFunctor using res_rT = typename resT::value_type; using arg_rT = typename argT::value_type; - auto res1 = exprm_ns::complex(res); - res1 *= exprm_ns::complex(in); + auto res1 = su_ns::sycl_complex_t(res); + res1 *= su_ns::sycl_complex_t(in); res = res1; } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp index d7b0ed909e..f6dfc41899 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/pow.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -51,8 +51,10 @@ namespace pow { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct PowFunctor { @@ -92,8 +94,8 @@ template struct PowFunctor using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; - return exprm_ns::pow(exprm_ns::complex(in1), - exprm_ns::complex(in2)); + return exprm_ns::pow(su_ns::sycl_complex_t(in1), + su_ns::sycl_complex_t(in2)); } else { return sycl::pow(in1, in2); @@ -392,8 +394,8 @@ template struct PowInplaceFunctor using r_resT = typename resT::value_type; using r_argT = typename argT::value_type; - res = exprm_ns::pow(exprm_ns::complex(res), - exprm_ns::complex(in)); + res = exprm_ns::pow(su_ns::sycl_complex_t(res), + su_ns::sycl_complex_t(in)); } else { res = sycl::pow(res, in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp index 45c54e3b09..0c43865647 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp @@ -32,7 +32,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -52,7 +52,9 @@ namespace proj { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -72,7 +74,7 @@ template struct ProjFunctor { using realT = typename argT::value_type; using realT = typename argT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index 539abf22f1..096f32eec5 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace real { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; using dpctl::tensor::type_utils::is_complex_v; @@ -73,7 +75,7 @@ template struct RealFunctor { if constexpr (is_complex_v) { using realT = typename argT::value_type; - using sycl_complexT = typename exprm_ns::complex; + using sycl_complexT = typename su_ns::sycl_complex_t; return exprm_ns::real(sycl_complexT(in)); } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp index 0e46acba39..43e0e3c640 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/reciprocal.hpp @@ -32,7 +32,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -52,6 +52,7 @@ namespace reciprocal { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -74,7 +75,7 @@ template struct ReciprocalFunctor using realT = typename argT::value_type; - return realT(1) / exprm_ns::complex(in); + return realT(1) / su_ns::sycl_complex_t(in); } else { return argT(1) / in; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp index b53126754e..4382f0d447 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/round.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace round { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,7 +75,7 @@ template struct RoundFunctor } else if constexpr (is_complex::value) { using realT = typename argT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); return resT{round_func(exprm_ns::real(z)), round_func(exprm_ns::imag(z))}; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp index baa224942f..94943b73ab 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sign.hpp @@ -37,6 +37,7 @@ #include "kernels/elementwise_functions/common.hpp" #include "utils/offset_utils.hpp" +#include "utils/sycl_complex.hpp" #include "utils/type_dispatch_building.hpp" #include "utils/type_utils.hpp" @@ -50,6 +51,7 @@ namespace sign { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -82,7 +84,7 @@ template struct SignFunctor return resT(0); } else { - auto z = exprm_ns::complex(in); + auto z = su_ns::sycl_complex_t(in); return (z / detail::cabs(in)); } } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index 81dab66026..d4bbed564b 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace sin { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,7 +75,7 @@ template struct SinFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); using realT = typename argT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); realT const &in_re = exprm_ns::real(z); realT const &in_im = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp index 4bba379f74..6c37266781 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sinh.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -49,7 +49,9 @@ namespace sinh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -70,7 +72,7 @@ template struct SinhFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp index b83ff72495..b1014a5070 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp @@ -32,7 +32,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -52,7 +52,9 @@ namespace sqrt { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -73,7 +75,7 @@ template struct SqrtFunctor { if constexpr (is_complex::value) { using realT = typename argT::value_type; - return exprm_ns::sqrt(exprm_ns::complex(in)); + return exprm_ns::sqrt(su_ns::sycl_complex_t(in)); } else { return sycl::sqrt(in); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp index f9d9d848c0..b66b53d225 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/square.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,6 +50,7 @@ namespace square { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; @@ -74,7 +75,7 @@ template struct SquareFunctor if constexpr (is_complex::value) { using realT = typename argT::value_type; - auto z = exprm_ns::complex(in); + auto z = su_ns::sycl_complex_t(in); return z * z; } diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp index 4b2978ffc1..bc35026481 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/subtract.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -50,8 +50,10 @@ namespace subtract { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct SubtractFunctor { @@ -68,8 +70,8 @@ template struct SubtractFunctor using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; - return exprm_ns::complex(in1) - - exprm_ns::complex(in2); + return su_ns::sycl_complex_t(in1) - + su_ns::sycl_complex_t(in2); } else { return in1 - in2; @@ -439,8 +441,8 @@ template struct SubtractInplaceFunctor using res_rT = typename resT::value_type; using arg_rT = typename argT::value_type; - auto res1 = exprm_ns::complex(res); - res1 -= exprm_ns::complex(in); + auto res1 = su_ns::sycl_complex_t(res); + res1 -= su_ns::sycl_complex_t(in); res = res1; } else { diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp index a575a8ec0c..8e0404fe02 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tan.hpp @@ -30,7 +30,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -50,7 +50,9 @@ namespace tan { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -79,7 +81,7 @@ template struct TanFunctor * we calculate real and imaginary parts of z = tanh(I * z) and * return { imag(z) , -real(z) } which is tan(z). */ - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = -exprm_ns::imag(z); const realT y = exprm_ns::real(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp index e88018e933..9ea078f6a4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/tanh.hpp @@ -31,7 +31,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -51,7 +51,9 @@ namespace tanh { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; +namespace exprm_ns = sycl::ext::oneapi::experimental; using dpctl::tensor::type_utils::is_complex; @@ -75,7 +77,7 @@ template struct TanhFunctor constexpr realT q_nan = std::numeric_limits::quiet_NaN(); - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z = sycl_complexT(in); const realT x = exprm_ns::real(z); const realT y = exprm_ns::imag(z); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index de6c9a8723..a187c75230 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -29,7 +29,7 @@ #include #include -#include "sycl_complex.hpp" +#include "utils/sycl_complex.hpp" #include "vec_size_util.hpp" #include "utils/offset_utils.hpp" @@ -50,8 +50,10 @@ namespace true_divide { using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; namespace td_ns = dpctl::tensor::type_dispatch; namespace tu_ns = dpctl::tensor::type_utils; +namespace exprm_ns = sycl::ext::oneapi::experimental; template struct TrueDivideFunctor @@ -70,22 +72,22 @@ struct TrueDivideFunctor using realT1 = typename argT1::value_type; using realT2 = typename argT2::value_type; - return exprm_ns::complex(in1) / - exprm_ns::complex(in2); + return su_ns::sycl_complex_t(in1) / + su_ns::sycl_complex_t(in2); } else if constexpr (tu_ns::is_complex::value && !tu_ns::is_complex::value) { using realT1 = typename argT1::value_type; - return exprm_ns::complex(in1) / in2; + return su_ns::sycl_complex_t(in1) / in2; } else if constexpr (!tu_ns::is_complex::value && tu_ns::is_complex::value) { using realT2 = typename argT2::value_type; - return in1 / exprm_ns::complex(in2); + return in1 / su_ns::sycl_complex_t(in2); } else { return in1 / in2; @@ -435,14 +437,14 @@ template struct TrueDivideInplaceFunctor using res_rT = typename resT::value_type; using arg_rT = typename argT::value_type; - auto res1 = exprm_ns::complex(res); - res1 /= exprm_ns::complex(in); + auto res1 = su_ns::sycl_complex_t(res); + res1 /= su_ns::sycl_complex_t(in); res = res1; } else { using res_rT = typename resT::value_type; - auto res1 = exprm_ns::complex(res); + auto res1 = su_ns::sycl_complex_t(res); res1 /= in; res = res1; } diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp index 5246fdea61..c19fe6812b 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -37,12 +37,10 @@ #include "kernels/reductions.hpp" #include "utils/offset_utils.hpp" #include "utils/sycl_alloc_utils.hpp" +#include "utils/sycl_complex.hpp" #include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" -#define SYCL_EXT_ONEAPI_COMPLEX -#include - namespace dpctl { namespace tensor @@ -53,7 +51,6 @@ namespace kernels using dpctl::tensor::ssize_t; namespace su_ns = dpctl::tensor::sycl_utils; namespace tu_ns = dpctl::tensor::type_utils; -namespace exprm_ns = sycl::ext::oneapi::experimental; namespace detail { @@ -110,7 +107,7 @@ struct SequentialDotProduct if constexpr (tu_ns::is_complex_v) { using realT = typename outT::value_type; - using sycl_complex = exprm_ns::complex; + using sycl_complex = su_ns::sycl_complex_t; auto tmp = sycl_complex(red_val); tmp += sycl_complex(tu_ns::convert_impl( @@ -746,7 +743,7 @@ struct DotProductNoAtomicFunctor if constexpr (tu_ns::is_complex_v) { using realT = typename outT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT val = sycl_complexT(tu_ns::convert_impl( @@ -859,7 +856,7 @@ struct DotProductNoAtomicCustomFunctor if constexpr (tu_ns::is_complex_v) { using realT = typename outT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT val = sycl_complexT(tu_ns::convert_impl( diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index fb5e8dce14..49b8868b20 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -37,12 +37,10 @@ #include "kernels/reductions.hpp" #include "utils/offset_utils.hpp" #include "utils/sycl_alloc_utils.hpp" +#include "utils/sycl_complex.hpp" #include "utils/sycl_utils.hpp" #include "utils/type_utils.hpp" -#define SYCL_EXT_ONEAPI_COMPLEX -#include - namespace dpctl { namespace tensor @@ -53,7 +51,6 @@ namespace kernels using dpctl::tensor::ssize_t; namespace su_ns = dpctl::tensor::sycl_utils; namespace tu_ns = dpctl::tensor::type_utils; -namespace exprm_ns = sycl::ext::oneapi::experimental; namespace gemm_detail { @@ -1090,7 +1087,7 @@ class GemmBatchFunctorThreadNM_vecm { if constexpr (tu_ns::is_complex_v) { using realT = typename resT::value_type; - using sycl_complex = exprm_ns::complex; + using sycl_complex = su_ns::sycl_complex_t; auto tmp = sycl_complex( private_C[pr_i * wi_delta_m_vecs + pr_j]); @@ -1981,7 +1978,7 @@ class GemmBatchNoAtomicFunctorThreadNM { if constexpr (tu_ns::is_complex_v) { using realT = typename resT::value_type; - using sycl_complex = exprm_ns::complex; + using sycl_complex = su_ns::sycl_complex_t; auto tmp = sycl_complex(local_sum); tmp += (sycl_complex(local_A_block[a_offset + a_pr_offset + private_s]) * @@ -2158,7 +2155,7 @@ class GemmBatchNoAtomicFunctorThreadK for (std::size_t t = local_s; t < local_B_block.size(); t += delta_k) { if constexpr (tu_ns::is_complex_v) { using realT = typename resT::value_type; - using sycl_complex = exprm_ns::complex; + using sycl_complex = su_ns::sycl_complex_t; auto tmp = sycl_complex(private_sum); tmp += ((i < n) && (t + t_shift < k)) @@ -2190,7 +2187,7 @@ class GemmBatchNoAtomicFunctorThreadK for (std::size_t t = 1; t < delta_k; ++t) { if constexpr (tu_ns::is_complex_v) { using realT = typename resT::value_type; - using sycl_complex = exprm_ns::complex; + using sycl_complex = su_ns::sycl_complex_t; auto tmp = sycl_complex(local_sum); tmp += sycl_complex(workspace[workspace_i_shift + t]); diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index f056d246c9..042997f56b 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -1914,8 +1914,7 @@ struct SequentialSearchReduction using dpctl::tensor::math_utils::less_complex; // less_complex always returns false for NaNs, so check if (less_complex(val, red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(val.real()) || std::isnan(val.imag())) { red_val = val; idx_val = static_cast(m); @@ -1941,8 +1940,7 @@ struct SequentialSearchReduction if constexpr (is_complex::value) { using dpctl::tensor::math_utils::greater_complex; if (greater_complex(val, red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(val.real()) || std::isnan(val.imag())) { red_val = val; idx_val = static_cast(m); @@ -2230,8 +2228,8 @@ struct CustomSearchReduction // less_complex always returns false for NaNs, so // check if (less_complex(val, local_red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(val.real()) || + std::isnan(val.imag())) { local_red_val = val; if constexpr (!First) { @@ -2277,8 +2275,8 @@ struct CustomSearchReduction if constexpr (is_complex::value) { using dpctl::tensor::math_utils::greater_complex; if (greater_complex(val, local_red_val) || - std::isnan(std::real(val)) || - std::isnan(std::imag(val))) + std::isnan(val.real()) || + std::isnan(val.imag())) { local_red_val = val; if constexpr (!First) { diff --git a/dpctl/tensor/libtensor/include/utils/math_utils.hpp b/dpctl/tensor/libtensor/include/utils/math_utils.hpp index d0c0475ffa..9097133773 100644 --- a/dpctl/tensor/libtensor/include/utils/math_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/math_utils.hpp @@ -24,10 +24,17 @@ #pragma once #include -#define SYCL_EXT_ONEAPI_COMPLEX -#include #include +#ifndef SYCL_EXT_ONEAPI_COMPLEX +#define SYCL_EXT_ONEAPI_COMPLEX 1 +#endif +#if __has_include() +#include +#else +#include +#endif + namespace dpctl { namespace tensor diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sycl_complex.hpp b/dpctl/tensor/libtensor/include/utils/sycl_complex.hpp similarity index 81% rename from dpctl/tensor/libtensor/include/kernels/elementwise_functions/sycl_complex.hpp rename to dpctl/tensor/libtensor/include/utils/sycl_complex.hpp index 3b5a1b9e7b..535bf17241 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sycl_complex.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_complex.hpp @@ -26,11 +26,25 @@ #pragma once -#define SYCL_EXT_ONEAPI_COMPLEX +#ifndef SYCL_EXT_ONEAPI_COMPLEX +#define SYCL_EXT_ONEAPI_COMPLEX 1 +#endif #if __has_include() #include #else #include #endif -namespace exprm_ns = sycl::ext::oneapi::experimental; +namespace dpctl +{ +namespace tensor +{ +namespace sycl_utils +{ + +template +using sycl_complex_t = sycl::ext::oneapi::experimental::complex; + +} // namespace sycl_utils +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp index bcbb54ff39..7d1d9a77d9 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_utils.hpp @@ -31,6 +31,7 @@ #include #include "math_utils.hpp" +#include "sycl_complex.hpp" namespace dpctl { diff --git a/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp b/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp index 44b70c28ec..9f3cf9ffd8 100644 --- a/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp +++ b/dpctl/tensor/libtensor/source/sorting/rich_comparisons.hpp @@ -24,11 +24,11 @@ #pragma once -#define SYCL_EXT_ONEAPI_COMPLEX #include "sycl/sycl.hpp" -#include #include +#include "utils/sycl_complex.hpp" + namespace dpctl { namespace tensor @@ -38,6 +38,9 @@ namespace py_internal namespace { + +namespace su_ns = dpctl::tensor::sycl_utils; + template struct ExtendedRealFPLess { /* [R, nan] */ @@ -64,7 +67,7 @@ template struct ExtendedComplexFPLess bool operator()(const cT &v1, const cT &v2) const { using realT = typename cT::value_type; - using sycl_complexT = exprm_ns::complex; + using sycl_complexT = su_ns::sycl_complex_t; sycl_complexT z1 = sycl_complexT(v1); sycl_complexT z2 = sycl_complexT(v2); const realT real1 = exprm_ns::real(z1);