diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 9c2cea6b66..422b19c0e5 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -38,7 +39,7 @@ namespace linalg { namespace detail { template -void randomized_svd(const raft::device_resources& handle, +void randomized_svd(const raft::resources& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -61,8 +62,8 @@ void randomized_svd(const raft::device_resources& handle, #if CUDART_VERSION < 11050 RAFT_EXPECTS(gen_U && gen_V, "not computing U or V is not supported in CUDA version < 11.5"); #endif - cudaStream_t stream = handle.get_stream(); - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cudaStream_t stream = resource::get_cuda_stream(handle); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); char jobu = gen_U ? 'S' : 'N'; char jobv = gen_V ? 'S' : 'N'; @@ -123,7 +124,7 @@ void randomized_svd(const raft::device_resources& handle, int dev_info; raft::update_host(&dev_info, devInfo.data_handle(), 1, stream); - handle.sync_stream(stream); + resource::sync_stream(handle); ASSERT(dev_info == 0, "rsvd.cuh: Invalid parameter encountered."); } diff --git a/cpp/include/raft/linalg/rsvd.cuh b/cpp/include/raft/linalg/rsvd.cuh index 8037611a54..2dece5b957 100644 --- a/cpp/include/raft/linalg/rsvd.cuh +++ b/cpp/include/raft/linalg/rsvd.cuh @@ -814,7 +814,7 @@ void rsvd_perc_symmetric_jacobi(Args... args) * @param[in] niters: Number of iteration of power method. (2 is recommended) */ template -void randomized_svd(const raft::device_resources& handle, +void randomized_svd(const raft::resources& handle, raft::device_matrix_view in, raft::device_vector_view S, std::optional> U, @@ -859,7 +859,7 @@ void randomized_svd(const raft::device_resources& handle, * Please see above for documentation of `randomized_svd`. */ template -void randomized_svd(const raft::device_resources& handle, +void randomized_svd(const raft::resources& handle, raft::device_matrix_view in, raft::device_vector_view S, opt_u_vec_t&& U, diff --git a/cpp/test/linalg/randomized_svd.cu b/cpp/test/linalg/randomized_svd.cu index 2d55fd7579..9e1d3df6dc 100644 --- a/cpp/test/linalg/randomized_svd.cu +++ b/cpp/test/linalg/randomized_svd.cu @@ -16,6 +16,7 @@ #include "../test_utils.cuh" #include +#include #include #include #include