Skip to content

Commit

Permalink
Using raft::resources in rsvd (#1543)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #1543
  • Loading branch information
cjnolet authored May 23, 2023
1 parent a9dd440 commit 42c9c18
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
9 changes: 5 additions & 4 deletions cpp/include/raft/linalg/detail/rsvd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/cusolver_dn_handle.hpp>
#include <raft/linalg/eig.cuh>
#include <raft/linalg/gemm.cuh>
Expand All @@ -38,7 +39,7 @@ namespace linalg {
namespace detail {

template <typename math_t>
void randomized_svd(const raft::device_resources& handle,
void randomized_svd(const raft::resources& handle,
const math_t* in,
std::size_t n_rows,
std::size_t n_cols,
Expand All @@ -61,8 +62,8 @@ void randomized_svd(const raft::device_resources& handle,
#if CUDART_VERSION < 11050
RAFT_EXPECTS(gen_U && gen_V, "not computing U or V is not supported in CUDA version < 11.5");
#endif
cudaStream_t stream = handle.get_stream();
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
cudaStream_t stream = resource::get_cuda_stream(handle);
cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle);

char jobu = gen_U ? 'S' : 'N';
char jobv = gen_V ? 'S' : 'N';
Expand Down Expand Up @@ -123,7 +124,7 @@ void randomized_svd(const raft::device_resources& handle,

int dev_info;
raft::update_host(&dev_info, devInfo.data_handle(), 1, stream);
handle.sync_stream(stream);
resource::sync_stream(handle);
ASSERT(dev_info == 0, "rsvd.cuh: Invalid parameter encountered.");
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/linalg/rsvd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ void rsvd_perc_symmetric_jacobi(Args... args)
* @param[in] niters: Number of iteration of power method. (2 is recommended)
*/
template <typename math_t, typename idx_t>
void randomized_svd(const raft::device_resources& handle,
void randomized_svd(const raft::resources& handle,
raft::device_matrix_view<const math_t, idx_t, raft::col_major> in,
raft::device_vector_view<math_t, idx_t> S,
std::optional<raft::device_matrix_view<math_t, idx_t, raft::col_major>> U,
Expand Down Expand Up @@ -859,7 +859,7 @@ void randomized_svd(const raft::device_resources& handle,
* Please see above for documentation of `randomized_svd`.
*/
template <typename math_t, typename idx_t, typename opt_u_vec_t, typename opt_v_vec_t>
void randomized_svd(const raft::device_resources& handle,
void randomized_svd(const raft::resources& handle,
raft::device_matrix_view<const math_t, idx_t, raft::col_major> in,
raft::device_vector_view<math_t, idx_t> S,
opt_u_vec_t&& U,
Expand Down
1 change: 1 addition & 0 deletions cpp/test/linalg/randomized_svd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "../test_utils.cuh"
#include <gtest/gtest.h>
#include <raft/core/device_resources.hpp>
#include <raft/linalg/rsvd.cuh>
#include <raft/linalg/svd.cuh>
#include <raft/matrix/diagonal.cuh>
Expand Down

0 comments on commit 42c9c18

Please sign in to comment.