diff --git a/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp b/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp index 3eff920dd8..79fd869083 100644 --- a/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp +++ b/cpp/include/raft/linalg/detail/cusolver_wrappers.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -693,6 +693,119 @@ inline cusolverStatus_t CUSOLVERAPI cusolverDngesvdj( // NOLINT return cusolverDnDgesvdj( handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, info, params); } + +#if CUDART_VERSION >= 11010 +template +cusolverStatus_t cusolverDnxgesvdr_bufferSize( // NOLINT + cusolverDnHandle_t handle, + signed char jobu, + signed char jobv, + int64_t m, + int64_t n, + int64_t k, + int64_t p, + int64_t niters, + const T* a, + int64_t lda, + const T* Srand, + const T* Urand, + int64_t ldUrand, + const T* Vrand, + int64_t ldVrand, + size_t* workspaceInBytesOnDevice, + size_t* workspaceInBytesOnHost, + cudaStream_t stream) +{ + RAFT_EXPECTS(std::is_floating_point_v, "Unsupported data type"); + cudaDataType dataType = std::is_same_v ? CUDA_R_32F : CUDA_R_64F; + RAFT_CUSOLVER_TRY(cusolverDnSetStream(handle, stream)); + cusolverDnParams_t dn_params = nullptr; + RAFT_CUSOLVER_TRY(cusolverDnCreateParams(&dn_params)); + auto result = cusolverDnXgesvdr_bufferSize(handle, + dn_params, + jobu, + jobv, + m, + n, + k, + p, + niters, + dataType, + a, + lda, + dataType, + Srand, + dataType, + Urand, + ldUrand, + dataType, + Vrand, + ldVrand, + dataType, + workspaceInBytesOnDevice, + workspaceInBytesOnHost); + RAFT_CUSOLVER_TRY(cusolverDnDestroyParams(dn_params)); + return result; +} +template +cusolverStatus_t cusolverDnxgesvdr( // NOLINT + cusolverDnHandle_t handle, + signed char jobu, + signed char jobv, + int64_t m, + int64_t n, + int64_t k, + int64_t p, + int64_t niters, + T* a, + int64_t lda, + T* Srand, + T* Urand, + int64_t ldUrand, + T* Vrand, + int64_t ldVrand, + void* bufferOnDevice, + size_t workspaceInBytesOnDevice, + void* bufferOnHost, + size_t workspaceInBytesOnHost, + int* d_info, + cudaStream_t stream) +{ + cudaDataType dataType = std::is_same_v ? CUDA_R_32F : CUDA_R_64F; + RAFT_CUSOLVER_TRY(cusolverDnSetStream(handle, stream)); + cusolverDnParams_t dn_params = nullptr; + RAFT_CUSOLVER_TRY(cusolverDnCreateParams(&dn_params)); + auto result = cusolverDnXgesvdr(handle, + dn_params, + jobu, + jobv, + m, + n, + k, + p, + niters, + dataType, + a, + lda, + dataType, + Srand, + dataType, + Urand, + ldUrand, + dataType, + Vrand, + ldVrand, + dataType, + bufferOnDevice, + workspaceInBytesOnDevice, + bufferOnHost, + workspaceInBytesOnHost, + d_info); + RAFT_CUSOLVER_TRY(cusolverDnDestroyParams(dn_params)); + return result; +} +#endif // CUDART_VERSION >= 11010 + /** @} */ /** diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 50cb339ea1..9c2cea6b66 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -37,6 +37,96 @@ namespace raft { namespace linalg { namespace detail { +template +void randomized_svd(const raft::device_resources& handle, + const math_t* in, + std::size_t n_rows, + std::size_t n_cols, + std::size_t k, + std::size_t p, + std::size_t niters, + math_t* S, + math_t* U, + math_t* V, + bool gen_U, + bool gen_V) +{ + common::nvtx::range fun_scope( + "raft::linalg::randomized_svd(%d, %d, %d)", n_rows, n_cols, k); + + RAFT_EXPECTS(k < std::min(n_rows, n_cols), "k must be < min(n_rows, n_cols)"); + RAFT_EXPECTS((k + p) < std::min(n_rows, n_cols), "k + p must be < min(n_rows, n_cols)"); + RAFT_EXPECTS(!gen_U || (U != nullptr), "computation of U vector requested but found nullptr"); + RAFT_EXPECTS(!gen_V || (V != nullptr), "computation of V vector requested but found nullptr"); +#if CUDART_VERSION < 11050 + RAFT_EXPECTS(gen_U && gen_V, "not computing U or V is not supported in CUDA version < 11.5"); +#endif + cudaStream_t stream = handle.get_stream(); + cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + + char jobu = gen_U ? 'S' : 'N'; + char jobv = gen_V ? 'S' : 'N'; + + auto lda = n_rows; + auto ldu = n_rows; + auto ldv = n_cols; + auto* in_ptr = const_cast(in); + + size_t workspaceDevice = 0; + size_t workspaceHost = 0; + RAFT_CUSOLVER_TRY(cusolverDnxgesvdr_bufferSize(cusolverH, + jobu, + jobv, + n_rows, + n_cols, + k, + p, + niters, + in_ptr, + lda, + S, + U, + ldu, + V, + ldv, + &workspaceDevice, + &workspaceHost, + stream)); + + auto d_workspace = raft::make_device_vector(handle, workspaceDevice); + auto h_workspace = raft::make_host_vector(workspaceHost); + auto devInfo = raft::make_device_scalar(handle, 0); + + RAFT_CUSOLVER_TRY(cusolverDnxgesvdr(cusolverH, + jobu, + jobv, + n_rows, + n_cols, + k, + p, + niters, + in_ptr, + lda, + S, + U, + ldu, + V, + ldv, + d_workspace.data_handle(), + workspaceDevice, + h_workspace.data_handle(), + workspaceHost, + devInfo.data_handle(), + stream)); + + RAFT_CUDA_TRY(cudaGetLastError()); + + int dev_info; + raft::update_host(&dev_info, devInfo.data_handle(), 1, stream); + handle.sync_stream(stream); + ASSERT(dev_info == 0, "rsvd.cuh: Invalid parameter encountered."); +} + /** * @brief randomized singular value decomposition (RSVD) on the column major * float type input matrix (Jacobi-based), by specifying no. of PCs and diff --git a/cpp/include/raft/linalg/rsvd.cuh b/cpp/include/raft/linalg/rsvd.cuh index 4a6c058061..8037611a54 100644 --- a/cpp/include/raft/linalg/rsvd.cuh +++ b/cpp/include/raft/linalg/rsvd.cuh @@ -19,9 +19,8 @@ #pragma once #include "detail/rsvd.cuh" -#include - #include +#include namespace raft { namespace linalg { @@ -176,16 +175,20 @@ void rsvd_fixed_rank(raft::resources const& handle, std::forward(U_in); std::optional> V = std::forward(V_in); + ValueType* U_ptr = nullptr; + ValueType* V_ptr = nullptr; if (U) { RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), "Number of columns in U should be equal to length of S"); + U_ptr = U.value().data_handle(); } if (V) { RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), "Number of rows in V should be equal to length of S"); + V_ptr = V.value().data_handle(); } rsvdFixedRank(handle, @@ -193,8 +196,8 @@ void rsvd_fixed_rank(raft::resources const& handle, M.extent(0), M.extent(1), S_vec.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + U_ptr, + V_ptr, S_vec.extent(0), p, false, @@ -251,13 +254,17 @@ void rsvd_fixed_rank_symmetric( std::forward(U_in); std::optional> V = std::forward(V_in); + ValueType* U_ptr = nullptr; + ValueType* V_ptr = nullptr; if (U) { + U_ptr = U.value().data_handle(); RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), "Number of columns in U should be equal to length of S"); } if (V) { + V_ptr = V.value().data_handle(); RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), "Number of rows in V should be equal to length of S"); @@ -268,8 +275,8 @@ void rsvd_fixed_rank_symmetric( M.extent(0), M.extent(1), S_vec.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + U_ptr, + V_ptr, S_vec.extent(0), p, true, @@ -329,13 +336,17 @@ void rsvd_fixed_rank_jacobi(raft::resources const& handle, std::forward(U_in); std::optional> V = std::forward(V_in); + ValueType* U_ptr = nullptr; + ValueType* V_ptr = nullptr; if (U) { + U_ptr = U.value().data_handle(); RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), "Number of columns in U should be equal to length of S"); } if (V) { + V_ptr = V.value().data_handle(); RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), "Number of rows in V should be equal to length of S"); @@ -346,8 +357,8 @@ void rsvd_fixed_rank_jacobi(raft::resources const& handle, M.extent(0), M.extent(1), S_vec.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + U_ptr, + V_ptr, S_vec.extent(0), p, false, @@ -408,13 +419,17 @@ void rsvd_fixed_rank_symmetric_jacobi( std::forward(U_in); std::optional> V = std::forward(V_in); + ValueType* U_ptr = nullptr; + ValueType* V_ptr = nullptr; if (U) { + U_ptr = U.value().data_handle(); RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), "Number of columns in U should be equal to length of S"); } if (V) { + V_ptr = V.value().data_handle(); RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), "Number of rows in V should be equal to length of S"); @@ -425,8 +440,8 @@ void rsvd_fixed_rank_symmetric_jacobi( M.extent(0), M.extent(1), S_vec.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + U_ptr, + V_ptr, S_vec.extent(0), p, true, @@ -484,13 +499,17 @@ void rsvd_perc(raft::resources const& handle, std::forward(U_in); std::optional> V = std::forward(V_in); + ValueType* U_ptr = nullptr; + ValueType* V_ptr = nullptr; if (U) { + U_ptr = U.value().data_handle(); RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), "Number of columns in U should be equal to length of S"); } if (V) { + V_ptr = V.value().data_handle(); RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), "Number of rows in V should be equal to length of S"); @@ -501,8 +520,8 @@ void rsvd_perc(raft::resources const& handle, M.extent(0), M.extent(1), S_vec.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + U_ptr, + V_ptr, PC_perc, UpS_perc, false, @@ -560,13 +579,17 @@ void rsvd_perc_symmetric(raft::resources const& handle, std::forward(U_in); std::optional> V = std::forward(V_in); + ValueType* U_ptr = nullptr; + ValueType* V_ptr = nullptr; if (U) { + U_ptr = U.value().data_handle(); RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), "Number of columns in U should be equal to length of S"); } if (V) { + V_ptr = V.value().data_handle(); RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), "Number of rows in V should be equal to length of S"); @@ -577,8 +600,8 @@ void rsvd_perc_symmetric(raft::resources const& handle, M.extent(0), M.extent(1), S_vec.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + U_ptr, + V_ptr, PC_perc, UpS_perc, true, @@ -640,13 +663,17 @@ void rsvd_perc_jacobi(raft::resources const& handle, std::forward(U_in); std::optional> V = std::forward(V_in); + ValueType* U_ptr = nullptr; + ValueType* V_ptr = nullptr; if (U) { + U_ptr = U.value().data_handle(); RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), "Number of columns in U should be equal to length of S"); } if (V) { + V_ptr = V.value().data_handle(); RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), "Number of rows in V should be equal to length of S"); @@ -657,8 +684,8 @@ void rsvd_perc_jacobi(raft::resources const& handle, M.extent(0), M.extent(1), S_vec.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + U_ptr, + V_ptr, PC_perc, UpS_perc, false, @@ -721,13 +748,17 @@ void rsvd_perc_symmetric_jacobi( std::forward(U_in); std::optional> V = std::forward(V_in); + ValueType* U_ptr = nullptr; + ValueType* V_ptr = nullptr; if (U) { + U_ptr = U.value().data_handle(); RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), "Number of columns in U should be equal to length of S"); } if (V) { + V_ptr = V.value().data_handle(); RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), "Number of rows in V should be equal to length of S"); @@ -738,8 +769,8 @@ void rsvd_perc_symmetric_jacobi( M.extent(0), M.extent(1), S_vec.data_handle(), - U.value().data_handle(), - V.value().data_handle(), + U_ptr, + V_ptr, PC_perc, UpS_perc, true, @@ -764,6 +795,85 @@ void rsvd_perc_symmetric_jacobi(Args... args) rsvd_perc_symmetric_jacobi(std::forward(args)..., std::nullopt, std::nullopt); } +/** + * @brief randomized singular value decomposition (RSVD) using cusolver + * @tparam math_t the data type + * @tparam idx_t index type + * @param[in] handle: raft handle + * @param[in] in: input matrix in col-major format. + * Warning: the content of this matrix is modified by the cuSOLVER routines. + * [dim = n_rows * n_cols] + * @param[out] S: array of singular values of input matrix. The rank k must be less than + * min(m,n). [dim = k] + * @param[out] U: optional left singular values of input matrix. Use std::nullopt to not + * generate it. [dim = n_rows * k] + * @param[out] V: optional right singular values of input matrix. Use std::nullopt to not + * generate it. [dim = k * n_cols] + * @param[in] p: Oversampling. The size of the subspace will be (k + p). (k+p) is less than + * min(m,n). (Recommended to be at least 2*k) + * @param[in] niters: Number of iteration of power method. (2 is recommended) + */ +template +void randomized_svd(const raft::device_resources& handle, + raft::device_matrix_view in, + raft::device_vector_view S, + std::optional> U, + std::optional> V, + std::size_t p, + std::size_t niters) +{ + auto k = S.extent(0); + math_t* left_sing_vecs_ptr = nullptr; + math_t* right_sing_vecs_ptr = nullptr; + auto gen_U = U.has_value(); + auto gen_V = V.has_value(); + if (gen_U) { + RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && k == U.value().extent(1), + "U should have dimensions n_rows * k"); + left_sing_vecs_ptr = U.value().data_handle(); + } + if (gen_V) { + RAFT_EXPECTS(k == V.value().extent(0) && in.extent(1) == V.value().extent(1), + "V should have dimensions k * n_cols"); + right_sing_vecs_ptr = V.value().data_handle(); + } + detail::randomized_svd(handle, + in.data_handle(), + in.extent(0), + in.extent(1), + k, + p, + niters, + S.data_handle(), + left_sing_vecs_ptr, + right_sing_vecs_ptr, + gen_U, + gen_V); +} + +/** + * @brief Overload of `randomized_svd` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for the optional arguments. + * + * Please see above for documentation of `randomized_svd`. + */ +template +void randomized_svd(const raft::device_resources& handle, + raft::device_matrix_view in, + raft::device_vector_view S, + opt_u_vec_t&& U, + opt_v_vec_t&& V, + std::size_t p, + std::size_t niters) +{ + std::optional> opt_u = + std::forward(U); + std::optional> opt_v = + std::forward(V); + randomized_svd(handle, in, S, opt_u, opt_v, p, niters); +} + /** @} */ // end of group rsvd }; // end namespace linalg diff --git a/cpp/include/raft/linalg/transpose.cuh b/cpp/include/raft/linalg/transpose.cuh index 0fe752347d..afe1962223 100644 --- a/cpp/include/raft/linalg/transpose.cuh +++ b/cpp/include/raft/linalg/transpose.cuh @@ -74,7 +74,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream) * * @param[in] handle raft handle for managing expensive cuda resources. * @param[in] in Input matrix. - * @param[out] out Output matirx, storage is pre-allocated by caller. + * @param[out] out Output matrix, storage is pre-allocated by caller. */ template auto transpose(raft::resources const& handle, diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 1b4d269d1b..871869102c 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -215,6 +215,7 @@ if(BUILD_TESTS) test/linalg/norm.cu test/linalg/normalize.cu test/linalg/power.cu + test/linalg/randomized_svd.cu test/linalg/reduce.cu test/linalg/reduce_cols_by_key.cu test/linalg/reduce_rows_by_key.cu diff --git a/cpp/test/linalg/randomized_svd.cu b/cpp/test/linalg/randomized_svd.cu new file mode 100644 index 0000000000..2d55fd7579 --- /dev/null +++ b/cpp/test/linalg/randomized_svd.cu @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace linalg { + +template +struct randomized_svdInputs { + T tolerance; + int n_row; + int n_col; + int k; + unsigned long long int seed; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const randomized_svdInputs& dims) +{ + return os; +} + +template +class randomized_svdTest : public ::testing::TestWithParam> { + public: + randomized_svdTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + data(params.n_row * params.n_col, stream), + reconst(params.n_row * params.n_col, stream), + left_eig_vectors_act(params.n_row * params.k, stream), + right_eig_vectors_act(params.k * params.n_col, stream), + sing_vals_act(params.k, stream), + left_eig_vectors_ref(params.n_row * params.n_col, stream), + right_eig_vectors_ref(params.n_col * params.n_col, stream), + sing_vals_ref(params.k, stream) + { + } + + protected: + void basicTest() + { + int len = params.n_row * params.n_col; + ASSERT(params.n_row == 5 && params.n_col == 5, "This test only supports nrows=5 && ncols=5!"); + T data_h[] = {0.76420743, 0.61411544, 0.81724151, 0.42040879, 0.03446089, + 0.03697287, 0.85962444, 0.67584086, 0.45594666, 0.02074835, + 0.42018265, 0.39204509, 0.12657948, 0.90250559, 0.23076218, + 0.50339844, 0.92974961, 0.21213988, 0.63962457, 0.58124562, + 0.58325673, 0.11589871, 0.39831112, 0.21492685, 0.00540355}; + raft::update_device(data.data(), data_h, len, stream); + + T left_eig_vectors_ref_h[] = {0.42823088, + 0.59131151, + 0.4220887, + 0.50441194, + 0.18541506, + 0.27047497, + -0.17195579, + 0.69362791, + -0.43253894, + -0.47860724}; + + T right_eig_vectors_ref_h[] = {0.53005494, + 0.44104121, + 0.40720732, + 0.54337293, + 0.25189773, + 0.5789401, + 0.15264214, + -0.45215699, + -0.53184873, + 0.3927082}; + + T sing_vals_ref_h[] = {2.36539241, 0.81117785, 0.68562255, 0.41390509, 0.01519322}; + + raft::update_device( + left_eig_vectors_ref.data(), left_eig_vectors_ref_h, params.n_row * params.k, stream); + raft::update_device( + right_eig_vectors_ref.data(), right_eig_vectors_ref_h, params.k * params.n_col, stream); + raft::update_device(sing_vals_ref.data(), sing_vals_ref_h, params.k, stream); + + randomized_svd(handle, + raft::make_device_matrix_view( + data.data(), params.n_row, params.n_col), + raft::make_device_vector_view(sing_vals_act.data(), params.k), + std::make_optional(raft::make_device_matrix_view( + left_eig_vectors_act.data(), params.n_row, params.k)), + std::make_optional(raft::make_device_matrix_view( + right_eig_vectors_act.data(), params.k, params.n_col)), + 2, + 2); + handle.sync_stream(stream); + } + + void apiTest() + { + int len = params.n_row * params.n_col; + ASSERT(params.n_row == 5 && params.n_col == 5, "This test only supports nrows=5 && ncols=5!"); + T data_h[] = {0.76420743, 0.61411544, 0.81724151, 0.42040879, 0.03446089, + 0.03697287, 0.85962444, 0.67584086, 0.45594666, 0.02074835, + 0.42018265, 0.39204509, 0.12657948, 0.90250559, 0.23076218, + 0.50339844, 0.92974961, 0.21213988, 0.63962457, 0.58124562, + 0.58325673, 0.11589871, 0.39831112, 0.21492685, 0.00540355}; + raft::update_device(data.data(), data_h, len, stream); + + T left_eig_vectors_ref_h[] = {0.42823088, + 0.59131151, + 0.4220887, + 0.50441194, + 0.18541506, + 0.27047497, + -0.17195579, + 0.69362791, + -0.43253894, + -0.47860724}; + + T right_eig_vectors_ref_h[] = {0.53005494, + 0.44104121, + 0.40720732, + 0.54337293, + 0.25189773, + 0.5789401, + 0.15264214, + -0.45215699, + -0.53184873, + 0.3927082}; + + T sing_vals_ref_h[] = {2.36539241, 0.81117785, 0.68562255, 0.41390509, 0.01519322}; + + raft::update_device( + left_eig_vectors_ref.data(), left_eig_vectors_ref_h, params.n_row * params.k, stream); + raft::update_device( + right_eig_vectors_ref.data(), right_eig_vectors_ref_h, params.k * params.n_col, stream); + raft::update_device(sing_vals_ref.data(), sing_vals_ref_h, params.k, stream); + randomized_svd(handle, + raft::make_device_matrix_view( + data.data(), params.n_row, params.n_col), + raft::make_device_vector_view(sing_vals_act.data(), params.k), + std::nullopt, + std::make_optional(raft::make_device_matrix_view( + right_eig_vectors_act.data(), params.k, params.n_col)), + 2, + 2); + randomized_svd(handle, + raft::make_device_matrix_view( + data.data(), params.n_row, params.n_col), + raft::make_device_vector_view(sing_vals_act.data(), params.k), + std::make_optional(raft::make_device_matrix_view( + left_eig_vectors_act.data(), params.n_row, params.k)), + std::nullopt, + 2, + 2); + randomized_svd(handle, + raft::make_device_matrix_view( + data.data(), params.n_row, params.n_col), + raft::make_device_vector_view(sing_vals_act.data(), params.k), + std::nullopt, + std::nullopt, + 2, + 2); + handle.sync_stream(stream); + } + + void SetUp() override + { + int major = 0; + int minor = 0; + cusolverGetProperty(MAJOR_VERSION, &major); + cusolverGetProperty(MINOR_VERSION, &minor); + int cusolv_version = major * 1000 + minor * 10; + if (cusolv_version >= 11050) apiTest(); + basicTest(); + } + + protected: + raft::device_resources handle; + cudaStream_t stream; + + randomized_svdInputs params; + rmm::device_uvector data, left_eig_vectors_act, right_eig_vectors_act, sing_vals_act, + left_eig_vectors_ref, right_eig_vectors_ref, sing_vals_ref, reconst; +}; + +const std::vector> inputsf1 = {{0.0001f, 5, 5, 2, 1234ULL}}; +const std::vector> inputsd1 = {{0.0001, 5, 5, 2, 1234ULL}}; + +typedef randomized_svdTest randomized_svdTestF; +TEST_P(randomized_svdTestF, Result) +{ + ASSERT_TRUE(raft::devArrMatch(sing_vals_ref.data(), + sing_vals_act.data(), + params.k, + raft::CompareApproxAbs(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(left_eig_vectors_ref.data(), + left_eig_vectors_act.data(), + params.n_row * params.k, + raft::CompareApproxAbs(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(right_eig_vectors_ref.data(), + right_eig_vectors_act.data(), + params.k * params.n_col, + raft::CompareApproxAbs(params.tolerance))); +} + +typedef randomized_svdTest randomized_svdTestD; +TEST_P(randomized_svdTestD, Result) +{ + ASSERT_TRUE(raft::devArrMatch(sing_vals_ref.data(), + sing_vals_act.data(), + params.k, + raft::CompareApproxAbs(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(left_eig_vectors_ref.data(), + left_eig_vectors_act.data(), + params.n_row * params.k, + raft::CompareApproxAbs(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(right_eig_vectors_ref.data(), + right_eig_vectors_act.data(), + params.k * params.n_col, + raft::CompareApproxAbs(params.tolerance))); +} + +INSTANTIATE_TEST_SUITE_P(randomized_svdTests1, randomized_svdTestF, ::testing::ValuesIn(inputsf1)); +INSTANTIATE_TEST_SUITE_P(randomized_svdTests1, randomized_svdTestD, ::testing::ValuesIn(inputsd1)); +} // end namespace linalg +} // end namespace raft