From e78b23b121b00661c39c3521ca77f5d12077d019 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 30 Oct 2024 11:15:29 -0700 Subject: [PATCH 01/92] save --- csrc/device_lower/pass/circular_buffer.cpp | 23 +++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 40bb8fc3294..f2eb30297a5 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -700,6 +700,17 @@ class CloneTmaCircularBufferLoopAndInsertSync return wait_exprs; } + // If there is already an if-then-else with electSync() predicate, use it. + // Otherwise, create a new one. + kir::IfThenElse* getElectSyncIfThenElse() { + if (elect_sync_if_then_else_ == nullptr) { + elect_sync_if_then_else_ = IrBuilder::create( + IrBuilder::create(PredicateType::ElectSync)); + for_loop_stack_.back()->body().push_back(elect_sync_if_then_else_); + } + return elect_sync_if_then_else_; + } + // This function selects a single thread to launch tma load and mbarrier // arrive_expected_tx operations. The remaining threads will simply arrive // at the mbarrier. @@ -719,16 +730,14 @@ class CloneTmaCircularBufferLoopAndInsertSync NVF_ERROR(mbarrier_arrive_tx_ != nullptr); NVF_ERROR(expr != nullptr); - // Create the if-then-else with electSync() predicate for the arrive expect - // transaction. - kir::IfThenElse* if_expr = IrBuilder::create( - IrBuilder::create(PredicateType::ElectSync)); + // Use the if-then-else with electSync() predicate for the arrive expect + // and cpAsyncBulk operations. + kir::IfThenElse* if_expr = getElectSyncIfThenElse(); // A single thread issues arriveExpectTx with expected transactions and // launches the TMA load. if_expr->thenBody().push_back(mbarrier_arrive_tx_); if_expr->thenBody().push_back(expr); - for_loop_stack_.back()->body().push_back(if_expr); mbarrier_arrive_tx_ = nullptr; } @@ -841,6 +850,10 @@ class CloneTmaCircularBufferLoopAndInsertSync // Mbarrier_ArriveExpectTx to add to cloned_top_level_loop kir::MBarrierArriveExpectTx* mbarrier_arrive_tx_ = nullptr; + // ElectSync if-then-else for the cloned loop. We put all the circular buffer + // load TMA operations under this if-then-else. + kir::IfThenElse* elect_sync_if_then_else_ = nullptr; + // The circular buffered TVs for the loop being cloned std::unordered_set circular_buffer_load_tvs_; }; From 639c99156cefcc57fa98643578d5e5964a0d68da Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 30 Oct 2024 11:58:08 -0700 Subject: [PATCH 02/92] try --- csrc/device_lower/pass/circular_buffer.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index f2eb30297a5..5db7f0cc603 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -534,6 +534,10 @@ class CloneTmaCircularBufferLoopAndInsertSync return out_tv == circular_buffer_tv; }); + if (!is_circular_buffer_load_expr) { + elect_sync_if_then_else_ = nullptr; + } + insertMBarrierWaitBeforeFirstRead(expr); // Handle Short-Circuit conditions From c02a22c571c0e60aa895b6bb559efac87a158d47 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 30 Oct 2024 12:06:33 -0700 Subject: [PATCH 03/92] save --- csrc/device_lower/pass/circular_buffer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 5db7f0cc603..09141202ea2 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -707,7 +707,7 @@ class CloneTmaCircularBufferLoopAndInsertSync // If there is already an if-then-else with electSync() predicate, use it. // Otherwise, create a new one. kir::IfThenElse* getElectSyncIfThenElse() { - if (elect_sync_if_then_else_ == nullptr) { + if (true || elect_sync_if_then_else_ == nullptr) { elect_sync_if_then_else_ = IrBuilder::create( IrBuilder::create(PredicateType::ElectSync)); for_loop_stack_.back()->body().push_back(elect_sync_if_then_else_); From cbd751dc3fbf9e2a532f7bf031c2df0e4dd471cd Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 30 Oct 2024 12:38:42 -0700 Subject: [PATCH 04/92] save --- csrc/device_lower/pass/circular_buffer.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 09141202ea2..bc4c157d6ba 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -707,10 +707,14 @@ class CloneTmaCircularBufferLoopAndInsertSync // If there is already an if-then-else with electSync() predicate, use it. // Otherwise, create a new one. kir::IfThenElse* getElectSyncIfThenElse() { - if (true || elect_sync_if_then_else_ == nullptr) { + if (loop_containing_elec_sync_ != for_loop_stack_.back()) { + elect_sync_if_then_else_ = nullptr; + } + if (elect_sync_if_then_else_ == nullptr) { elect_sync_if_then_else_ = IrBuilder::create( IrBuilder::create(PredicateType::ElectSync)); - for_loop_stack_.back()->body().push_back(elect_sync_if_then_else_); + loop_containing_elec_sync_ = for_loop_stack_.back(); + loop_containing_elec_sync_->body().push_back(elect_sync_if_then_else_); } return elect_sync_if_then_else_; } @@ -857,6 +861,7 @@ class CloneTmaCircularBufferLoopAndInsertSync // ElectSync if-then-else for the cloned loop. We put all the circular buffer // load TMA operations under this if-then-else. kir::IfThenElse* elect_sync_if_then_else_ = nullptr; + ForLoop* loop_containing_elec_sync_ = nullptr; // The circular buffered TVs for the loop being cloned std::unordered_set circular_buffer_load_tvs_; From cec00303595fa0b697b1a16d629b027212529993 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 30 Oct 2024 12:43:10 -0700 Subject: [PATCH 05/92] save --- csrc/device_lower/pass/circular_buffer.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index bc4c157d6ba..34a6eb21b1b 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -707,9 +707,9 @@ class CloneTmaCircularBufferLoopAndInsertSync // If there is already an if-then-else with electSync() predicate, use it. // Otherwise, create a new one. kir::IfThenElse* getElectSyncIfThenElse() { - if (loop_containing_elec_sync_ != for_loop_stack_.back()) { - elect_sync_if_then_else_ = nullptr; - } + // if (loop_containing_elec_sync_ != for_loop_stack_.back()) { + // elect_sync_if_then_else_ = nullptr; + // } if (elect_sync_if_then_else_ == nullptr) { elect_sync_if_then_else_ = IrBuilder::create( IrBuilder::create(PredicateType::ElectSync)); From 87f2d4b56fe481c5dc23718fc08ee42863dec34f Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 30 Oct 2024 12:51:42 -0700 Subject: [PATCH 06/92] save --- csrc/index_compute.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 81c27dddfc8..553787c66bf 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -2238,6 +2238,7 @@ kir::TensorIndex* Index::getConsumerIndex( DataType as_type) { Val* index = nullptr; if (!ir_utils::hasRootToLoopLinearTransformations(consumer) || + ir_utils::isCpAsyncBulkLoad(consumer->definition()) || (isIdModelOptionEnabled(IdModelEnableOption::ConsumerIndex) && GpuLower::current()->isTensorIndexerEnabled())) { index = GpuLower::current()->tensorIndexer().getLinearIndex( From 3d8927bdafa6a0cdd9f5937b106afd8d638c11bc Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 30 Oct 2024 12:55:57 -0700 Subject: [PATCH 07/92] save --- csrc/device_lower/pass/circular_buffer.cpp | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 34a6eb21b1b..5db7f0cc603 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -707,14 +707,10 @@ class CloneTmaCircularBufferLoopAndInsertSync // If there is already an if-then-else with electSync() predicate, use it. // Otherwise, create a new one. kir::IfThenElse* getElectSyncIfThenElse() { - // if (loop_containing_elec_sync_ != for_loop_stack_.back()) { - // elect_sync_if_then_else_ = nullptr; - // } if (elect_sync_if_then_else_ == nullptr) { elect_sync_if_then_else_ = IrBuilder::create( IrBuilder::create(PredicateType::ElectSync)); - loop_containing_elec_sync_ = for_loop_stack_.back(); - loop_containing_elec_sync_->body().push_back(elect_sync_if_then_else_); + for_loop_stack_.back()->body().push_back(elect_sync_if_then_else_); } return elect_sync_if_then_else_; } @@ -861,7 +857,6 @@ class CloneTmaCircularBufferLoopAndInsertSync // ElectSync if-then-else for the cloned loop. We put all the circular buffer // load TMA operations under this if-then-else. kir::IfThenElse* elect_sync_if_then_else_ = nullptr; - ForLoop* loop_containing_elec_sync_ = nullptr; // The circular buffered TVs for the loop being cloned std::unordered_set circular_buffer_load_tvs_; From 2bd04124210dab97a00c17581b9fe5644b897832 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 30 Oct 2024 12:56:30 -0700 Subject: [PATCH 08/92] save --- csrc/device_lower/pass/circular_buffer.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 5db7f0cc603..f2eb30297a5 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -534,10 +534,6 @@ class CloneTmaCircularBufferLoopAndInsertSync return out_tv == circular_buffer_tv; }); - if (!is_circular_buffer_load_expr) { - elect_sync_if_then_else_ = nullptr; - } - insertMBarrierWaitBeforeFirstRead(expr); // Handle Short-Circuit conditions From 6fd4489cd582a557b819b1ae3b845ff914f9b2ee Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 30 Oct 2024 17:02:52 -0700 Subject: [PATCH 09/92] save --- __tmp_kernel_none_f0_c0_r0_g0.cu | 11281 +++++++++++++++++++++++++++++ 1 file changed, 11281 insertions(+) create mode 100644 __tmp_kernel_none_f0_c0_r0_g0.cu diff --git a/__tmp_kernel_none_f0_c0_r0_g0.cu b/__tmp_kernel_none_f0_c0_r0_g0.cu new file mode 100644 index 00000000000..118d456c5e9 --- /dev/null +++ b/__tmp_kernel_none_f0_c0_r0_g0.cu @@ -0,0 +1,11281 @@ + +#ifdef __NVCC__ +#include +#endif // __NVCC__ +namespace { + +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = short int; +using uint16_t = unsigned short int; +using int32_t = int; +using uint32_t = unsigned int; +using int64_t = long long int; +using uint64_t = unsigned long long int; + +// Modified from cuda.h +struct TensorMap { + alignas(64) + uint64_t opaque[16]; +}; +typedef int nvfuser_index_t; + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#ifdef __NVCC__ +#include +#else +// The following namespace std is modified from LLVM, see the following +// copyright information +// +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// copy-pasted from some llvm files: +// - https://github.com/llvm/llvm-project/blob/main/libcxx/include/type_traits +// - +// https://github.com/llvm/llvm-project/blob/main/clang/test/Headers/Inputs/include/type_traits +namespace std { + +template +_Tp&& __declval(int); +template +_Tp __declval(long); +template +decltype(__declval<_Tp>(0)) declval() noexcept; + +template +struct integral_constant { + static const _Tp value = __v; + typedef _Tp value_type; + typedef integral_constant type; +}; + +typedef integral_constant true_type; +typedef integral_constant false_type; + +// is_same, functional +template +struct is_same : public false_type {}; +template +struct is_same<_Tp, _Tp> : public true_type {}; + +// is_integral, for some types. +template +struct is_integral : public integral_constant {}; +template <> +struct is_integral : public integral_constant {}; +template <> +struct is_integral : public integral_constant {}; +template <> +struct is_integral : public integral_constant {}; +template <> +struct is_integral : public integral_constant {}; +template <> +struct is_integral : public integral_constant {}; +template <> +struct is_integral : public integral_constant {}; + +// enable_if, functional +template +struct enable_if {}; +template +struct enable_if { + using type = _Tp; +}; +template +using enable_if_t = typename enable_if::type; + +template +struct remove_const { + typedef _Tp type; +}; +template +struct remove_const { + typedef _Tp type; +}; +template +using remove_const_t = typename remove_const<_Tp>::type; + +template +struct remove_volatile { + typedef _Tp type; +}; +template +struct remove_volatile { + typedef _Tp type; +}; +template +using remove_volatile_t = typename remove_volatile<_Tp>::type; + +template +struct remove_cv { + typedef typename remove_volatile::type>::type type; +}; +template +using remove_cv_t = typename remove_cv<_Tp>::type; + +template +struct __libcpp_is_floating_point : public false_type {}; +template <> +struct __libcpp_is_floating_point : public true_type {}; +template <> +struct __libcpp_is_floating_point : public true_type {}; +template <> +struct __libcpp_is_floating_point : public true_type {}; + +template +struct is_floating_point + : public __libcpp_is_floating_point::type> {}; + +template +struct is_arithmetic + : public integral_constant< + bool, + is_integral<_Tp>::value || is_floating_point<_Tp>::value> {}; +template +inline constexpr bool is_arithmetic_v = is_arithmetic<_Tp>::value; + +template +struct __numeric_type { + static void __test(...); + static float __test(float); + static double __test(char); + static double __test(int); + static double __test(unsigned); + static double __test(long); + static double __test(unsigned long); + static double __test(long long); + static double __test(unsigned long long); + static double __test(double); + static long double __test(long double); + + typedef decltype(__test(declval<_Tp>())) type; + static const bool value = !is_same::value; +}; + +template <> +struct __numeric_type { + static const bool value = true; +}; + +// __promote + +template < + class _A1, + class _A2 = void, + class _A3 = void, + bool = __numeric_type<_A1>::value && __numeric_type<_A2>::value && + __numeric_type<_A3>::value> +class __promote_imp { + public: + static const bool value = false; +}; + +template +class __promote_imp<_A1, _A2, _A3, true> { + private: + typedef typename __promote_imp<_A1>::type __type1; + typedef typename __promote_imp<_A2>::type __type2; + typedef typename __promote_imp<_A3>::type __type3; + + public: + typedef decltype(__type1() + __type2() + __type3()) type; + static const bool value = true; +}; + +template +class __promote_imp<_A1, _A2, void, true> { + private: + typedef typename __promote_imp<_A1>::type __type1; + typedef typename __promote_imp<_A2>::type __type2; + + public: + typedef decltype(__type1() + __type2()) type; + static const bool value = true; +}; + +template +class __promote_imp<_A1, void, void, true> { + public: + typedef typename __numeric_type<_A1>::type type; + static const bool value = true; +}; + +template +class __promote : public __promote_imp<_A1, _A2, _A3> {}; + +} // namespace std +#endif + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#ifdef __NVCC__ +#include +#else + +namespace std { + +template +std::enable_if_t bit_cast( + const From& src) noexcept { + return *reinterpret_cast(&src); +} + +} // namespace std + +#endif + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#ifndef __NVCC__ +#define POS_INFINITY __int_as_float(0x7f800000) +#define INFINITY POS_INFINITY +#define NEG_INFINITY __int_as_float(0xff800000) +#define NAN __int_as_float(0x7fffffff) +//===----------------------------------------------------------------------===// +// The following namespace std is modified from LLVM, see the following +// copyright information +// +// -*- C++ -*- +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// copy-pasted from the following llvm file: +// https://github.com/llvm/llvm-project/blob/main/libcxx/include/complex +namespace std { + +template +class complex; + +template +complex<_Tp> operator*(const complex<_Tp>& __z, const complex<_Tp>& __w); +template +complex<_Tp> operator/(const complex<_Tp>& __x, const complex<_Tp>& __y); + +template +class complex { + public: + typedef _Tp value_type; + + private: + value_type __re_; + value_type __im_; + + public: + constexpr complex( + const value_type& __re = value_type(), + const value_type& __im = value_type()) + : __re_(__re), __im_(__im) {} + template + constexpr complex(const complex<_Xp>& __c) + : __re_(__c.real()), __im_(__c.imag()) {} + + constexpr value_type real() const { + return __re_; + } + constexpr value_type imag() const { + return __im_; + } + + void real(value_type __re) { + __re_ = __re; + } + void imag(value_type __im) { + __im_ = __im; + } + + constexpr operator bool() const { + return real() || imag(); + } + + complex& operator=(const value_type& __re) { + __re_ = __re; + __im_ = value_type(); + return *this; + } + complex& operator+=(const value_type& __re) { + __re_ += __re; + return *this; + } + complex& operator-=(const value_type& __re) { + __re_ -= __re; + return *this; + } + complex& operator*=(const value_type& __re) { + __re_ *= __re; + __im_ *= __re; + return *this; + } + complex& operator/=(const value_type& __re) { + __re_ /= __re; + __im_ /= __re; + return *this; + } + + template + complex& operator=(const complex<_Xp>& __c) { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + template + complex& operator+=(const complex<_Xp>& __c) { + __re_ += __c.real(); + __im_ += __c.imag(); + return *this; + } + template + complex& operator-=(const complex<_Xp>& __c) { + __re_ -= __c.real(); + __im_ -= __c.imag(); + return *this; + } + template + complex& operator*=(const complex<_Xp>& __c) { + *this = *this * complex(__c.real(), __c.imag()); + return *this; + } + template + complex& operator/=(const complex<_Xp>& __c) { + *this = *this / complex(__c.real(), __c.imag()); + return *this; + } +}; + +template <> +class complex; + +template <> +class complex { + float __re_; + float __im_; + + public: + typedef float value_type; + + constexpr complex(float __re = 0.0f, float __im = 0.0f) + : __re_(__re), __im_(__im) {} + + explicit constexpr complex(const complex& __c); + + // copy volatile to non-volatile + constexpr complex(const volatile complex& other) + : __re_(other.__re_), __im_(other.__im_) {} + + constexpr complex(const complex& other) + : __re_(other.__re_), __im_(other.__im_) {} + + constexpr float real() const { + return __re_; + } + constexpr float imag() const { + return __im_; + } + + void real(value_type __re) { + __re_ = __re; + } + void imag(value_type __im) { + __im_ = __im; + } + + constexpr operator bool() const { + return real() || imag(); + } + + complex& operator=(float __re) { + __re_ = __re; + __im_ = value_type(); + return *this; + } + complex& operator+=(float __re) { + __re_ += __re; + return *this; + } + complex& operator-=(float __re) { + __re_ -= __re; + return *this; + } + complex& operator*=(float __re) { + __re_ *= __re; + __im_ *= __re; + return *this; + } + complex& operator/=(float __re) { + __re_ /= __re; + __im_ /= __re; + return *this; + } + + template + complex& operator=(const complex<_Xp>& __c) { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + + // non-volatile to volatile + template + volatile complex& operator=(const complex<_Xp>& __c) volatile { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + // volatile to non-volatile + template + complex& operator=(const volatile complex<_Xp>& __c) { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + // volatile to volatile + template + volatile complex& operator=(const volatile complex<_Xp>& __c) volatile { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + + template + complex& operator+=(const complex<_Xp>& __c) { + __re_ += __c.real(); + __im_ += __c.imag(); + return *this; + } + template + complex& operator-=(const complex<_Xp>& __c) { + __re_ -= __c.real(); + __im_ -= __c.imag(); + return *this; + } + template + complex& operator*=(const complex<_Xp>& __c) { + *this = *this * complex(__c.real(), __c.imag()); + return *this; + } + template + complex& operator/=(const complex<_Xp>& __c) { + *this = *this / complex(__c.real(), __c.imag()); + return *this; + } +}; + +template <> +class complex { + double __re_; + double __im_; + + public: + typedef double value_type; + + constexpr complex(double __re = 0.0, double __im = 0.0) + : __re_(__re), __im_(__im) {} + + constexpr complex(const complex& __c); + + // copy volatile to non-volatile + constexpr complex(const volatile complex& other) + : __re_(other.__re_), __im_(other.__im_) {} + + constexpr complex(const complex& other) + : __re_(other.__re_), __im_(other.__im_) {} + + constexpr double real() const { + return __re_; + } + constexpr double imag() const { + return __im_; + } + + void real(value_type __re) { + __re_ = __re; + } + void imag(value_type __im) { + __im_ = __im; + } + + constexpr operator bool() const { + return real() || imag(); + } + + complex& operator=(double __re) { + __re_ = __re; + __im_ = value_type(); + return *this; + } + complex& operator+=(double __re) { + __re_ += __re; + return *this; + } + complex& operator-=(double __re) { + __re_ -= __re; + return *this; + } + complex& operator*=(double __re) { + __re_ *= __re; + __im_ *= __re; + return *this; + } + complex& operator/=(double __re) { + __re_ /= __re; + __im_ /= __re; + return *this; + } + + template + complex& operator=(const complex<_Xp>& __c) { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + + // non-volatile to volatile + template + volatile complex& operator=(const complex<_Xp>& __c) volatile { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + // volatile to non-volatile + template + complex& operator=(const volatile complex<_Xp>& __c) { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + // volatile to volatile + template + volatile complex& operator=(const volatile complex<_Xp>& __c) volatile { + __re_ = __c.real(); + __im_ = __c.imag(); + return *this; + } + + template + complex& operator+=(const complex<_Xp>& __c) { + __re_ += __c.real(); + __im_ += __c.imag(); + return *this; + } + template + complex& operator-=(const complex<_Xp>& __c) { + __re_ -= __c.real(); + __im_ -= __c.imag(); + return *this; + } + template + complex& operator*=(const complex<_Xp>& __c) { + *this = *this * complex(__c.real(), __c.imag()); + return *this; + } + template + complex& operator/=(const complex<_Xp>& __c) { + *this = *this / complex(__c.real(), __c.imag()); + return *this; + } +}; + +inline constexpr complex::complex(const complex& __c) + : __re_(__c.real()), __im_(__c.imag()) {} + +inline constexpr complex::complex(const complex& __c) + : __re_(__c.real()), __im_(__c.imag()) {} + +// 26.3.6 operators: + +template +inline complex<_Tp> operator+( + const complex<_Tp>& __x, + const complex<_Tp>& __y) { + complex<_Tp> __t(__x); + __t += __y; + return __t; +} + +template +inline complex<_Tp> operator+(const complex<_Tp>& __x, const _Tp& __y) { + complex<_Tp> __t(__x); + __t += __y; + return __t; +} + +template +inline complex<_Tp> operator+(const _Tp& __x, const complex<_Tp>& __y) { + complex<_Tp> __t(__y); + __t += __x; + return __t; +} + +template +inline complex<_Tp> operator-( + const complex<_Tp>& __x, + const complex<_Tp>& __y) { + complex<_Tp> __t(__x); + __t -= __y; + return __t; +} + +template +inline complex<_Tp> operator-(const complex<_Tp>& __x, const _Tp& __y) { + complex<_Tp> __t(__x); + __t -= __y; + return __t; +} + +template +inline complex<_Tp> operator-(const _Tp& __x, const complex<_Tp>& __y) { + complex<_Tp> __t(-__y); + __t += __x; + return __t; +} + +template +complex<_Tp> operator*(const complex<_Tp>& __z, const complex<_Tp>& __w) { + _Tp __a = __z.real(); + _Tp __b = __z.imag(); + _Tp __c = __w.real(); + _Tp __d = __w.imag(); + _Tp __ac = __a * __c; + _Tp __bd = __b * __d; + _Tp __ad = __a * __d; + _Tp __bc = __b * __c; + _Tp __x = __ac - __bd; + _Tp __y = __ad + __bc; + if (isnan(__x) && isnan(__y)) { + bool __recalc = false; + if (isinf(__a) || isinf(__b)) { + __a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a); + __b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b); + if (isnan(__c)) + __c = copysign(_Tp(0), __c); + if (isnan(__d)) + __d = copysign(_Tp(0), __d); + __recalc = true; + } + if (isinf(__c) || isinf(__d)) { + __c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c); + __d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d); + if (isnan(__a)) + __a = copysign(_Tp(0), __a); + if (isnan(__b)) + __b = copysign(_Tp(0), __b); + __recalc = true; + } + if (!__recalc && + (isinf(__ac) || isinf(__bd) || isinf(__ad) || isinf(__bc))) { + if (isnan(__a)) + __a = copysign(_Tp(0), __a); + if (isnan(__b)) + __b = copysign(_Tp(0), __b); + if (isnan(__c)) + __c = copysign(_Tp(0), __c); + if (isnan(__d)) + __d = copysign(_Tp(0), __d); + __recalc = true; + } + if (__recalc) { + __x = _Tp(INFINITY) * (__a * __c - __b * __d); + __y = _Tp(INFINITY) * (__a * __d + __b * __c); + } + } + return complex<_Tp>(__x, __y); +} + +template +inline complex<_Tp> operator*(const complex<_Tp>& __x, const _Tp& __y) { + complex<_Tp> __t(__x); + __t *= __y; + return __t; +} + +template +inline complex<_Tp> operator*(const _Tp& __x, const complex<_Tp>& __y) { + complex<_Tp> __t(__y); + __t *= __x; + return __t; +} + +template +complex<_Tp> operator/(const complex<_Tp>& __z, const complex<_Tp>& __w) { + int __ilogbw = 0; + _Tp __a = __z.real(); + _Tp __b = __z.imag(); + _Tp __c = __w.real(); + _Tp __d = __w.imag(); + _Tp __logbw = logb(fmax(fabs(__c), fabs(__d))); + if (isfinite(__logbw)) { + __ilogbw = static_cast(__logbw); + __c = scalbn(__c, -__ilogbw); + __d = scalbn(__d, -__ilogbw); + } + _Tp __denom = __c * __c + __d * __d; + _Tp __x = scalbn((__a * __c + __b * __d) / __denom, -__ilogbw); + _Tp __y = scalbn((__b * __c - __a * __d) / __denom, -__ilogbw); + if (isnan(__x) && isnan(__y)) { + if ((__denom == _Tp(0)) && (!isnan(__a) || !isnan(__b))) { + __x = copysign(_Tp(INFINITY), __c) * __a; + __y = copysign(_Tp(INFINITY), __c) * __b; + } else if ((isinf(__a) || isinf(__b)) && isfinite(__c) && isfinite(__d)) { + __a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a); + __b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b); + __x = _Tp(INFINITY) * (__a * __c + __b * __d); + __y = _Tp(INFINITY) * (__b * __c - __a * __d); + } else if ( + isinf(__logbw) && __logbw > _Tp(0) && isfinite(__a) && isfinite(__b)) { + __c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c); + __d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d); + __x = _Tp(0) * (__a * __c + __b * __d); + __y = _Tp(0) * (__b * __c - __a * __d); + } + } + return complex<_Tp>(__x, __y); +} + +template +inline complex<_Tp> operator/(const complex<_Tp>& __x, const _Tp& __y) { + return complex<_Tp>(__x.real() / __y, __x.imag() / __y); +} + +template +inline complex<_Tp> operator/(const _Tp& __x, const complex<_Tp>& __y) { + complex<_Tp> __t(__x); + __t /= __y; + return __t; +} + +template +inline complex<_Tp> operator+(const complex<_Tp>& __x) { + return __x; +} + +template +inline complex<_Tp> operator-(const complex<_Tp>& __x) { + return complex<_Tp>(-__x.real(), -__x.imag()); +} + +template +inline constexpr bool operator==( + const complex<_Tp>& __x, + const complex<_Tp>& __y) { + return __x.real() == __y.real() && __x.imag() == __y.imag(); +} + +template +inline constexpr bool operator==(const complex<_Tp>& __x, const _Tp& __y) { + return __x.real() == __y && __x.imag() == 0; +} + +template +inline constexpr bool operator==(const _Tp& __x, const complex<_Tp>& __y) { + return __x == __y.real() && 0 == __y.imag(); +} + +template +inline constexpr bool operator!=( + const complex<_Tp>& __x, + const complex<_Tp>& __y) { + return !(__x == __y); +} + +template +inline constexpr bool operator!=(const complex<_Tp>& __x, const _Tp& __y) { + return !(__x == __y); +} + +template +inline constexpr bool operator!=(const _Tp& __x, const complex<_Tp>& __y) { + return !(__x == __y); +} + +template +inline constexpr bool operator&&( + const complex<_Tp>& __x, + const complex<_Tp>& __y) { + return bool(__x) && bool(__y); +} + +template +inline constexpr bool isnan(const complex<_Tp>& __x) { + return isnan(__x.real()) || isnan(__x.imag()); +} + +template +inline constexpr bool operator||( + const complex<_Tp>& __x, + const complex<_Tp>& __y) { + return bool(__x) || bool(__y); +} + +// 26.3.7 values: + +template < + class _Tp, + bool = is_integral<_Tp>::value, + bool = is_floating_point<_Tp>::value> +struct __libcpp_complex_overload_traits {}; + +// Integral Types +template +struct __libcpp_complex_overload_traits<_Tp, true, false> { + typedef double _ValueType; + typedef complex _ComplexType; +}; + +// Floating point types +template +struct __libcpp_complex_overload_traits<_Tp, false, true> { + typedef _Tp _ValueType; + typedef complex<_Tp> _ComplexType; +}; + +// real + +template +inline constexpr _Tp real(const complex<_Tp>& __c) { + return __c.real(); +} + +template +inline constexpr typename __libcpp_complex_overload_traits<_Tp>::_ValueType real( + _Tp __re) { + return __re; +} + +// imag + + +template +inline constexpr _Tp imag(const complex<_Tp>& __c) { + return __c.imag(); +} + +template +inline constexpr typename __libcpp_complex_overload_traits<_Tp>::_ValueType imag( + _Tp) { + return 0; +} + +// abs + +template +inline _Tp abs(const complex<_Tp>& __c) { + return hypot(__c.real(), __c.imag()); +} + +// arg + +template +inline _Tp arg(const complex<_Tp>& __c) { + return atan2(__c.imag(), __c.real()); +} + +template +inline typename enable_if< + is_integral<_Tp>::value || is_same<_Tp, double>::value, + double>::type +arg(_Tp __re) { + return atan2(0., __re); +} + +template +inline typename enable_if::value, float>::type arg( + _Tp __re) { + return atan2f(0.F, __re); +} + +} // namespace std + +namespace std { + +using ::isfinite; +using ::isinf; +using ::isnan; +using ::signbit; + +using ::abs; + +using ::acos; +using ::acosf; +using ::asin; +using ::asinf; +using ::atan; +using ::atan2; +using ::atan2f; +using ::atanf; +using ::ceil; +using ::ceilf; +using ::cos; +using ::cosf; +using ::cosh; +using ::coshf; + +using ::exp; +using ::expf; + +using ::fabs; +using ::fabsf; +using ::floor; +using ::floorf; + +using ::fmod; +using ::fmodf; + +using ::frexp; +using ::frexpf; +using ::ldexp; +using ::ldexpf; + +using ::log; +using ::logf; + +using ::log10; +using ::log10f; +using ::modf; +using ::modff; + +using ::pow; +using ::powf; + +using ::sin; +using ::sinf; +using ::sinh; +using ::sinhf; + +using ::sqrt; +using ::sqrtf; +using ::tan; +using ::tanf; + +using ::tanh; +using ::tanhf; + +using ::acosh; +using ::acoshf; +using ::asinh; +using ::asinhf; +using ::atanh; +using ::atanhf; +using ::cbrt; +using ::cbrtf; + +using ::copysign; +using ::copysignf; + +using ::erf; +using ::erfc; +using ::erfcf; +using ::erff; +using ::exp2; +using ::exp2f; +using ::expm1; +using ::expm1f; +using ::fdim; +using ::fdimf; +using ::fma; +using ::fmaf; +using ::fmax; +using ::fmaxf; +using ::fmin; +using ::fminf; +using ::hypot; +using ::hypotf; +using ::ilogb; +using ::ilogbf; +using ::lgamma; +using ::lgammaf; +using ::llrint; +using ::llrintf; +using ::llround; +using ::llroundf; +using ::log1p; +using ::log1pf; +using ::log2; +using ::log2f; +using ::logb; +using ::logbf; +using ::lrint; +using ::lrintf; +using ::lround; +using ::lroundf; + +using ::nan; +using ::nanf; + +using ::nearbyint; +using ::nearbyintf; +using ::nextafter; +using ::nextafterf; +using ::remainder; +using ::remainderf; +using ::remquo; +using ::remquof; +using ::rint; +using ::rintf; +using ::round; +using ::roundf; +using ::scalbln; +using ::scalblnf; +using ::scalbn; +using ::scalbnf; +using ::tgamma; +using ::tgammaf; +using ::trunc; +using ::truncf; + +} // namespace std + +namespace std { + +// norm + +template +inline _Tp norm(const complex<_Tp>& __c) { + if (isinf(__c.real())) + return abs(__c.real()); + if (isinf(__c.imag())) + return abs(__c.imag()); + return __c.real() * __c.real() + __c.imag() * __c.imag(); +} + +template +inline typename __libcpp_complex_overload_traits<_Tp>::_ValueType norm( + _Tp __re) { + typedef typename __libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType; + return static_cast<_ValueType>(__re) * __re; +} + +// conj + +template +inline complex<_Tp> conj(const complex<_Tp>& __c) { + return complex<_Tp>(__c.real(), -__c.imag()); +} + +template +inline typename __libcpp_complex_overload_traits<_Tp>::_ComplexType conj( + _Tp __re) { + typedef + typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; + return _ComplexType(__re); +} + +// proj + +template +inline complex<_Tp> proj(const complex<_Tp>& __c) { + complex<_Tp> __r = __c; + if (isinf(__c.real()) || isinf(__c.imag())) + __r = complex<_Tp>(INFINITY, copysign(_Tp(0), __c.imag())); + return __r; +} + +template +inline typename enable_if< + is_floating_point<_Tp>::value, + typename __libcpp_complex_overload_traits<_Tp>::_ComplexType>::type +proj(_Tp __re) { + if (isinf(__re)) + __re = abs(__re); + return complex<_Tp>(__re); +} + +template +inline typename enable_if< + is_integral<_Tp>::value, + typename __libcpp_complex_overload_traits<_Tp>::_ComplexType>::type +proj(_Tp __re) { + typedef + typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; + return _ComplexType(__re); +} + +// polar + +template +complex<_Tp> polar(const _Tp& __rho, const _Tp& __theta = _Tp()) { + if (isnan(__rho) || signbit(__rho)) + return complex<_Tp>(_Tp(NAN), _Tp(NAN)); + if (isnan(__theta)) { + if (isinf(__rho)) + return complex<_Tp>(__rho, __theta); + return complex<_Tp>(__theta, __theta); + } + if (isinf(__theta)) { + if (isinf(__rho)) + return complex<_Tp>(__rho, _Tp(NAN)); + return complex<_Tp>(_Tp(NAN), _Tp(NAN)); + } + _Tp __x = __rho * cos(__theta); + if (isnan(__x)) + __x = 0; + _Tp __y = __rho * sin(__theta); + if (isnan(__y)) + __y = 0; + return complex<_Tp>(__x, __y); +} + +// log + +template +inline complex<_Tp> log(const complex<_Tp>& __x) { + return complex<_Tp>(log(abs(__x)), arg(__x)); +} + +// log10 + +template +inline complex<_Tp> log10(const complex<_Tp>& __x) { + return log(__x) / log(_Tp(10)); +} + +// log2 + +template +inline complex<_Tp> log2(const complex<_Tp>& __x) { + return log(__x) / log(_Tp(2)); +} + +// sqrt + +template +complex<_Tp> sqrt(const complex<_Tp>& __x) { + if (isinf(__x.imag())) + return complex<_Tp>(_Tp(INFINITY), __x.imag()); + if (isinf(__x.real())) { + if (__x.real() > _Tp(0)) + return complex<_Tp>( + __x.real(), + isnan(__x.imag()) ? __x.imag() : copysign(_Tp(0), __x.imag())); + return complex<_Tp>( + isnan(__x.imag()) ? __x.imag() : _Tp(0), + copysign(__x.real(), __x.imag())); + } + return polar(sqrt(abs(__x)), arg(__x) / _Tp(2)); +} + +// exp + +template +complex<_Tp> exp(const complex<_Tp>& __x) { + _Tp __i = __x.imag(); + if (__i == 0) { + return complex<_Tp>(exp(__x.real()), copysign(_Tp(0), __x.imag())); + } + if (isinf(__x.real())) { + if (__x.real() < _Tp(0)) { + if (!isfinite(__i)) + __i = _Tp(1); + } else if (__i == 0 || !isfinite(__i)) { + if (isinf(__i)) + __i = _Tp(NAN); + return complex<_Tp>(__x.real(), __i); + } + } + _Tp __e = exp(__x.real()); + return complex<_Tp>(__e * cos(__i), __e * sin(__i)); +} + +// pow + +template +inline complex<_Tp> pow(const complex<_Tp>& __x, const complex<_Tp>& __y) { + return exp(__y * log(__x)); +} + +template +inline complex::type> pow( + const complex<_Tp>& __x, + const complex<_Up>& __y) { + typedef complex::type> result_type; + return std::pow(result_type(__x), result_type(__y)); +} + +template +inline typename enable_if< + is_arithmetic<_Up>::value, + complex::type>>::type +pow(const complex<_Tp>& __x, const _Up& __y) { + typedef complex::type> result_type; + return std::pow(result_type(__x), result_type(__y)); +} + +template +inline typename enable_if< + is_arithmetic<_Tp>::value, + complex::type>>::type +pow(const _Tp& __x, const complex<_Up>& __y) { + typedef complex::type> result_type; + return std::pow(result_type(__x), result_type(__y)); +} + +// __sqr, computes pow(x, 2) + +template +inline complex<_Tp> __sqr(const complex<_Tp>& __x) { + return complex<_Tp>( + (__x.real() - __x.imag()) * (__x.real() + __x.imag()), + _Tp(2) * __x.real() * __x.imag()); +} + +// asinh + +template +complex<_Tp> asinh(const complex<_Tp>& __x) { + const _Tp __pi(atan2(+0., -0.)); + if (isinf(__x.real())) { + if (isnan(__x.imag())) + return __x; + if (isinf(__x.imag())) + return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag())); + return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag())); + } + if (isnan(__x.real())) { + if (isinf(__x.imag())) + return complex<_Tp>(__x.imag(), __x.real()); + if (__x.imag() == 0) + return __x; + return complex<_Tp>(__x.real(), __x.real()); + } + if (isinf(__x.imag())) + return complex<_Tp>( + copysign(__x.imag(), __x.real()), copysign(__pi / _Tp(2), __x.imag())); + complex<_Tp> __z = log(__x + sqrt(__sqr(__x) + _Tp(1))); + return complex<_Tp>( + copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag())); +} + +// acosh + +template +complex<_Tp> acosh(const complex<_Tp>& __x) { + const _Tp __pi(atan2(+0., -0.)); + if (isinf(__x.real())) { + if (isnan(__x.imag())) + return complex<_Tp>(abs(__x.real()), __x.imag()); + if (isinf(__x.imag())) { + if (__x.real() > 0) + return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag())); + else + return complex<_Tp>( + -__x.real(), copysign(__pi * _Tp(0.75), __x.imag())); + } + if (__x.real() < 0) + return complex<_Tp>(-__x.real(), copysign(__pi, __x.imag())); + return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag())); + } + if (isnan(__x.real())) { + if (isinf(__x.imag())) + return complex<_Tp>(abs(__x.imag()), __x.real()); + return complex<_Tp>(__x.real(), __x.real()); + } + if (isinf(__x.imag())) + return complex<_Tp>(abs(__x.imag()), copysign(__pi / _Tp(2), __x.imag())); + complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); + return complex<_Tp>( + copysign(__z.real(), _Tp(0)), copysign(__z.imag(), __x.imag())); +} + +// atanh + +template +complex<_Tp> atanh(const complex<_Tp>& __x) { + const _Tp __pi(atan2(+0., -0.)); + if (isinf(__x.imag())) { + return complex<_Tp>( + copysign(_Tp(0), __x.real()), copysign(__pi / _Tp(2), __x.imag())); + } + if (isnan(__x.imag())) { + if (isinf(__x.real()) || __x.real() == 0) + return complex<_Tp>(copysign(_Tp(0), __x.real()), __x.imag()); + return complex<_Tp>(__x.imag(), __x.imag()); + } + if (isnan(__x.real())) { + return complex<_Tp>(__x.real(), __x.real()); + } + if (isinf(__x.real())) { + return complex<_Tp>( + copysign(_Tp(0), __x.real()), copysign(__pi / _Tp(2), __x.imag())); + } + if (abs(__x.real()) == _Tp(1) && __x.imag() == _Tp(0)) { + return complex<_Tp>( + copysign(_Tp(INFINITY), __x.real()), copysign(_Tp(0), __x.imag())); + } + complex<_Tp> __z = log((_Tp(1) + __x) / (_Tp(1) - __x)) / _Tp(2); + return complex<_Tp>( + copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag())); +} + +// sinh + +template +complex<_Tp> sinh(const complex<_Tp>& __x) { + if (isinf(__x.real()) && !isfinite(__x.imag())) + return complex<_Tp>(__x.real(), _Tp(NAN)); + if (__x.real() == 0 && !isfinite(__x.imag())) + return complex<_Tp>(__x.real(), _Tp(NAN)); + if (__x.imag() == 0 && !isfinite(__x.real())) + return __x; + return complex<_Tp>( + sinh(__x.real()) * cos(__x.imag()), cosh(__x.real()) * sin(__x.imag())); +} + +// cosh + +template +complex<_Tp> cosh(const complex<_Tp>& __x) { + if (isinf(__x.real()) && !isfinite(__x.imag())) + return complex<_Tp>(abs(__x.real()), _Tp(NAN)); + if (__x.real() == 0 && !isfinite(__x.imag())) + return complex<_Tp>(_Tp(NAN), __x.real()); + if (__x.real() == 0 && __x.imag() == 0) + return complex<_Tp>(_Tp(1), __x.imag()); + if (__x.imag() == 0 && !isfinite(__x.real())) + return complex<_Tp>(abs(__x.real()), __x.imag()); + return complex<_Tp>( + cosh(__x.real()) * cos(__x.imag()), sinh(__x.real()) * sin(__x.imag())); +} + +// tanh + +template +complex<_Tp> tanh(const complex<_Tp>& __x) { + if (isinf(__x.real())) { + if (!isfinite(__x.imag())) + return complex<_Tp>(copysign(_Tp(1), __x.real()), _Tp(0)); + return complex<_Tp>( + copysign(_Tp(1), __x.real()), + copysign(_Tp(0), sin(_Tp(2) * __x.imag()))); + } + if (isnan(__x.real()) && __x.imag() == 0) + return __x; + _Tp __2r(_Tp(2) * __x.real()); + _Tp __2i(_Tp(2) * __x.imag()); + _Tp __d(cosh(__2r) + cos(__2i)); + _Tp __2rsh(sinh(__2r)); + if (isinf(__2rsh) && isinf(__d)) + return complex<_Tp>( + __2rsh > _Tp(0) ? _Tp(1) : _Tp(-1), __2i > _Tp(0) ? _Tp(0) : _Tp(-0.)); + return complex<_Tp>(__2rsh / __d, sin(__2i) / __d); +} + +// asin + +template +complex<_Tp> asin(const complex<_Tp>& __x) { + complex<_Tp> __z = asinh(complex<_Tp>(-__x.imag(), __x.real())); + return complex<_Tp>(__z.imag(), -__z.real()); +} + +// acos + +template +complex<_Tp> acos(const complex<_Tp>& __x) { + const _Tp __pi(atan2(+0., -0.)); + if (isinf(__x.real())) { + if (isnan(__x.imag())) + return complex<_Tp>(__x.imag(), __x.real()); + if (isinf(__x.imag())) { + if (__x.real() < _Tp(0)) + return complex<_Tp>(_Tp(0.75) * __pi, -__x.imag()); + return complex<_Tp>(_Tp(0.25) * __pi, -__x.imag()); + } + if (__x.real() < _Tp(0)) + return complex<_Tp>(__pi, signbit(__x.imag()) ? -__x.real() : __x.real()); + return complex<_Tp>(_Tp(0), signbit(__x.imag()) ? __x.real() : -__x.real()); + } + if (isnan(__x.real())) { + if (isinf(__x.imag())) + return complex<_Tp>(__x.real(), -__x.imag()); + return complex<_Tp>(__x.real(), __x.real()); + } + if (isinf(__x.imag())) + return complex<_Tp>(__pi / _Tp(2), -__x.imag()); + if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag()))) + return complex<_Tp>(__pi / _Tp(2), -__x.imag()); + complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); + if (signbit(__x.imag())) + return complex<_Tp>(abs(__z.imag()), abs(__z.real())); + return complex<_Tp>(abs(__z.imag()), -abs(__z.real())); +} + +// atan + +template +complex<_Tp> atan(const complex<_Tp>& __x) { + complex<_Tp> __z = atanh(complex<_Tp>(-__x.imag(), __x.real())); + return complex<_Tp>(__z.imag(), -__z.real()); +} + +// sin + +template +complex<_Tp> sin(const complex<_Tp>& __x) { + complex<_Tp> __z = sinh(complex<_Tp>(-__x.imag(), __x.real())); + return complex<_Tp>(__z.imag(), -__z.real()); +} + +// cos + +template +inline complex<_Tp> cos(const complex<_Tp>& __x) { + return cosh(complex<_Tp>(-__x.imag(), __x.real())); +} + +// tan + +template +complex<_Tp> tan(const complex<_Tp>& __x) { + complex<_Tp> __z = tanh(complex<_Tp>(-__x.imag(), __x.real())); + return complex<_Tp>(__z.imag(), -__z.real()); +} + +// Literal suffix for complex number literals [complex.literals] +inline namespace literals { +inline namespace complex_literals { +constexpr complex operator""i(long double __im) { + return {0.0, static_cast(__im)}; +} + +constexpr complex operator""i(unsigned long long __im) { + return {0.0, static_cast(__im)}; +} + +constexpr complex operator""if(long double __im) { + return {0.0f, static_cast(__im)}; +} + +constexpr complex operator""if(unsigned long long __im) { + return {0.0f, static_cast(__im)}; +} +} // namespace complex_literals +} // namespace literals + +} // namespace std + +__device__ std::complex lerp( + std::complex start, + std::complex end, + std::complex weight) { + if (abs(weight) < 0.5) { + return start + weight * (end - start); + } else { + return end - (end - start) * (1.0 - weight); + } +} + +__device__ std::complex lerp( + std::complex start, + std::complex end, + std::complex weight) { + if (abs(weight) < 0.5f) { + return start + weight * (end - start); + } else { + + return end - (end - start) * (1.0f - weight); + } +} + +__device__ std::complex reciprocal(std::complex x) { + return 1.0 / x; +} + +__device__ std::complex reciprocal(std::complex x) { + return 1.0f / x; +} + +__device__ std::complex sigmoid(std::complex x) { + return 1.0 / (1.0 + exp(-x)); +} + +__device__ std::complex sigmoid(std::complex x) { + return 1.0f / (1.0f + exp(-x)); +} + +// The reciprocal of a complex number z is +// 1/z = conj(z)/|z|^2. +// The principal square root of a complex number z can be obtained by [1] +// sqrt(z) = sqrt(|z|) (z + |z|) / |z + |z||. +// Combining these formulas we have +// 1/sqrt(z) = (conj(z) + |z|) / (sqrt(|z|) |z + |z||). +// [1] https://math.stackexchange.com/a/44500 +__device__ std::complex rsqrt(std::complex z) { + auto a = std::real(z); + auto b = std::imag(z); + auto absa = ::fabsf(a); + auto absb = ::fabsf(b); + // scale to avoid precision loss due to underflow/overflow + auto scale = fmax(absa, absb); + a /= scale; + b /= scale; + auto a_sq = a * a; + auto b_sq = b * b; + auto modz_sq = a_sq + b_sq; + auto modz = ::sqrtf(modz_sq); + auto a_plus_modz = a + modz; + auto mod_zplusmodz_sq = a_plus_modz * a_plus_modz + b_sq; + auto fac = ::rsqrtf(scale * modz * mod_zplusmodz_sq); + return std::complex(a_plus_modz * fac, -b * fac); +} + +__device__ std::complex rsqrt(std::complex z) { + auto a = std::real(z); + auto b = std::imag(z); + auto absa = ::abs(a); + auto absb = ::abs(b); + // scale to avoid precision loss due to underflow/overflow + auto scale = fmax(absa, absb); + a /= scale; + b /= scale; + auto a_sq = a * a; + auto b_sq = b * b; + auto modz_sq = a_sq + b_sq; + auto modz = ::sqrt(modz_sq); + auto a_plus_modz = a + modz; + auto mod_zplusmodz_sq = a_plus_modz * a_plus_modz + b_sq; + auto fac = ::rsqrt(scale * modz * mod_zplusmodz_sq); + return std::complex(a_plus_modz * fac, -b * fac); +} + +template +bool isfinite(std::complex x) { + return ::isfinite(std::real(x)) && ::isfinite(std::imag(x)); +} + +template +bool isinf(std::complex x) { + return ::isinf(std::real(x)) || ::isinf(std::imag(x)); +} + +template +bool isreal(std::complex x) { + return std::imag(x) == 0; +} +#endif // __NVCC__ + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#define __NVFUSER_HALF_TO_US(var) *(reinterpret_cast(&(var))) +#define __NVFUSER_HALF_TO_CUS(var) \ + *(reinterpret_cast(&(var))) + +struct __half; +__device__ __inline__ __half __float2half(const float); + +struct __align__(2) __half { + __half() = default; + + __half(const __half& other) { + __x = other.__x; + } + + __half(const __half&& other) { + __x = other.__x; + } + + __half(const volatile __half& other) { + __x = other.__x; + } + + __half(const volatile __half&& other) { + __x = other.__x; + } + + // Note: not returning reference for `__half::operator=` + // Doing so would requires us to return `volatile __half&` for the volatile + // variants, which would trigger a gcc warning `implicit dereference will not + // access object of type ‘volatile S’ in statement` + __device__ void operator=(const __half& other) { + __x = other.__x; + } + + __device__ void operator=(const __half&& other) { + __x = other.__x; + } + + __device__ void operator=(const volatile __half& other) { + __x = other.__x; + } + + __device__ void operator=(const volatile __half&& other) { + __x = other.__x; + } + + __device__ void operator=(const __half& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const __half&& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const volatile __half& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const volatile __half&& other) volatile { + __x = other.__x; + } + + __device__ __half(const float f) { + __x = __float2half(f).__x; + } + + __device__ uint16_t raw() const { + return __x; + } + + protected: + unsigned short __x; +}; + +__device__ __inline__ __half __float2half(const float f) { + __half val; + asm("{ cvt.rn.f16.f32 %0, %1;}\n" + : "=h"(__NVFUSER_HALF_TO_US(val)) + : "f"(f)); + return val; +} + +__device__ __inline__ __half __double2half(const double d) { + __half val; + asm("{ cvt.rn.f16.f64 %0, %1;}\n" + : "=h"(__NVFUSER_HALF_TO_US(val)) + : "d"(d)); + return val; +} + +__device__ __inline__ __half __int2half(const int i) { + __half val; + asm("{ cvt.rn.f16.s32 %0, %1;}\n" + : "=h"(__NVFUSER_HALF_TO_US(val)) + : "r"(i)); + return val; +} + +__device__ __inline__ __half __int2half(const int64_t i64) { + __half val; + asm("{ cvt.rn.f16.s64 %0, %1;}\n" + : "=h"(__NVFUSER_HALF_TO_US(val)) + : "l"(i64)); + return val; +} + +__device__ __inline__ __half __int2half(const uint32_t i) { + __half val; + asm("{ cvt.rn.f16.u32 %0, %1;}\n" + : "=h"(__NVFUSER_HALF_TO_US(val)) + : "r"(i)); + return val; +} + +__device__ __inline__ __half __int2half(const uint64_t i64) { + __half val; + asm("{ cvt.rn.f16.u64 %0, %1;}\n" + : "=h"(__NVFUSER_HALF_TO_US(val)) + : "l"(i64)); + return val; +} + +__device__ __inline__ __half __bool2half(const bool b) { + return __int2half((int)b); +} + +__device__ __inline__ float __half2float(const __half h) { + float val; + asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__NVFUSER_HALF_TO_CUS(h))); + return val; +} + +__device__ __inline__ double __half2double(const __half h) { + double val; + asm("{ cvt.f64.f16 %0, %1;}\n" : "=d"(val) : "h"(__NVFUSER_HALF_TO_CUS(h))); + return val; +} + +__device__ int __half2int32(const __half h) { + int val; + asm("{ cvt.rzi.s32.f16 %0, %1;}\n" + : "=r"(val) + : "h"(__NVFUSER_HALF_TO_CUS(h))); + return val; +} + +__device__ __inline__ int64_t __half2int(const __half h) { + int64_t val; + asm("{ cvt.rzi.s64.f16 %0, %1;}\n" + : "=l"(val) + : "h"(__NVFUSER_HALF_TO_CUS(h))); + return val; +} + +__device__ int __half2uint32(const __half h) { + int val; + asm("{ cvt.rzi.u32.f16 %0, %1;}\n" + : "=r"(val) + : "h"(__NVFUSER_HALF_TO_CUS(h))); + return val; +} + +__device__ __inline__ int64_t __half2uint(const __half h) { + int64_t val; + asm("{ cvt.rzi.u64.f16 %0, %1;}\n" + : "=l"(val) + : "h"(__NVFUSER_HALF_TO_CUS(h))); + return val; +} + +__device__ __inline__ void __half2int(const __half h, int& output) { + output = __half2int32(h); +} + +__device__ __inline__ void __half2int(const __half h, int64_t& output) { + output = __half2int(h); +} + +__device__ __inline__ void __half2int(const __half h, uint32_t& output) { + output = __half2uint32(h); +} + +__device__ __inline__ void __half2int(const __half h, uint64_t& output) { + output = __half2uint(h); +} + +__device__ __inline__ nvfuser_index_t __half2index(const __half h) { + nvfuser_index_t result; + __half2int(h, result); + return result; +} + +__device__ __inline__ bool __half2bool(const __half h) { + return (bool)__half2float(h) != 0; +} + +__device__ __inline__ __half __real_then_2half(const std::complex c) { + return __float2half(std::real(c)); +} + +__device__ __inline__ __half __real_then_2half(const std::complex c) { + return __double2half(std::real(c)); +} + +__device__ __inline__ bool __heq(const __half a, const __half b) { + // From cuda_fp16.hpp + unsigned short val; + asm("{ .reg .pred __$temp3;\n" + " setp.eq.f16 __$temp3, %1, %2;\n" + " selp.u16 %0, 1, 0, __$temp3;}" + : "=h"(val) + : "h"(__NVFUSER_HALF_TO_CUS(a)), "h"(__NVFUSER_HALF_TO_CUS(b))); + return (val != 0U) ? true : false; +} + +__device__ __inline__ __half operator|(const __half x, const __half y) { + __half val; + asm("{ or.b16 %0, %1, %2;}\n" + : "=h"(__NVFUSER_HALF_TO_US(val)) + : "h"(__NVFUSER_HALF_TO_CUS(x)), "h"(__NVFUSER_HALF_TO_CUS(y))); + return val; +} + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#define __NVFUSER_BFLOAT_TO_US(var) *(reinterpret_cast(&(var))) +#define __NVFUSER_BFLOAT_TO_CUS(var) \ + *(reinterpret_cast(&(var))) + +struct __bfloat; +__device__ __inline__ __bfloat __float2bfloat(const float); + +struct __align__(2) __bfloat { + __bfloat() = default; + + __bfloat(const __bfloat& other) { + __x = other.__x; + } + + __bfloat(const __bfloat&& other) { + __x = other.__x; + } + + __bfloat(const volatile __bfloat& other) { + __x = other.__x; + } + + __bfloat(const volatile __bfloat&& other) { + __x = other.__x; + } + + // Note: not returning reference for `__bfloat::operator=` + // Doing so would requires us to return `volatile __bfloat&` for the volatile + // variants, which would trigger a gcc warning `implicit dereference will not + // access object of type ‘volatile S’ in statement` + __device__ void operator=(const __bfloat& other) { + __x = other.__x; + } + + __device__ void operator=(const __bfloat&& other) { + __x = other.__x; + } + + __device__ void operator=(const volatile __bfloat& other) { + __x = other.__x; + } + + __device__ void operator=(const volatile __bfloat&& other) { + __x = other.__x; + } + + __device__ void operator=(const __bfloat& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const __bfloat&& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const volatile __bfloat& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const volatile __bfloat&& other) volatile { + __x = other.__x; + } + + __device__ __bfloat(const float f) { + __x = __float2bfloat(f).__x; + } + + __device__ uint16_t raw() const { + return __x; + } + + protected: + unsigned short __x; +}; + +__device__ __inline__ __bfloat __float2bfloat(const float f) { + __bfloat val; + asm("{ cvt.rn.bf16.f32 %0, %1;}\n" + : "=h"(__NVFUSER_BFLOAT_TO_US(val)) + : "f"(f)); + return val; +} + +__device__ __inline__ __bfloat __double2bfloat(const double d) { +#if __CUDA_ARCH__ >= 900 + __bfloat val; + asm("{ cvt.rn.bf16.f64 %0, %1;}\n" + : "=h"(__NVFUSER_BFLOAT_TO_US(val)) + : "d"(d)); + return val; +#else + return __float2bfloat(static_cast(d)); +#endif +} + +__device__ __inline__ __bfloat __int2bfloat(const int i) { +#if __CUDA_ARCH__ >= 900 + __bfloat val; + asm("{ cvt.rn.bf16.s32 %0, %1;}\n" + : "=h"(__NVFUSER_BFLOAT_TO_US(val)) + : "r"(i)); + return val; +#else + return __float2bfloat(static_cast(i)); +#endif +} + +__device__ __inline__ __bfloat __int2bfloat(const int64_t i64) { +#if __CUDA_ARCH__ >= 900 + __bfloat val; + asm("{ cvt.rn.bf16.s64 %0, %1;}\n" + : "=h"(__NVFUSER_BFLOAT_TO_US(val)) + : "l"(i64)); + return val; +#else + return __float2bfloat(static_cast(i64)); +#endif +} + +__device__ __inline__ __bfloat __int2bfloat(const uint32_t i) { +#if __CUDA_ARCH__ >= 900 + __bfloat val; + asm("{ cvt.rn.bf16.u32 %0, %1;}\n" + : "=h"(__NVFUSER_BFLOAT_TO_US(val)) + : "r"(i)); + return val; +#else + return __float2bfloat(static_cast(i)); +#endif +} + +__device__ __inline__ __bfloat __int2bfloat(const uint64_t i64) { +#if __CUDA_ARCH__ >= 900 + __bfloat val; + asm("{ cvt.rn.bf16.u64 %0, %1;}\n" + : "=h"(__NVFUSER_BFLOAT_TO_US(val)) + : "l"(i64)); + return val; +#else + return __float2bfloat(static_cast(i64)); +#endif +} + +__device__ __inline__ __bfloat __bool2bfloat(const bool b) { + return __int2bfloat((int)b); +} + +__device__ __inline__ float __bfloat2float(const __bfloat h) { + float val; + asm("{ mov.b32 %0, {0,%1};}\n" + : "=f"(val) + : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); + return val; +} + +__device__ __inline__ double __bfloat2double(const __bfloat h) { +#if __CUDA_ARCH__ >= 900 + double val; + asm("{ cvt.f64.bf16 %0, %1;}\n" + : "=d"(val) + : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); + return val; +#else + return static_cast(__bfloat2float(h)); +#endif +} + +__device__ int __bfloat2int32(const __bfloat h) { +#if __CUDA_ARCH__ >= 900 + int val; + asm("{ cvt.rzi.s32.bf16 %0, %1;}\n" + : "=r"(val) + : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); + return val; +#else + return static_cast(__bfloat2float(h)); +#endif +} + +__device__ __inline__ int64_t __bfloat2int(const __bfloat h) { +#if __CUDA_ARCH__ >= 900 + int64_t val; + asm("{ cvt.rzi.s64.bf16 %0, %1;}\n" + : "=l"(val) + : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); + return val; +#else + return static_cast(__bfloat2float(h)); +#endif +} + +__device__ int __bfloat2uint32(const __bfloat h) { +#if __CUDA_ARCH__ >= 900 + int val; + asm("{ cvt.rzi.u32.bf16 %0, %1;}\n" + : "=r"(val) + : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); + return val; +#else + return static_cast(__bfloat2float(h)); +#endif +} + +__device__ __inline__ int64_t __bfloat2uint(const __bfloat h) { +#if __CUDA_ARCH__ >= 900 + int64_t val; + asm("{ cvt.rzi.u64.bf16 %0, %1;}\n" + : "=l"(val) + : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); + return val; +#else + return static_cast(__bfloat2float(h)); +#endif +} + +__device__ __inline__ void __bfloat2int(const __bfloat h, int& output) { + output = __bfloat2int32(h); +} + +__device__ __inline__ void __bfloat2int(const __bfloat h, int64_t& output) { + output = __bfloat2int(h); +} + +__device__ __inline__ void __bfloat2int(const __bfloat h, uint32_t& output) { + output = __bfloat2uint32(h); +} + +__device__ __inline__ void __bfloat2int(const __bfloat h, uint64_t& output) { + output = __bfloat2uint(h); +} + +__device__ __inline__ nvfuser_index_t __bfloat2index( + const __bfloat h, + bool& output) { + nvfuser_index_t result; + __bfloat2int(h, result); + return result; +} + +__device__ __inline__ bool __bfloat2bool(const __bfloat h) { + return (bool)__bfloat2float(h) != 0; +} + +__device__ __inline__ __bfloat __half2bfloat(const __half h) { +#if __CUDA_ARCH__ >= 900 + __bfloat val; + asm("{ cvt.rn.bf16.f16 %0, %1;}\n" + : "=h"(__NVFUSER_BFLOAT_TO_US(val)) + : "h"(__NVFUSER_HALF_TO_CUS(h))); + return val; +#else + return __float2bfloat(__half2float(h)); +#endif +} + +__device__ __inline__ __half __bfloat2half(const __bfloat h) { +#if __CUDA_ARCH__ >= 900 + __half val; + asm("{ cvt.rn.f16.bf16 %0, %1;}\n" + : "=h"(__NVFUSER_HALF_TO_US(val)) + : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); + return val; +#else + return __float2half(__bfloat2float(h)); +#endif +} + +__device__ __inline__ __bfloat __real_then_2bfloat( + const std::complex c) { + return __float2bfloat(std::real(c)); +} + +__device__ __inline__ __bfloat __real_then_2bfloat( + const std::complex c) { + return __double2bfloat(std::real(c)); +} + +__device__ __inline__ bool __heq(const __bfloat a, const __bfloat b) { +// From cuda_bf16.hpp +#if __CUDA_ARCH__ >= 900 + unsigned short val; + asm("{ .reg .pred __$temp3;\n" + " setp.eq.bf16 __$temp3, %1, %2;\n" + " selp.u16 %0, 1, 0, __$temp3;}" + : "=h"(val) + : "h"(__NVFUSER_BFLOAT_TO_CUS(a)), "h"(__NVFUSER_BFLOAT_TO_CUS(b))); +#else + unsigned int val; + asm("{.reg .b32 a,b;\n" + " mov.b32 a, {0, %1};\n" + " mov.b32 b, {0, %2};\n" + " set.eq.f32.f32 %0, a, b;}\n" + : "=r"(val) + : "h"(__NVFUSER_BFLOAT_TO_CUS(a)), "h"(__NVFUSER_BFLOAT_TO_CUS(b))); +#endif + return (val != 0U) ? true : false; +} + +__device__ __inline__ __bfloat operator|(const __bfloat x, const __bfloat y) { + __bfloat val; + asm("{ or.b16 %0, %1, %2;}\n" + : "=h"(__NVFUSER_BFLOAT_TO_US(val)) + : "h"(__NVFUSER_BFLOAT_TO_CUS(x)), "h"(__NVFUSER_BFLOAT_TO_CUS(y))); + return val; +} + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +struct __e4m3; +__device__ __inline__ __e4m3 __float2e4m3(const float); +__device__ __inline__ __e4m3 __double2e4m3(const double); + +struct __align__(1) __e4m3 { + __e4m3() = default; + + __e4m3(const __e4m3& other) { + __x = other.__x; + } + + __e4m3(const __e4m3&& other) { + __x = other.__x; + } + + __e4m3(const volatile __e4m3& other) { + __x = other.__x; + } + + __e4m3(const volatile __e4m3&& other) { + __x = other.__x; + } + + // Note: not returning reference for `__e4m3::operator=` + // Doing so would requires us to return `volatile __e4m3&` for the volatile + // variants, which would trigger a gcc warning `implicit dereference will not + // access object of type ‘volatile S’ in statement` + __device__ void operator=(const __e4m3& other) { + __x = other.__x; + } + + __device__ void operator=(const __e4m3&& other) { + __x = other.__x; + } + + __device__ void operator=(const volatile __e4m3& other) { + __x = other.__x; + } + + __device__ void operator=(const volatile __e4m3&& other) { + __x = other.__x; + } + + __device__ void operator=(const __e4m3& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const __e4m3&& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const volatile __e4m3& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const volatile __e4m3&& other) volatile { + __x = other.__x; + } + + __device__ __e4m3(const float f) { + __x = __float2e4m3(f).__x; + } + + __device__ __e4m3(const double f) { + __x = __double2e4m3(f).__x; + } + + __device__ __e4m3(const int x) : __x(x) {} + + __device__ __e4m3(const long long x) : __x(x) {} + + __device__ __e4m3(const uint8_t x) : __x(x) {} + + __device__ __e4m3(const uint16_t x) : __x(x) {} + + __device__ uint8_t raw() const { + return __x; + } + + protected: + uint8_t __x; +}; + +__device__ __inline__ __e4m3 __double2e4m3(const double f) { + unsigned short _tmp_buffer; + __e4m3 val; + asm("{\n\t" + ".reg .b16 buf0;\n\t" + ".reg .b32 buf1;\n\t" + "cvt.rn.f16.f64 buf0, %1;\n\t" + "cvt.u32.u16 buf1, buf0;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 %0, buf1;\n\t" + "}" + : "=h"(_tmp_buffer) + : "d"(f)); + memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); + return val; +} + +__device__ __inline__ double __e4m32double(const __e4m3 h) { + unsigned short _tmp_buffer; + memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); + double val; + asm("{\n\t" + ".reg .b32 buf0;\n\t" + "cvt.rn.f16x2.e4m3x2 buf0, %1;\n\t" + "cvt.u16.u32 %1, buf0;\n\t" + "cvt.f64.f16 %0, %1;" + "}" + : "=d"(val) + : "h"(_tmp_buffer)); + + return val; +} + +__device__ __inline__ __e4m3 __float2e4m3(const float f) { + constexpr float f_const_zero = 0.f; + unsigned short _tmp_buffer; + __e4m3 val; + asm("{cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;}" + : "=h"(_tmp_buffer) + : "f"(f_const_zero), "f"(f)); + memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); + + return val; +} + +__device__ __inline__ float __e4m32float(const __e4m3 h) { + unsigned short _tmp_buffer; + memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); + float val; + asm("{\n\t" + ".reg .b32 buf0;\n\t" + "cvt.rn.f16x2.e4m3x2 buf0, %1;\n\t" + "cvt.u16.u32 %1, buf0;\n\t" + "cvt.f32.f16 %0, %1;\n\t" + "}" + : "=f"(val) + : "h"(_tmp_buffer)); + + return val; +} + +__device__ __inline__ __e4m3 __half2e4m3(const __half h) { + uint32_t buffer; + memcpy(&buffer, &h, sizeof(__half)); + unsigned short _tmp_buffer; + __e4m3 val; + asm("{cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;}\n\t" + : "=h"(_tmp_buffer) + : "r"(buffer)); + memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); + + return val; +} + +__device__ __inline__ __half __e4m32half(const __e4m3 h) { + unsigned short _tmp_buffer; + memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); + __half val; + asm("{\n\t" + ".reg .b32 buf0;\n\t" + "cvt.rn.f16x2.e4m3x2 buf0, %1;\n\t" + "cvt.u16.u32 %0, buf0;\n\t" + "}" + : "=h"(__NVFUSER_HALF_TO_US(val)) + : "h"(_tmp_buffer)); + + return val; +} + +__device__ __inline__ __e4m3 __bfloat2e4m3(const __bfloat h) { + unsigned short _tmp_buffer; + __e4m3 val; + asm("{\n\t" + ".reg .b32 buf0;\n\t" + "cvt.rn.f16.bf16 %1, %1;\n\t" + "cvt.u32.u16 buf0, %1;\n\t" + "cvt.rn.satfinite.e4m3x2.f16x2 %0, buf0;\n\t" + "}" + : "=h"(_tmp_buffer) + : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); + memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); + return val; +} + +__device__ __inline__ __bfloat __e4m32bfloat(const __e4m3 h) { + unsigned short _tmp_buffer; + memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); + __bfloat val; + asm("{\n\t" + ".reg .b32 buf0;\n\t" + "cvt.rn.f16x2.e4m3x2 buf0, %1;\n\t" + "cvt.u16.u32 %0, buf0;\n\t" + "cvt.bf16.f16 %0, %0;\n\t" + "}" + : "=h"(__NVFUSER_BFLOAT_TO_US(val)) + : "h"(_tmp_buffer)); + + return val; +} + +__device__ __inline__ __e4m3 operator|(const __e4m3 x, const __e4m3 y) { + unsigned short val; + unsigned short x_val = x.raw(); + unsigned short y_val = y.raw(); + asm("{ or.b16 %0, %1, %2;}\n" : "=h"(val) : "h"(x_val), "h"(y_val)); + return __e4m3(val); +} + +struct __e5m2; +__device__ __inline__ __e5m2 __float2e5m2(const float); +__device__ __inline__ __e5m2 __double2e5m2(const double); + +struct __align__(1) __e5m2 { + __e5m2() = default; + + __e5m2(const __e5m2& other) { + __x = other.__x; + } + + __e5m2(const __e5m2&& other) { + __x = other.__x; + } + + __e5m2(const volatile __e5m2& other) { + __x = other.__x; + } + + __e5m2(const volatile __e5m2&& other) { + __x = other.__x; + } + + // Note: not returning reference for `__e5m2::operator=` + // Doing so would requires us to return `volatile __e5m2&` for the volatile + // variants, which would trigger a gcc warning `implicit dereference will not + // access object of type ‘volatile S’ in statement` + __device__ void operator=(const __e5m2& other) { + __x = other.__x; + } + + __device__ void operator=(const __e5m2&& other) { + __x = other.__x; + } + + __device__ void operator=(const volatile __e5m2& other) { + __x = other.__x; + } + + __device__ void operator=(const volatile __e5m2&& other) { + __x = other.__x; + } + + __device__ void operator=(const __e5m2& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const __e5m2&& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const volatile __e5m2& other) volatile { + __x = other.__x; + } + + __device__ void operator=(const volatile __e5m2&& other) volatile { + __x = other.__x; + } + + __device__ __e5m2(const float f) { + __x = __float2e5m2(f).__x; + } + + __device__ __e5m2(const double f) { + __x = __double2e5m2(f).__x; + } + + __device__ __e5m2(const int x) : __x(x) {} + + __device__ __e5m2(const long long x) : __x(x) {} + + __device__ __e5m2(const uint8_t x) : __x(x) {} + + __device__ __e5m2(const uint16_t x) : __x(x) {} + + __device__ uint8_t raw() const { + return __x; + } + + protected: + uint8_t __x; +}; + +__device__ __inline__ __e5m2 __double2e5m2(const double f) { + unsigned short _tmp_buffer; + __e5m2 val; + asm("{\n\t" + ".reg .b16 buf0;\n\t" + ".reg .b32 buf1;\n\t" + "cvt.rn.f16.f64 buf0, %1;\n\t" + "cvt.u32.u16 buf1, buf0;\n\t" + "cvt.rn.satfinite.e5m2x2.f16x2 %0, buf1;\n\t" + "}" + : "=h"(_tmp_buffer) + : "d"(f)); + memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); + return val; +} + +__device__ __inline__ double __e5m22double(const __e5m2 h) { + unsigned short _tmp_buffer; + memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); + double val; + asm("{\n\t" + ".reg .b32 buf0;\n\t" + "cvt.rn.f16x2.e5m2x2 buf0, %1;\n\t" + "cvt.u16.u32 %1, buf0;\n\t" + "cvt.f64.f16 %0, %1;" + "}" + : "=d"(val) + : "h"(_tmp_buffer)); + + return val; +} + +__device__ __inline__ __e5m2 __float2e5m2(const float f) { + constexpr float f_const_zero = 0.f; + unsigned short _tmp_buffer; + __e5m2 val; + asm("{cvt.rn.satfinite.e5m2x2.f32 %0, %1, %2;}" + : "=h"(_tmp_buffer) + : "f"(f_const_zero), "f"(f)); + memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); + + return val; +} + +__device__ __inline__ float __e5m22float(const __e5m2 h) { + unsigned short _tmp_buffer; + memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); + float val; + asm("{\n\t" + ".reg .b32 buf0;\n\t" + "cvt.rn.f16x2.e5m2x2 buf0, %1;\n\t" + "cvt.u16.u32 %1, buf0;\n\t" + "cvt.f32.f16 %0, %1;\n\t" + "}" + : "=f"(val) + : "h"(_tmp_buffer)); + + return val; +} + +__device__ __inline__ __e5m2 __half2e5m2(const __half h) { + uint32_t buffer; + memcpy(&buffer, &h, sizeof(__half)); + unsigned short _tmp_buffer; + __e5m2 val; + asm("{cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;}\n\t" + : "=h"(_tmp_buffer) + : "r"(buffer)); + memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); + + return val; +} + +__device__ __inline__ __half __e5m22half(const __e5m2 h) { + unsigned short _tmp_buffer; + memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); + __half val; + asm("{\n\t" + ".reg .b32 buf0;\n\t" + "cvt.rn.f16x2.e5m2x2 buf0, %1;\n\t" + "cvt.u16.u32 %0, buf0;\n\t" + "}" + : "=h"(__NVFUSER_HALF_TO_US(val)) + : "h"(_tmp_buffer)); + + return val; +} + +__device__ __inline__ __e5m2 __bfloat2e5m2(const __bfloat h) { + unsigned short _tmp_buffer; + __e5m2 val; + asm("{\n\t" + ".reg .b32 buf0;\n\t" + "cvt.rn.f16.bf16 %1, %1;\n\t" + "cvt.u32.u16 buf0, %1;\n\t" + "cvt.rn.satfinite.e5m2x2.f16x2 %0, buf0;\n\t" + "}" + : "=h"(_tmp_buffer) + : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); + memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); + return val; +} + +__device__ __inline__ __bfloat __e5m22bfloat(const __e5m2 h) { + unsigned short _tmp_buffer; + memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); + __bfloat val; + asm("{\n\t" + ".reg .b32 buf0;\n\t" + "cvt.rn.f16x2.e5m2x2 buf0, %1;\n\t" + "cvt.u16.u32 %0, buf0;\n\t" + "cvt.bf16.f16 %0, %0;\n\t" + "}" + : "=h"(__NVFUSER_BFLOAT_TO_US(val)) + : "h"(_tmp_buffer)); + + return val; +} + +__device__ __inline__ __e5m2 operator|(const __e5m2 x, const __e5m2 y) { + unsigned short val; + unsigned short x_val = x.raw(); + unsigned short y_val = y.raw(); + asm("{ or.b16 %0, %1, %2;}\n" : "=h"(val) : "h"(x_val), "h"(y_val)); + return __e5m2(val); +} + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +// Type trait utils +template +struct MaybeVolatile; + +template +struct MaybeVolatile { + using type = volatile Type; +}; + +template +struct MaybeVolatile { + using type = Type; +}; + +template +struct TypeList {}; + +template +struct TypeSelector { + using type = typename TypeSelector::type; +}; + +template +struct TypeSelector<0, T, Types...> { + using type = T; +}; + +template +struct IsSameType { + static constexpr bool value = false; +}; + +template +struct IsSameType { + static constexpr bool value = true; +}; + +template +struct IsPointerType { + static constexpr bool value = false; +}; + +template +struct IsPointerType { + static constexpr bool value = true; +}; + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +// aligned register array for vectorized load/store +template +struct alignas(sizeof(scalar_t) * align_size) Array { + scalar_t array[size]; + + __device__ void set(scalar_t v) { +#pragma unroll + for (int i = 0; i < size; ++i) { + array[i] = v; + } + } + + __device__ scalar_t& operator[](const unsigned int i) { + return array[i]; + } + + __device__ const scalar_t& operator[](const unsigned int i) const { + return array[i]; + } + + Array& operator=(const Array& a) { +#pragma unroll + for (int i = 0; i < size; ++i) { + array[i] = a[i]; + } + return *this; + } +}; + +// Used for vectorized allocations that are not in registers +template +__device__ void arraySet(scalar_t* buff, scalar_t val) { +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + buff[i] = val; + } +} + +template +__device__ void loadGeneric(scalar_t* to, scalar_t* from) { + // It would be really nice to use memcpy here, but one example was failing + // with: + // + // memcpy(to, from, vec_size * sizeof(scalar_t)); + // + // Yet passing with: + // + // for(int i = 0; i < vec_size; i++){ + // to[i] = from[i]; + // } + + switch (sizeof(scalar_t) * vec_size) { + case 1: + *reinterpret_cast(to) = *reinterpret_cast(from); + break; + case 2: + *reinterpret_cast(to) = *reinterpret_cast(from); + break; + case 4: + *reinterpret_cast(to) = *reinterpret_cast(from); + break; + case 8: + *reinterpret_cast(to) = *reinterpret_cast(from); + break; + case 12: + *reinterpret_cast(to) = *reinterpret_cast(from); + break; + case 16: + *reinterpret_cast(to) = *reinterpret_cast(from); + break; + } +} + +// Volatile version only works with c++ fundamnetal types +template < + typename scalar_t, + int vec_size, + bool is_volatile_to, + bool is_volatile_from> +__device__ void loadGenericVolatile( + typename MaybeVolatile::type* to, + typename MaybeVolatile::type* from) { + switch (sizeof(scalar_t) * vec_size) { + // Reinterpret cast like this with volatile types only works for C++ + // fundamental types otherwise the = operator is not defined + case 1: + *reinterpret_cast< + typename MaybeVolatile::type*>(to) = + *reinterpret_cast< + typename MaybeVolatile::type*>( + from); + break; + case 2: + *reinterpret_cast::type*>( + to) = + *reinterpret_cast< + typename MaybeVolatile::type*>(from); + break; + case 4: + *reinterpret_cast< + typename MaybeVolatile::type*>(to) = + *reinterpret_cast< + typename MaybeVolatile::type*>( + from); + break; + case 8: + *reinterpret_cast::type*>( + to) = + *reinterpret_cast< + typename MaybeVolatile::type*>(from); + break; + } +} + +template +__device__ void loadLocalToGlobal( + typename MaybeVolatile::type* to, + scalar_t* from) { + switch (sizeof(scalar_t) * vec_size) { + case 1: + case 2: + case 4: + loadGenericVolatile(to, from); + break; + case 8: { + uint2 const& data = *reinterpret_cast(from); + if (is_volatile) { + asm volatile( + "st.volatile.global.v2.s32 [%0], {%1,%2};" ::"l"( + (typename MaybeVolatile::type*)to), + "r"(data.x), + "r"(data.y)); + } else { + asm volatile( + "st.global.cs.v2.s32 [%0], {%1,%2};" ::"l"( + (typename MaybeVolatile::type*)to), + "r"(data.x), + "r"(data.y)); + } + break; + } + case 16: { + uint4 const& data = *reinterpret_cast(from); + if (is_volatile) { + asm volatile( + "st.volatile.global.v4.s32 [%0], {%1,%2,%3,%4};" ::"l"( + (typename MaybeVolatile::type*)to), + "r"(data.x), + "r"(data.y), + "r"(data.z), + "r"(data.w)); + } else { + asm volatile( + "st.global.cs.v4.s32 [%0], {%1,%2,%3,%4};" ::"l"( + (typename MaybeVolatile::type*)to), + "r"(data.x), + "r"(data.y), + "r"(data.z), + "r"(data.w)); + } + break; + } + } +} + +// This is copied from csrc/type.h and should be kept consistent. +enum class CacheOp { + AllLevels, + Streaming, + Global, +}; + +template +__device__ void loadGlobalToLocalCached(void* to, void* from) { + T* typed_to = reinterpret_cast(to); + T* typed_from = reinterpret_cast(from); + switch (cache_op) { + case CacheOp::AllLevels: + *typed_to = __ldca(typed_from); + break; + case CacheOp::Streaming: + *typed_to = __ldcs(typed_from); + break; + case CacheOp::Global: + *typed_to = __ldcg(typed_from); + break; + } +} + +// For simplicity, cache_op is only used for non-volatile loads written in +// inline assembly. Other loads are done with the default cache operator -- +// cache all levels. ld.volatile doesn't accept cache operator anyway. +template +__device__ void loadGlobalToLocal( + scalar_t* to, + typename MaybeVolatile::type* from) { + switch (sizeof(scalar_t) * vec_size) { + case 1: + case 2: + case 4: + loadGenericVolatile(to, from); + break; + case 8: { + if (is_volatile) { + uint2& data = *reinterpret_cast(to); + asm volatile("ld.volatile.global.v2.s32 {%0,%1}, [%2];" + : "=r"(data.x), "=r"(data.y) + : "l"((uint2*)from)); + } else { + loadGlobalToLocalCached( + to, const_cast(from)); + } + break; + } + case 16: { + if (is_volatile) { + uint4& data = *reinterpret_cast(to); + asm volatile("ld.volatile.global.v4.s32 {%0,%1,%2,%3}, [%4];" + : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) + : "l"((uint4*)from)); + } else { + loadGlobalToLocalCached( + to, const_cast(from)); + } + break; + } + } +} + +template < + typename scalar_t, + int vec_size, + bool is_volatile_to, + bool is_volatile_from> +__device__ void loadGlobalToGlobal( + typename MaybeVolatile::type* to, + typename MaybeVolatile::type* from) { + switch (sizeof(scalar_t) * vec_size) { + // Reinterpret cast like this with volatile types only works for C++ + // fundamental types otherwise the = operator is not defined + case 1: + case 2: + case 4: + case 8: + loadGenericVolatile( + to, from); + break; + case 12: { + uint3 local_intermediate; + loadGlobalToLocal< + scalar_t, + vec_size, + is_volatile_from, + CacheOp::Streaming>( + reinterpret_cast(&local_intermediate), from); + loadLocalToGlobal( + to, reinterpret_cast(&local_intermediate)); + break; + } + case 16: { + uint4 local_intermediate; + loadGlobalToLocal< + scalar_t, + vec_size, + is_volatile_from, + CacheOp::Streaming>( + reinterpret_cast(&local_intermediate), from); + loadLocalToGlobal( + to, reinterpret_cast(&local_intermediate)); + break; + } + } +} + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +template +struct Tensor { + __device__ T& operator[](nvfuser_index_t ind) { + return data[ind]; + }; + + T* data; + Array logical_size; + Array alloc_stride; +}; + +// Specialization for 0-dim case as it does not need size and stride arrays. +// They will be an error as well since zero-length arrays are not allowed. +template +struct Tensor { + __device__ T& operator[](nvfuser_index_t i) { + return *data; + }; + + T* data; +}; + +// Specialization for 0-dim case that's easy to pass in a CPU based tensor. +template +struct CpuScalarTensor { + __device__ T& operator[](int i) { + return data; + }; + + T data; +}; + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +__device__ unsigned int mulhilo32( + unsigned int a, + unsigned int b, + unsigned int* result_high) { + *result_high = __umulhi(a, b); + return a * b; +} + +__device__ uint4 single_round(uint4 ctr, uint2 key) { + constexpr unsigned long kPhiloxSA = 0xD2511F53; + constexpr unsigned long kPhiloxSB = 0xCD9E8D57; + unsigned int hi0; + unsigned int hi1; + unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); + unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); + uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + return ret; +} + +__device__ uint4 philox( + unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + constexpr unsigned long kPhilox10A = 0x9E3779B9; + constexpr unsigned long kPhilox10B = 0xBB67AE85; + uint2 key = {}; + key.x = (unsigned int)seed; + key.y = (unsigned int)(seed >> 32); + uint4 counter = make_uint4(0, 0, 0, 0); + counter.x = (unsigned int)(offset); + counter.y = (unsigned int)(offset >> 32); + counter.z = (unsigned int)(subsequence); + counter.w = (unsigned int)(subsequence >> 32); + + uint4 output = {}; + uint2 key_ = key; + uint4 counter_ = counter; + for (int i = 0; i < 9; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); + key_.y += (kPhilox10B); + } + output = single_round(counter_, key_); + return output; +} + +// This is a uniform double in the range (0, 1] +__device__ double raw_uniform_double(unsigned int x, unsigned int y) { + constexpr double scale = 1.0 / (double)(1ll << 53); + const unsigned long long z = + (unsigned long long)x ^ ((unsigned long long)y << (53 - 32)); + return (double)z * scale + 0.5 * scale; +} + +// This is a uniform float in the range (0, 1] +__device__ float raw_uniform_float(unsigned int x) { + constexpr float scale = (float)(1.0 / (double)(1ll << 32)); + return (float)x * scale + 0.5f * scale; +} + +__device__ __half uniform_half(unsigned int x) { + __half result = __float2half(raw_uniform_float(x)); + return __heq(result, __float2half(1.0f)) ? __float2half(0.0f) : result; +} + +__device__ __bfloat uniform_bfloat(unsigned int x) { + __bfloat result = __float2bfloat(raw_uniform_float(x)); + return __heq(result, __float2bfloat(1.0f)) ? __float2bfloat(0.0f) : result; +} + +__device__ float uniformf(unsigned int x) { + float result = raw_uniform_float(x); + return result == 1.0f ? 0.0f : result; +} + +__device__ double uniform(unsigned int x, unsigned int y) { + double result = raw_uniform_double(x, y); + return result == 1.0 ? 0.0 : result; +} + +__device__ double rng_uniform(const uint4& rng_result, int rng_component) { + return uniform( + (&rng_result.x)[rng_component * 2], + (&rng_result.x)[rng_component * 2 + 1]); +} + +__device__ float rng_uniformf(const uint4& rng_result, int rng_component) { + return uniformf((&rng_result.x)[rng_component]); +} + +__device__ __half rng_uniform_half(const uint4& rng_result, int rng_component) { + return uniform_half((&rng_result.x)[rng_component]); +} + +__device__ __bfloat +rng_uniform_bfloat(const uint4& rng_result, int rng_component) { + return uniform_bfloat((&rng_result.x)[rng_component]); +} + +__device__ double rng_uniform_range( + const uint4& rng_result, + int rng_component, + double from, + double to) { + auto range = to - from; + auto uniform01 = rng_uniform(rng_result, rng_component); + return from + range * uniform01; +} + +__device__ float rng_uniform_rangef( + const uint4& rng_result, + int rng_component, + float from, + float to) { + auto range = to - from; + auto uniform01 = rng_uniformf(rng_result, rng_component); + return from + range * uniform01; +} + +__device__ __half rng_uniform_range_half( + const uint4& rng_result, + int rng_component, + float from, + float to) { + auto range = to - from; + float uniform01 = raw_uniform_float((&rng_result.x)[rng_component]); + __half result = __float2half(from + range * uniform01); + return __heq(result, __float2half(to)) ? __float2half(from) : result; +} + +__device__ __bfloat rng_uniform_range_bfloat( + const uint4& rng_result, + int rng_component, + float from, + float to) { + auto range = to - from; + float uniform01 = raw_uniform_float((&rng_result.x)[rng_component]); + __bfloat result = __float2bfloat(from + range * uniform01); + return __heq(result, __float2bfloat(to)) ? __float2bfloat(from) : result; +} + +__device__ float normalf(unsigned int x, unsigned int y, int rng_component) { + float u = uniformf(x); + float v = uniformf(y) * 6.2831855f; + + if (rng_component % 2 == 0) { + return sqrtf(-2.0f * logf(u)) * sinf(v); + } else { + return sqrtf(-2.0f * logf(u)) * cosf(v); + } +} + +__device__ double normal( + unsigned int x0, + unsigned int x1, + unsigned int y0, + unsigned int y1, + int rng_component) { + double u = uniform(x0, x1); + double v = uniform(y0, y1) * 6.2831853071795860; + + if (rng_component % 2 == 0) { + return sqrt(-2.0 * log(u)) * sin(v); + } else { + return sqrt(-2.0 * log(u)) * cos(v); + } +} + +__device__ double rng_normal_standard( + const uint4& rng_result, + int rng_component) { + return normal( + rng_result.x, rng_result.y, rng_result.z, rng_result.w, rng_component); +} + +__device__ float rng_normal_standardf( + const uint4& rng_result, + int rng_component) { + return normalf( + (&rng_result.x)[rng_component / 2 * 2], + (&rng_result.y)[rng_component / 2 * 2], + rng_component); +} + +__device__ __half +rng_normal_standard_half(const uint4& rng_result, int rng_component) { + return __float2half(normalf( + (&rng_result.x)[rng_component / 2 * 2], + (&rng_result.y)[rng_component / 2 * 2], + rng_component)); +} + +__device__ __bfloat +rng_normal_standard_bfloat(const uint4& rng_result, int rng_component) { + return __float2bfloat(normalf( + (&rng_result.x)[rng_component / 2 * 2], + (&rng_result.y)[rng_component / 2 * 2], + rng_component)); +} + +__device__ double rng_normal_general( + const uint4& rng_result, + int rng_component, + double mean, + double std) { + auto normal01 = rng_normal_standard(rng_result, rng_component); + return normal01 * std + mean; +} + +__device__ float rng_normal_generalf( + const uint4& rng_result, + int rng_component, + float mean, + float std) { + auto normal01 = rng_normal_standardf(rng_result, rng_component); + return normal01 * std + mean; +} + +__device__ __half rng_normal_general_half( + const uint4& rng_result, + int rng_component, + float mean, + float std) { + auto normal01 = normalf( + (&rng_result.x)[rng_component / 2 * 2], + (&rng_result.y)[rng_component / 2 * 2], + rng_component); + return __float2half(normal01 * std + mean); +} + +__device__ __bfloat rng_normal_general_bfloat( + const uint4& rng_result, + int rng_component, + float mean, + float std) { + auto normal01 = normalf( + (&rng_result.x)[rng_component / 2 * 2], + (&rng_result.y)[rng_component / 2 * 2], + rng_component); + return __float2bfloat(normal01 * std + mean); +} + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#define NVFUSER_DEFINE_MAGIC_ZERO \ + __shared__ int nvfuser_zero_s; \ + if (threadIdx.x == 0) \ + nvfuser_zero_s = 0; \ + __syncthreads(); \ + atomicMin(&nvfuser_zero_s, threadIdx.x); \ + int nvfuser_zero = nvfuser_zero_s; + +#define NVFUSER_UPDATE_MAGIC_ZERO \ + do { \ + nvfuser_zero <<= 1; \ + } while (0); + +#ifdef __NVCC__ +#include +#endif // __NVCC__ + +__device__ constexpr int ceilDiv(int a, int b) { + return (a + b - 1) / b; +} + +__device__ constexpr int64_t ceilDiv(int64_t a, int64_t b) { + return (a + b - 1) / b; +} + +__device__ constexpr int64_t ceilDiv(int64_t a, int b) { + return ceilDiv(a, (int64_t)b); +} + +__device__ constexpr int64_t ceilDiv(int a, int64_t b) { + return ceilDiv((int64_t)a, b); +} + +__device__ constexpr double ceilDiv(double a, double b) { + return std::ceil(a / b); +} + +__device__ constexpr double ceilDiv(double a, int64_t b) { + return std::ceil(a / b); +} + +__device__ constexpr double ceilDiv(int64_t a, double b) { + return std::ceil(a / b); +} + +// Monotonic and precise lerp is described here: +// https://math.stackexchange.com/a/1798323 +__device__ double lerp(double start, double end, double weight) { + if (weight < 0.5) { + return start + weight * (end - start); + } else { + return end - (end - start) * (1.0 - weight); + } +} + +__device__ float lerp(float start, float end, float weight) { + if (weight < 0.5f) { + return start + weight * (end - start); + } else { + return end - (end - start) * (1.0f - weight); + } +} + +__device__ float lerp(float start, float end, double weight) { + return lerp(start, end, static_cast(weight)); +} + +__device__ constexpr int max(int a, int b) { + return a > b ? a : b; +} + +__device__ constexpr int64_t max(int64_t a, int b) { + return a > (int64_t)b ? a : (int64_t)b; +} + +__device__ constexpr int64_t max(int a, int64_t b) { + return (int64_t)a > b ? (int64_t)a : b; +} + +__device__ constexpr int64_t max(int64_t a, int64_t b) { + return a > b ? a : b; +} + +__device__ double fmax(double a, double b) { + // check and propagate NaN + if (a != a) { + return a; + } else { // If b is nan, it will be returned in the next line + return a > b ? a : b; + } +} + +__device__ float fmax(float a, float b) { + // check and propagate NaN + if (a != a) { + return a; + } else { // If b is nan, it will be returned in the next line + return a > b ? a : b; + } +} + +__device__ constexpr int min(int a, int b) { + return a > b ? b : a; +} + +__device__ constexpr int64_t min(int64_t a, int b) { + return (int64_t)a > b ? b : (int64_t)a; +} + +__device__ constexpr int64_t min(int a, int64_t b) { + return a > (int64_t)b ? (int64_t)b : a; +} + +__device__ constexpr int64_t min(int64_t a, int64_t b) { + return a > b ? b : a; +} + +__device__ double fmin(double a, double b) { + // check and propagate NaN + if (b != b) { + return b; + } else { // If a is nan, it will be returned in the next line + return a > b ? b : a; + } +} + +__device__ float fmin(float a, float b) { + // check and propagate NaN + if (b != b) { + return b; + } else { // If a is nan, it will be returned in the next line + return a > b ? b : a; + } +} + +__device__ constexpr int alignBufferSize(int buffer, int size) { + return (buffer + (size - 1)) & ~(size - 1); +} + +__device__ double clamp(double x, double minv, double maxv) { + return fmin(fmax(x, minv), maxv); +} + +__device__ float clamp(float x, double minv, double maxv) { + return fmin(fmax((double)x, minv), maxv); +} + +__device__ int clamp(int x, int64_t minv, int64_t maxv) { + return min(max((int64_t)x, minv), maxv); +} + +__device__ int64_t clamp(int64_t x, int64_t minv, int64_t maxv) { + return min(max(x, minv), maxv); +} + +__device__ double frac(double x) { + return x - trunc(x); +} + +__device__ float frac(float x) { + return x - trunc(x); +} + +__device__ double reciprocal(double x) { + return 1 / x; +} + +__device__ float reciprocal(float x) { + return 1 / x; +} + +__device__ double relu(double x) { + return x <= 0 ? 0 : x; +} + +__device__ float relu(float x) { + return x <= 0 ? 0 : x; +} + +__device__ float relu(int64_t x) { + return x <= 0 ? 0 : x; +} + +__device__ float relu(int x) { + return x <= 0 ? 0 : x; +} + +__device__ double remainder(double a, double b) { + auto mod = ::fmod(a, b); + if ((mod != 0) && ((b < 0) != (mod < 0))) + mod += b; + return mod; +} + +__device__ float remainder(float a, float b) { + auto mod = ::fmod(a, b); + if ((mod != 0) && ((b < 0) != (mod < 0))) + mod += b; + return mod; +} + +__device__ double sigmoid(double x) { + return 1.0 / (1.0 + exp(-x)); +} + +__device__ float sigmoid(float x) { + return 1.0f / (1.0f + exp(-x)); +} + +__device__ double silu(double x) { + return x * sigmoid(x); +} + +__device__ float silu(float x) { + return x * sigmoid(x); +} + +__device__ double threshold(double x, double t, double v) { + return x <= t ? v : x; +} + +__device__ float threshold(float x, double t, double v) { + return x <= t ? v : x; +} + +__device__ int threshold(int x, int64_t t, int64_t v) { + return x <= t ? v : x; +} + +__device__ int64_t threshold(int64_t x, int64_t t, int64_t v) { + return x <= t ? v : x; +} + +__device__ constexpr int64_t remainder(int64_t a, int64_t b) { + auto mod = a % b; + if ((mod != 0) && ((b < 0) != (mod < 0))) + mod += b; + return mod; +} + +__device__ constexpr int remainder(int a, int b) { + auto mod = a % b; + if ((mod != 0) && ((b < 0) != (mod < 0))) + mod += b; + return mod; +} + +__device__ constexpr int64_t fmod(int64_t a, int64_t b) { + return a % b; +} + +__device__ constexpr int fmod(int a, int b) { + return a % b; +} + +__device__ constexpr double fmod(double a, double b) { + return ::fmod(a, b); +} + +__device__ constexpr float fmod(float a, float b) { + return ::fmod(a, b); +} + +__device__ constexpr double nextafter(double a, double b) { + return ::nextafter(a, b); +} + +__device__ constexpr float nextafter(float a, float b) { + return ::nextafterf(a, b); +} + +template +__device__ T pow(T a, T b) { + if (b < 0) { + if (a == 1) { + return 1; + } else if (a == -1) { + auto negative = (-b) % static_cast(2); + return negative ? -1 : 1; + } else { + return 0; + } + } else { + T result = 1; + while (b) { + if (b & 1) { + result *= a; + } + b /= 2; + a *= a; + } + return result; + } +} + +template __device__ int pow(int a, int b); +template __device__ int64_t pow(int64_t a, int64_t b); + +template <> +__device__ float pow(float a, float b) { + return ::pow(a, b); +} + +template <> +__device__ double pow(double a, double b) { + return ::pow(a, b); +} + +__device__ float pow(float a, int b) { + return pow(a, (float)b); +} + +__device__ double pow(double a, int b) { + return pow(a, (double)b); +} + +__device__ float pow(float a, int64_t b) { + return pow(a, (float)b); +} + +__device__ double pow(double a, int64_t b) { + return pow(a, (double)b); +} + +__device__ int64_t pow(int64_t a, int b) { + return pow(a, (int64_t)b); +} + +__device__ int64_t pow(int a, int64_t b) { + return pow((int64_t)a, b); +} + +__device__ double rsqrt(double z) { + return ::rsqrt(z); +} + +__device__ float rsqrt(float z) { + return ::rsqrtf(z); +} + +__device__ int rsqrt(int z) { + return ::rsqrtf((float)z); +} + +__device__ int64_t rsqrt(int64_t z) { + return ::rsqrt((double)z); +} + +__device__ double signbit(double a) { + return ::signbit(a); +} + +__device__ float signbit(float a) { + return ::signbit(a); +} + +__device__ int signbit(int a) { + return a < 0; +} + +__device__ int64_t signbit(int64_t a) { + return a < 0; +} + +// Reference: +// https://en.wikipedia.org/wiki/Euclidean_algorithm#Implementations +// https://github.com/pytorch/pytorch/blob/c9f4f01981fd73fcc7c27676cc50230cd1b5bc22/aten/src/ATen/native/Math.h#L1232 +template +__device__ T gcd(T a, T b) { + a = abs(a); + b = abs(b); + while (b != 0) { + auto t = b; + b = a % b; + a = t; + } + return a; +} + +template +bool isfinite(T x) { + return ::isfinite(x); +} + +// ref: +// https://github.com/NVIDIA/cutlass/blob/6fbc0d33800008d3180d3fefed4e1a653e5f72a0/include/cutlass/bfloat16.h#L213 +template <> +bool isfinite<__bfloat>(__bfloat x) { + const auto exponent_biased = int((x.raw() >> 7) & 0x0ff); + return exponent_biased != 0x0ff; +} + +// ref: +// https://github.com/NVIDIA/cutlass/blob/6fbc0d33800008d3180d3fefed4e1a653e5f72a0/include/cutlass/half.h#L511 +template <> +bool isfinite<__half>(__half x) { + const auto exponent_biased = int((x.raw() >> 10) & 0x1f); + return exponent_biased != 0x1f; +} + +template +bool isinf(T x) { + return ::isinf(x); +} + +//////////////////////////////////////////////////////////// +// TODO: the following overloads are only needed for CUDA // +// 10.2 Please remove when CUDA 10.2 support is dropped // +//////////////////////////////////////////////////////////// + +bool isinf(int64_t x) { + return false; +} + +bool isinf(int x) { + return false; +} + +bool isinf(short x) { + return false; +} + +bool isinf(char x) { + return false; +} + +bool isinf(unsigned char x) { + return false; +} + +bool isinf(bool x) { + return false; +} + +bool isfinite(int64_t x) { + return true; +} + +bool isfinite(int x) { + return true; +} + +bool isfinite(short x) { + return true; +} + +bool isfinite(char x) { + return true; +} + +bool isfinite(unsigned char x) { + return true; +} + +bool isfinite(bool x) { + return true; +} + +//////////////////////////////////////////////////////////// +// End TODO // +//////////////////////////////////////////////////////////// + +template +bool isnan(T x) { + return x != x; +} + +template +bool isneginf(T x) { + return x < 0 && isinf(x); +} + +template +bool isposinf(T x) { + return x > 0 && isinf(x); +} + +template +bool isreal(T x) { + return true; +} + +// Return the current value of the cycle counter +__device__ inline int64_t readCycleCounter() { + // Ensures preceding memory operations are completed. Doing this + // would make sense for measuring elapsed times enclosed with this + // function. + __threadfence(); + return clock64(); +} + +__device__ float print_impl(const char* name, float value) { + printf( + "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + value, + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); + return value; +} + +__device__ double print_impl(const char* name, double value) { + printf( + "%s = %lf @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + value, + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); + return value; +} + +__device__ int print_impl(const char* name, int value) { + printf( + "%s = %d @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + value, + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); + return value; +} + +__device__ int64_t print_impl(const char* name, int64_t value) { + printf( + "%s = %ld @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + value, + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); + return value; +} + +__device__ bool print_impl(const char* name, bool value) { + printf( + "%s = %s @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + value ? "true" : "false", + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); + return value; +} + +__device__ __half print_impl(const char* name, __half value) { + printf( + "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + __half2float(value), + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); + return value; +} + +#if __CUDACC_VER_MAJOR__ >= 11 +__device__ __bfloat print_impl(const char* name, __bfloat value) { + printf( + "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + __bfloat2float(value), + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); + return value; +} +#endif + +#define print(...) print_impl(#__VA_ARGS__, (__VA_ARGS__)) + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +namespace index_utils { + +// Utility functions + +// Total size of provided dimension +template +__device__ __forceinline__ nvfuser_index_t size(const _dim3& d) { + return (nvfuser_index_t)d.x * (nvfuser_index_t)d.y * (nvfuser_index_t)d.z; +} + +// Linearized indexing of idx based on dim, if bool==false that dimension does +// not participate +template +__device__ nvfuser_index_t maskedOffset(const _dim3& idx, const _dim3_2& dim) { + nvfuser_index_t offset = 0; + if (Z) + offset += idx.z; + if (Y) + offset = offset * dim.y + idx.y; + if (X) + offset = offset * dim.x + idx.x; + return offset; +} + +// Linearized indexing of idx based on dim. All dimensions participate. +template +__device__ nvfuser_index_t offset(const _dim3& idx, const _dim3_2& dim) { + nvfuser_index_t offset = idx.z; + offset = offset * dim.y + idx.y; + offset = offset * dim.x + idx.x; + return offset; +} + +// Masks the provided dim3, those == false get truncated to 1 +template +__device__ dim3 maskedDims(const _dim3& dim) { + return dim3{ + X ? (unsigned)dim.x : 1U, + Y ? (unsigned)dim.y : 1U, + Z ? (unsigned)dim.z : 1U}; +} + +// Provides total size of dim with masking, those dims == false do not +// participate in the size calculation +template +__device__ nvfuser_index_t maskedSize(const _dim3& dim) { + return size(maskedDims(dim)); +} + +// Checks if provided idx is zero on those dims == true +template +__device__ bool maskedIsZero(const _dim3& idx) { + bool isZero = true; + if (X) + isZero = isZero && idx.x == 0; + if (Y) + isZero = isZero && idx.y == 0; + if (Z) + isZero = isZero && idx.z == 0; + return isZero; +} + +// Checks if provided idx is zero on those dims == true +template +__device__ bool maskedIsLast(const _dim3& idx, const _dim3_2& dim) { + bool isZero = true; + if (X) + isZero = isZero && idx.x == dim.x - 1; + if (Y) + isZero = isZero && idx.y == dim.y - 1; + if (Z) + isZero = isZero && idx.z == dim.z - 1; + return isZero; +} + +} // namespace index_utils + +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +// std::tuple-like type +template +struct Tuple; + +#define TUPLE_INCREMENT_PTR(idx) \ + do { \ + static_assert( \ + IsPointerType::value, "Invalid for non-pointer types"); \ + val##idx += offset; \ + } while (0) + +template +struct Tuple { + T0 val0; + + Tuple() = default; + + __device__ Tuple(T0 _val0) : val0(_val0) {} + + // Only valid when instantiated for pointer types + __device__ void operator+=(nvfuser_index_t offset) { + TUPLE_INCREMENT_PTR(0); + } +}; + +template +struct Tuple { + T0 val0; + T1 val1; + + Tuple() = default; + + __device__ Tuple(T0 _val0, T1 _val1) : val0(_val0), val1(_val1) {} + + // Only valid when instantiated for pointer types + __device__ void operator+=(nvfuser_index_t offset) { + TUPLE_INCREMENT_PTR(0); + TUPLE_INCREMENT_PTR(1); + } +}; + +template +struct Tuple { + T0 val0; + T1 val1; + T2 val2; + + Tuple() = default; + + __device__ Tuple(T0 _val0, T1 _val1, T2 _val2) + : val0(_val0), val1(_val1), val2(_val2) {} + + // Only valid when instantiated for pointer types + __device__ void operator+=(nvfuser_index_t offset) { + TUPLE_INCREMENT_PTR(0); + TUPLE_INCREMENT_PTR(1); + TUPLE_INCREMENT_PTR(2); + } +}; + +template +struct Tuple { + T0 val0; + T1 val1; + T2 val2; + T3 val3; + + Tuple() = default; + + __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3) + : val0(_val0), val1(_val1), val2(_val2), val3(_val3) {} + + // Only valid when instantiated for pointer types + __device__ void operator+=(nvfuser_index_t offset) { + TUPLE_INCREMENT_PTR(0); + TUPLE_INCREMENT_PTR(1); + TUPLE_INCREMENT_PTR(2); + TUPLE_INCREMENT_PTR(3); + } +}; + +template +struct Tuple { + T0 val0; + T1 val1; + T2 val2; + T3 val3; + T4 val4; + + Tuple() = default; + + __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4) + : val0(_val0), val1(_val1), val2(_val2), val3(_val3), val4(_val4) {} + + // Only valid when instantiated for pointer types + __device__ void operator+=(nvfuser_index_t offset) { + TUPLE_INCREMENT_PTR(0); + TUPLE_INCREMENT_PTR(1); + TUPLE_INCREMENT_PTR(2); + TUPLE_INCREMENT_PTR(3); + TUPLE_INCREMENT_PTR(4); + } +}; + +template < + typename T0, + typename T1, + typename T2, + typename T3, + typename T4, + typename T5> +struct Tuple { + T0 val0; + T1 val1; + T2 val2; + T3 val3; + T4 val4; + T5 val5; + + Tuple() = default; + + __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4, T5 _val5) + : val0(_val0), + val1(_val1), + val2(_val2), + val3(_val3), + val4(_val4), + val5(_val5) {} + + // Only valid when instantiated for pointer types + __device__ void operator+=(nvfuser_index_t offset) { + TUPLE_INCREMENT_PTR(0); + TUPLE_INCREMENT_PTR(1); + TUPLE_INCREMENT_PTR(2); + TUPLE_INCREMENT_PTR(3); + TUPLE_INCREMENT_PTR(4); + TUPLE_INCREMENT_PTR(5); + } +}; + +template < + typename T0, + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6> +struct Tuple { + T0 val0; + T1 val1; + T2 val2; + T3 val3; + T4 val4; + T5 val5; + T6 val6; + + Tuple() = default; + + __device__ Tuple( + T0 _val0, + T1 _val1, + T2 _val2, + T3 _val3, + T4 _val4, + T5 _val5, + T6 _val6) + : val0(_val0), + val1(_val1), + val2(_val2), + val3(_val3), + val4(_val4), + val5(_val5), + val6(_val6) {} + + // Only valid when instantiated for pointer types + __device__ void operator+=(nvfuser_index_t offset) { + TUPLE_INCREMENT_PTR(0); + TUPLE_INCREMENT_PTR(1); + TUPLE_INCREMENT_PTR(2); + TUPLE_INCREMENT_PTR(3); + TUPLE_INCREMENT_PTR(4); + TUPLE_INCREMENT_PTR(5); + TUPLE_INCREMENT_PTR(6); + } +}; + +template < + typename T0, + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7> +struct Tuple { + T0 val0; + T1 val1; + T2 val2; + T3 val3; + T4 val4; + T5 val5; + T6 val6; + T7 val7; + + Tuple() = default; + + __device__ Tuple( + T0 _val0, + T1 _val1, + T2 _val2, + T3 _val3, + T4 _val4, + T5 _val5, + T6 _val6, + T7 _val7) + : val0(_val0), + val1(_val1), + val2(_val2), + val3(_val3), + val4(_val4), + val5(_val5), + val6(_val6), + val7(_val7) {} + + // Only valid when instantiated for pointer types + __device__ void operator+=(nvfuser_index_t offset) { + TUPLE_INCREMENT_PTR(0); + TUPLE_INCREMENT_PTR(1); + TUPLE_INCREMENT_PTR(2); + TUPLE_INCREMENT_PTR(3); + TUPLE_INCREMENT_PTR(4); + TUPLE_INCREMENT_PTR(5); + TUPLE_INCREMENT_PTR(6); + TUPLE_INCREMENT_PTR(7); + } +}; + +template < + typename T0, + typename T1, + typename T2, + typename T3, + typename T4, + typename T5, + typename T6, + typename T7, + typename T8, + typename T9, + typename T10, + typename T11, + typename T12, + typename T13, + typename T14, + typename T15> +struct Tuple< + T0, + T1, + T2, + T3, + T4, + T5, + T6, + T7, + T8, + T9, + T10, + T11, + T12, + T13, + T14, + T15> { + T0 val0; + T1 val1; + T2 val2; + T3 val3; + T4 val4; + T5 val5; + T6 val6; + T7 val7; + T8 val8; + T9 val9; + T10 val10; + T11 val11; + T12 val12; + T13 val13; + T14 val14; + T15 val15; + + Tuple() = default; + + __device__ Tuple( + T0 _val0, + T1 _val1, + T2 _val2, + T3 _val3, + T4 _val4, + T5 _val5, + T6 _val6, + T7 _val7, + T8 _val8, + T9 _val9, + T10 _val10, + T11 _val11, + T12 _val12, + T13 _val13, + T14 _val14, + T15 _val15) + : val0(_val0), + val1(_val1), + val2(_val2), + val3(_val3), + val4(_val4), + val5(_val5), + val6(_val6), + val7(_val7), + val8(_val8), + val9(_val9), + val10(_val10), + val11(_val11), + val12(_val12), + val13(_val13), + val14(_val14), + val15(_val15) {} + + // Only valid when instantiated for pointer types + __device__ void operator+=(nvfuser_index_t offset) { + TUPLE_INCREMENT_PTR(0); + TUPLE_INCREMENT_PTR(1); + TUPLE_INCREMENT_PTR(2); + TUPLE_INCREMENT_PTR(3); + TUPLE_INCREMENT_PTR(4); + TUPLE_INCREMENT_PTR(5); + TUPLE_INCREMENT_PTR(6); + TUPLE_INCREMENT_PTR(7); + TUPLE_INCREMENT_PTR(8); + TUPLE_INCREMENT_PTR(9); + TUPLE_INCREMENT_PTR(10); + TUPLE_INCREMENT_PTR(11); + TUPLE_INCREMENT_PTR(12); + TUPLE_INCREMENT_PTR(13); + TUPLE_INCREMENT_PTR(14); + TUPLE_INCREMENT_PTR(15); + } +}; + +#undef TUPLE_INCREMENT_PTR + +// Accessor for Tuple +template +struct get; + +#define DEFINE_TUPLE_GET(idx) \ + template <> \ + struct get { \ + template \ + __device__ auto& operator()(Tuple& vals) { \ + return vals.val##idx; \ + } \ + template \ + __device__ const auto& operator()(const Tuple& vals) { \ + return vals.val##idx; \ + } \ + }; + +DEFINE_TUPLE_GET(0); +DEFINE_TUPLE_GET(1); +DEFINE_TUPLE_GET(2); +DEFINE_TUPLE_GET(3); +DEFINE_TUPLE_GET(4); +DEFINE_TUPLE_GET(5); +DEFINE_TUPLE_GET(6); +DEFINE_TUPLE_GET(7); +DEFINE_TUPLE_GET(8); +DEFINE_TUPLE_GET(9); +DEFINE_TUPLE_GET(10); +DEFINE_TUPLE_GET(11); +DEFINE_TUPLE_GET(12); +DEFINE_TUPLE_GET(13); +DEFINE_TUPLE_GET(14); +DEFINE_TUPLE_GET(15); +#undef DEFINE_TUPLE_GET + +template +__inline__ __device__ static void copyTuple( + DstType& dst, + nvfuser_index_t dst_offset, + const SrcType& src, + nvfuser_index_t src_offset = 0); + +template +__inline__ __device__ static void copyTuple( + DstType& dst, + const SrcType& src, + nvfuser_index_t src_offset = 0); + +template +__inline__ __device__ static void setTuple( + DstType& dst, + typename DstType::template ValType<0> src); + +template +class LocalTuple { + public: + static constexpr int num_vals = sizeof...(Types); + using ValTypes = TypeList; + + template + using ValType = typename TypeSelector::type; + + LocalTuple() = default; + + __device__ explicit LocalTuple(Types... args) : vals_(args...) {} + + __device__ LocalTuple(const LocalTuple& other) : vals_(other.vals_) {} + + template