diff --git a/libcudacxx/include/cuda/std/__random/linear_congruential_engine.h b/libcudacxx/include/cuda/std/__random/linear_congruential_engine.h index 988d596e6b9..abfadf4d662 100644 --- a/libcudacxx/include/cuda/std/__random/linear_congruential_engine.h +++ b/libcudacxx/include/cuda/std/__random/linear_congruential_engine.h @@ -28,85 +28,89 @@ #include #include +#if !_CCCL_COMPILER(NVRTC) +# include +#endif // !_CCCL_COMPILER(NVRTC) + #include _CCCL_BEGIN_NAMESPACE_CUDA_STD -template (_Mp - __c) / __a), - bool _OverflowOK = ((__m | (__m - 1)) > __m), // m = 2^n - bool _SchrageOK = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q + bool _MightOverflow = (__A != 0 && __M != 0 && __M - 1 > (_Mp - __C) / __A), + bool _OverflowOk = ((__M | (__M - 1)) > __M), // m = 2^n + bool _SchrageOk = (__A != 0 && __M != 0 && __M % __A <= __M / __A)> // r <= q struct __lce_alg_picker { - static_assert(__a != 0 || __m != 0 || !_MightOverflow || _OverflowOK || _SchrageOK, + static_assert(__A != 0 || __M != 0 || !_MightOverflow || _OverflowOk || _SchrageOk, "The current values of a, c, and m cannot generate a number " "within bounds of linear_congruential_engine."); - static constexpr const bool __use_schrage = _MightOverflow && !_OverflowOK && _SchrageOK; + static constexpr const bool __use_schrage = _MightOverflow && !_OverflowOk && _SchrageOk; }; -template ::__use_schrage> + bool _UseSchrage = __lce_alg_picker<__A, __C, __M, _Mp>::__use_schrage> struct __lce_ta; // 64 -template -struct __lce_ta<__a, __c, __m, ~uint64_t{0}, true> +template +struct __lce_ta<__A, __C, __M, ~uint64_t{0}, true> { using result_type = uint64_t; - [[nodiscard]] _CCCL_API static result_type next(result_type __x) noexcept + [[nodiscard]] _CCCL_API static constexpr result_type next(result_type __x) noexcept { // Schrage's algorithm - constexpr result_type __q = __m / __a; - constexpr result_type __r = __m % __a; - const result_type __t0 = __a * (__x % __q); + constexpr result_type __q = __M / __A; + constexpr result_type __r = __M % __A; + const result_type __t0 = __A * (__x % __q); const result_type __t1 = __r * (__x / __q); - __x = __t0 + (__t0 < __t1) * __m - __t1; - __x += __c - (__x >= __m - __c) * __m; + __x = __t0 + (__t0 < __t1) * __M - __t1; + __x += __C - (__x >= __M - __C) * __M; return __x; } }; -template -struct __lce_ta<__a, 0, __m, ~uint64_t{0}, true> +template +struct __lce_ta<__A, 0, __M, ~uint64_t{0}, true> { using result_type = uint64_t; - [[nodiscard]] _CCCL_API static result_type next(result_type __x) noexcept + [[nodiscard]] _CCCL_API static constexpr result_type next(result_type __x) noexcept { // Schrage's algorithm - constexpr result_type __q = __m / __a; - constexpr result_type __r = __m % __a; - const result_type __t0 = __a * (__x % __q); + constexpr result_type __q = __M / __A; + constexpr result_type __r = __M % __A; + const result_type __t0 = __A * (__x % __q); const result_type __t1 = __r * (__x / __q); - __x = __t0 + (__t0 < __t1) * __m - __t1; + __x = __t0 + (__t0 < __t1) * __M - __t1; return __x; } }; -template -struct __lce_ta<__a, __c, __m, ~uint64_t{0}, false> +template +struct __lce_ta<__A, __C, __M, ~uint64_t{0}, false> { using result_type = uint64_t; - [[nodiscard]] _CCCL_API static result_type next(result_type __x) noexcept + [[nodiscard]] _CCCL_API static constexpr result_type next(result_type __x) noexcept { - return (__a * __x + __c) % __m; + return (__A * __x + __C) % __M; } }; -template -struct __lce_ta<__a, __c, 0, ~uint64_t{0}, false> +template +struct __lce_ta<__A, __C, 0, ~uint64_t{0}, false> { using result_type = uint64_t; - [[nodiscard]] _CCCL_API static result_type next(result_type __x) noexcept + [[nodiscard]] _CCCL_API static constexpr result_type next(result_type __x) noexcept { - return __a * __x + __c; + return __A * __x + __C; } }; @@ -116,18 +120,18 @@ template struct __lce_ta<_Ap, _Cp, _Mp, ~uint32_t{0}, true> { using result_type = uint32_t; - [[nodiscard]] _CCCL_API static result_type next(result_type __x) noexcept + [[nodiscard]] _CCCL_API static constexpr result_type next(result_type __x) noexcept { - constexpr auto __a = static_cast(_Ap); - constexpr auto __c = static_cast(_Cp); - constexpr auto __m = static_cast(_Mp); + constexpr auto __A = static_cast(_Ap); + constexpr auto __C = static_cast(_Cp); + constexpr auto __M = static_cast(_Mp); // Schrage's algorithm - constexpr result_type __q = __m / __a; - constexpr result_type __r = __m % __a; - const result_type __t0 = __a * (__x % __q); + constexpr result_type __q = __M / __A; + constexpr result_type __r = __M % __A; + const result_type __t0 = __A * (__x % __q); const result_type __t1 = __r * (__x / __q); - __x = __t0 + (__t0 < __t1) * __m - __t1; - __x += __c - (__x >= __m - __c) * __m; + __x = __t0 + (__t0 < __t1) * __M - __t1; + __x += __C - (__x >= __M - __C) * __M; return __x; } }; @@ -136,16 +140,16 @@ template struct __lce_ta<_Ap, 0, _Mp, ~uint32_t{0}, true> { using result_type = uint32_t; - [[nodiscard]] _CCCL_API static result_type next(result_type __x) noexcept + [[nodiscard]] _CCCL_API static constexpr result_type next(result_type __x) noexcept { - constexpr result_type __a = static_cast(_Ap); - constexpr result_type __m = static_cast(_Mp); + constexpr result_type __A = static_cast(_Ap); + constexpr result_type __M = static_cast(_Mp); // Schrage's algorithm - constexpr result_type __q = __m / __a; - constexpr result_type __r = __m % __a; - const result_type __t0 = __a * (__x % __q); + constexpr result_type __q = __M / __A; + constexpr result_type __r = __M % __A; + const result_type __t0 = __A * (__x % __q); const result_type __t1 = __r * (__x / __q); - __x = __t0 + (__t0 < __t1) * __m - __t1; + __x = __t0 + (__t0 < __t1) * __M - __t1; return __x; } }; @@ -154,12 +158,12 @@ template struct __lce_ta<_Ap, _Cp, _Mp, ~uint32_t{0}, false> { using result_type = uint32_t; - [[nodiscard]] _CCCL_API static result_type next(result_type __x) noexcept + [[nodiscard]] _CCCL_API static constexpr result_type next(result_type __x) noexcept { - constexpr result_type __a = static_cast(_Ap); - constexpr result_type __c = static_cast(_Cp); - constexpr result_type __m = static_cast(_Mp); - return (__a * __x + __c) % __m; + constexpr result_type __A = static_cast(_Ap); + constexpr result_type __C = static_cast(_Cp); + constexpr result_type __M = static_cast(_Mp); + return (__A * __x + __C) % __M; } }; @@ -167,40 +171,30 @@ template struct __lce_ta<_Ap, _Cp, 0, ~uint32_t{0}, false> { using result_type = uint32_t; - [[nodiscard]] _CCCL_API static result_type next(result_type __x) noexcept + [[nodiscard]] _CCCL_API static constexpr result_type next(result_type __x) noexcept { - constexpr result_type __a = static_cast(_Ap); - constexpr result_type __c = static_cast(_Cp); - return __a * __x + __c; + constexpr result_type __A = static_cast(_Ap); + constexpr result_type __C = static_cast(_Cp); + return __A * __x + __C; } }; // 16 -template -struct __lce_ta<__a, __c, __m, static_cast(~0), __b> +template +struct __lce_ta<__A, __C, __M, static_cast(~0), __b> { using result_type = uint16_t; - [[nodiscard]] _CCCL_API static result_type next(result_type __x) noexcept + [[nodiscard]] _CCCL_API static constexpr result_type next(result_type __x) noexcept { - return static_cast(__lce_ta<__a, __c, __m, ~uint32_t{0}>::next(__x)); + return static_cast(__lce_ta<__A, __C, __M, ~uint32_t{0}>::next(__x)); } }; -template +template class _CCCL_TYPE_VISIBILITY_DEFAULT linear_congruential_engine; -#if 0 // Not Implemented -template -_CCCL_API basic_ostream<_CharT, _Traits>& -operator<<(basic_ostream<_CharT, _Traits>& __os, const linear_congruential_engine<_Up, _Ap, _Cp, _Np>&); - -template -_CCCL_API basic_istream<_CharT, _Traits>& -operator>>(basic_istream<_CharT, _Traits>& __is, linear_congruential_engine<_Up, _Ap, _Cp, _Np>& __x); -#endif // - -template +template class _CCCL_TYPE_VISIBILITY_DEFAULT linear_congruential_engine { public: @@ -208,23 +202,23 @@ class _CCCL_TYPE_VISIBILITY_DEFAULT linear_congruential_engine using result_type = _UIntType; private: - result_type __x_; + result_type __x_{}; static constexpr const result_type _Mp = result_type(~0); - static_assert(__m == 0 || __a < __m, "linear_congruential_engine invalid parameters"); - static_assert(__m == 0 || __c < __m, "linear_congruential_engine invalid parameters"); + static_assert(__M == 0 || __A < __M, "linear_congruential_engine invalid parameters"); + static_assert(__M == 0 || __C < __M, "linear_congruential_engine invalid parameters"); static_assert(is_unsigned_v<_UIntType>, "_UIntType must be uint32_t type"); public: - static constexpr const result_type _Min = __c == 0u ? 1u : 0u; - static constexpr const result_type _Max = __m - _UIntType(1u); + static constexpr const result_type _Min = __C == 0u ? 1u : 0u; + static constexpr const result_type _Max = __M - _UIntType(1u); static_assert(_Min < _Max, "linear_congruential_engine invalid parameters"); // engine characteristics - static constexpr const result_type multiplier = __a; - static constexpr const result_type increment = __c; - static constexpr const result_type modulus = __m; + static constexpr const result_type multiplier = __A; + static constexpr const result_type increment = __C; + static constexpr const result_type modulus = __M; [[nodiscard]] _CCCL_API static constexpr result_type min() noexcept { return _Min; @@ -236,157 +230,163 @@ class _CCCL_TYPE_VISIBILITY_DEFAULT linear_congruential_engine static constexpr const result_type default_seed = 1u; // constructors and seeding functions - _CCCL_API linear_congruential_engine() noexcept + _CCCL_API constexpr linear_congruential_engine() noexcept : linear_congruential_engine(default_seed) {} - _CCCL_API explicit linear_congruential_engine(result_type __s) noexcept + _CCCL_API explicit constexpr linear_congruential_engine(result_type __s) noexcept { seed(__s); } template , int> = 0> - _CCCL_API explicit linear_congruential_engine(_Sseq& __q) noexcept + _CCCL_API explicit constexpr linear_congruential_engine(_Sseq& __q) noexcept { seed(__q); } - _CCCL_API void seed(result_type __s = default_seed) + _CCCL_API constexpr void seed(result_type __s = default_seed) noexcept { - seed(integral_constant(), integral_constant(), __s); + seed(integral_constant(), integral_constant(), __s); } template , int> = 0> - _CCCL_API void seed(_Sseq& __q) noexcept + _CCCL_API constexpr void seed(_Sseq& __q) noexcept { __seed(__q, integral_constant 0x100000000ull))>()); + 1 + (__M == 0 ? (sizeof(result_type) * CHAR_BIT - 1) / 32 : (__M > 0x100000000ull))>()); } // generating functions - [[nodiscard]] _CCCL_API result_type operator()() noexcept + _CCCL_API constexpr result_type operator()() noexcept { - return __x_ = static_cast(__lce_ta<__a, __c, __m, _Mp>::next(__x_)); + return __x_ = static_cast(__lce_ta<__A, __C, __M, _Mp>::next(__x_)); } - _CCCL_API void discard(uint64_t __z) noexcept + + _CCCL_API constexpr void discard(uint64_t __z) noexcept { - for (; __z; --__z) + constexpr bool __can_overflow = (__A != 0 && __M != 0 && __M - 1 > (_Mp - __C) / __A); + // Fallback implementation + if constexpr (__can_overflow) { - (void) operator()(); + for (; __z; --__z) + { + (void) operator()(); + } + } + else + { + uint64_t __acc_mult = 1; + [[maybe_unused]] uint64_t __acc_plus = 0; + uint64_t __cur_mult = multiplier; + [[maybe_unused]] uint64_t __cur_plus = increment; + while (__z > 0) + { + if (__z & 1) + { + __acc_mult = (__acc_mult * __cur_mult) % modulus; + if constexpr (increment != 0) + { + __acc_plus = (__acc_plus * __cur_mult + __cur_plus) % modulus; + } + } + if constexpr (increment != 0) + { + __cur_plus = ((__cur_mult + 1) * __cur_plus) % modulus; + } + __cur_mult = (__cur_mult * __cur_mult) % modulus; + __z >>= 1; + } + __x_ = (__acc_mult * __x_ + __acc_plus) % modulus; } } - [[nodiscard]] _CCCL_API friend bool + [[nodiscard]] _CCCL_API friend constexpr bool operator==(const linear_congruential_engine& __x, const linear_congruential_engine& __y) noexcept { return __x.__x_ == __y.__x_; } - [[nodiscard]] _CCCL_API friend bool + [[nodiscard]] _CCCL_API friend constexpr bool operator!=(const linear_congruential_engine& __x, const linear_congruential_engine& __y) noexcept { return !(__x == __y); } +#if !_CCCL_COMPILER(NVRTC) + template + _CCCL_API friend ::std::basic_ostream<_CharT, _Traits>& + operator<<(::std::basic_ostream<_CharT, _Traits>& __os, const linear_congruential_engine& __e) + { + using _Ostream = ::std::basic_ostream<_CharT, _Traits>; + const typename _Ostream::fmtflags __flags = __os.flags(); + __os.flags(_Ostream::dec | _Ostream::left); + __os.fill(__os.widen(' ')); + __os.flags(__flags); + return __os << __e.__x_; + } + template + _CCCL_API friend ::std::basic_istream<_CharT, _Traits>& + operator>>(::std::basic_istream<_CharT, _Traits>& __is, linear_congruential_engine& __e) + { + using _Istream = ::std::basic_istream<_CharT, _Traits>; + const typename _Istream::fmtflags __flags = __is.flags(); + __is.flags(_Istream::dec | _Istream::skipws); + _UIntType __t; + __is >> __t; + if (!__is.fail()) + { + __e.__x_ = __t; + } + __is.flags(__flags); + return __is; + } +#endif // !_CCCL_COMPILER(NVRTC) + private: - _CCCL_API void seed(true_type, true_type, result_type __s) noexcept + _CCCL_API constexpr void seed(true_type, true_type, result_type __s) noexcept { __x_ = __s == 0 ? 1 : __s; } - _CCCL_API void seed(true_type, false_type, result_type __s) noexcept + _CCCL_API constexpr void seed(true_type, false_type, result_type __s) noexcept { __x_ = __s; } - _CCCL_API void seed(false_type, true_type, result_type __s) noexcept + _CCCL_API constexpr void seed(false_type, true_type, result_type __s) noexcept { - __x_ = __s % __m == 0 ? 1 : __s % __m; + __x_ = __s % __M == 0 ? 1 : __s % __M; } - _CCCL_API void seed(false_type, false_type, result_type __s) noexcept + _CCCL_API constexpr void seed(false_type, false_type, result_type __s) noexcept { - __x_ = __s % __m; + __x_ = __s % __M; } template - _CCCL_API void __seed(_Sseq& __q, integral_constant) noexcept; + _CCCL_API constexpr void __seed(_Sseq& __q, integral_constant) noexcept; template - _CCCL_API void __seed(_Sseq& __q, integral_constant) noexcept; - -#if 0 // Not Implemented - template - friend basic_ostream<_CharT, _Traits>& - operator<<(basic_ostream<_CharT, _Traits>& __os, const linear_congruential_engine<_Up, _Ap, _Cp, _Np>&); - - template - friend basic_istream<_CharT, _Traits>& - operator>>(basic_istream<_CharT, _Traits>& __is, linear_congruential_engine<_Up, _Ap, _Cp, _Np>& __x); -#endif // Not Implemented + _CCCL_API constexpr void __seed(_Sseq& __q, integral_constant) noexcept; }; -template -constexpr const typename linear_congruential_engine<_UIntType, __a, __c, __m>::result_type - linear_congruential_engine<_UIntType, __a, __c, __m>::multiplier; - -template -constexpr const typename linear_congruential_engine<_UIntType, __a, __c, __m>::result_type - linear_congruential_engine<_UIntType, __a, __c, __m>::increment; - -template -constexpr const typename linear_congruential_engine<_UIntType, __a, __c, __m>::result_type - linear_congruential_engine<_UIntType, __a, __c, __m>::modulus; - -template -constexpr const typename linear_congruential_engine<_UIntType, __a, __c, __m>::result_type - linear_congruential_engine<_UIntType, __a, __c, __m>::default_seed; - -template +template template -_CCCL_API void -linear_congruential_engine<_UIntType, __a, __c, __m>::__seed(_Sseq& __q, integral_constant) noexcept +_CCCL_API constexpr void +linear_congruential_engine<_UIntType, __A, __C, __M>::__seed(_Sseq& __q, integral_constant) noexcept { constexpr uint32_t __k = 1; - uint32_t __ar[__k + 3]; + uint32_t __ar[__k + 3] = {}; __q.generate(__ar, __ar + __k + 3); - result_type __s = static_cast(__ar[3] % __m); - __x_ = __c == 0 && __s == 0 ? result_type(1) : __s; + result_type __s = static_cast(__ar[3] % __M); + __x_ = __C == 0 && __s == 0 ? result_type(1) : __s; } -template +template template -_CCCL_API void -linear_congruential_engine<_UIntType, __a, __c, __m>::__seed(_Sseq& __q, integral_constant) noexcept +_CCCL_API constexpr void +linear_congruential_engine<_UIntType, __A, __C, __M>::__seed(_Sseq& __q, integral_constant) noexcept { constexpr uint32_t __k = 2; - uint32_t __ar[__k + 3]; + uint32_t __ar[__k + 3] = {}; __q.generate(__ar, __ar + __k + 3); - result_type __s = static_cast((__ar[3] + ((uint64_t) __ar[4] << 32)) % __m); - __x_ = __c == 0 && __s == 0 ? result_type(1) : __s; -} - -#if 0 // Not Implemented -template -_CCCL_API basic_ostream<_CharT, _Traits>& -operator<<(basic_ostream<_CharT, _Traits>& __os, const linear_congruential_engine<_UIntType, __a, __c, __m>& __x) -{ - __save_flags<_CharT, _Traits> __lx(__os); - using _Ostream = basic_ostream<_CharT, _Traits>; - __os.flags(_Ostream::dec | _Ostream::left); - __os.fill(__os.widen(' ')); - return __os << __x.__x_; -} - -template -_CCCL_API basic_istream<_CharT, _Traits>& -operator>>(basic_istream<_CharT, _Traits>& __is, linear_congruential_engine<_UIntType, __a, __c, __m>& __x) -{ - __save_flags<_CharT, _Traits> __lx(__is); - using _Istream = basic_istream<_CharT, _Traits>; - __is.flags(_Istream::dec | _Istream::skipws); - _UIntType __t; - __is >> __t; - if (!__is.fail()) - { - __x.__x_ = __t; - } - return __is; + result_type __s = static_cast((__ar[3] + ((uint64_t) __ar[4] << 32)) % __M); + __x_ = __C == 0 && __s == 0 ? result_type(1) : __s; } -#endif // Not Implemented using minstd_rand0 = linear_congruential_engine; using minstd_rand = linear_congruential_engine; diff --git a/libcudacxx/test/libcudacxx/std/random/engine/lcg.pass.cpp b/libcudacxx/test/libcudacxx/std/random/engine/lcg.pass.cpp new file mode 100644 index 00000000000..eee2f4411e4 --- /dev/null +++ b/libcudacxx/test/libcudacxx/std/random/engine/lcg.pass.cpp @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// +// Part of libcu++, the C++ Standard Library for your entire system, +// under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +// +//===----------------------------------------------------------------------===// + +#include + +#include "test_engine.h" + +__host__ __device__ void test() +{ + test_engine(); + test_engine(); +} + +int main(int, char**) +{ + test(); + return 0; +}