Skip to content

Commit

Permalink
[FEA] Add randomized svd from cusolver (#1000)
Browse files Browse the repository at this point in the history
Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1000
  • Loading branch information
lowener authored May 22, 2023
1 parent 87597a8 commit a9dd440
Show file tree
Hide file tree
Showing 6 changed files with 579 additions and 20 deletions.
115 changes: 114 additions & 1 deletion cpp/include/raft/linalg/detail/cusolver_wrappers.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -693,6 +693,119 @@ inline cusolverStatus_t CUSOLVERAPI cusolverDngesvdj( // NOLINT
return cusolverDnDgesvdj(
handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, info, params);
}

#if CUDART_VERSION >= 11010
template <typename T>
cusolverStatus_t cusolverDnxgesvdr_bufferSize( // NOLINT
cusolverDnHandle_t handle,
signed char jobu,
signed char jobv,
int64_t m,
int64_t n,
int64_t k,
int64_t p,
int64_t niters,
const T* a,
int64_t lda,
const T* Srand,
const T* Urand,
int64_t ldUrand,
const T* Vrand,
int64_t ldVrand,
size_t* workspaceInBytesOnDevice,
size_t* workspaceInBytesOnHost,
cudaStream_t stream)
{
RAFT_EXPECTS(std::is_floating_point_v<T>, "Unsupported data type");
cudaDataType dataType = std::is_same_v<T, float> ? CUDA_R_32F : CUDA_R_64F;
RAFT_CUSOLVER_TRY(cusolverDnSetStream(handle, stream));
cusolverDnParams_t dn_params = nullptr;
RAFT_CUSOLVER_TRY(cusolverDnCreateParams(&dn_params));
auto result = cusolverDnXgesvdr_bufferSize(handle,
dn_params,
jobu,
jobv,
m,
n,
k,
p,
niters,
dataType,
a,
lda,
dataType,
Srand,
dataType,
Urand,
ldUrand,
dataType,
Vrand,
ldVrand,
dataType,
workspaceInBytesOnDevice,
workspaceInBytesOnHost);
RAFT_CUSOLVER_TRY(cusolverDnDestroyParams(dn_params));
return result;
}
template <typename T>
cusolverStatus_t cusolverDnxgesvdr( // NOLINT
cusolverDnHandle_t handle,
signed char jobu,
signed char jobv,
int64_t m,
int64_t n,
int64_t k,
int64_t p,
int64_t niters,
T* a,
int64_t lda,
T* Srand,
T* Urand,
int64_t ldUrand,
T* Vrand,
int64_t ldVrand,
void* bufferOnDevice,
size_t workspaceInBytesOnDevice,
void* bufferOnHost,
size_t workspaceInBytesOnHost,
int* d_info,
cudaStream_t stream)
{
cudaDataType dataType = std::is_same_v<T, float> ? CUDA_R_32F : CUDA_R_64F;
RAFT_CUSOLVER_TRY(cusolverDnSetStream(handle, stream));
cusolverDnParams_t dn_params = nullptr;
RAFT_CUSOLVER_TRY(cusolverDnCreateParams(&dn_params));
auto result = cusolverDnXgesvdr(handle,
dn_params,
jobu,
jobv,
m,
n,
k,
p,
niters,
dataType,
a,
lda,
dataType,
Srand,
dataType,
Urand,
ldUrand,
dataType,
Vrand,
ldVrand,
dataType,
bufferOnDevice,
workspaceInBytesOnDevice,
bufferOnHost,
workspaceInBytesOnHost,
d_info);
RAFT_CUSOLVER_TRY(cusolverDnDestroyParams(dn_params));
return result;
}
#endif // CUDART_VERSION >= 11010

/** @} */

/**
Expand Down
90 changes: 90 additions & 0 deletions cpp/include/raft/linalg/detail/rsvd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,96 @@ namespace raft {
namespace linalg {
namespace detail {

template <typename math_t>
void randomized_svd(const raft::device_resources& handle,
const math_t* in,
std::size_t n_rows,
std::size_t n_cols,
std::size_t k,
std::size_t p,
std::size_t niters,
math_t* S,
math_t* U,
math_t* V,
bool gen_U,
bool gen_V)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"raft::linalg::randomized_svd(%d, %d, %d)", n_rows, n_cols, k);

RAFT_EXPECTS(k < std::min(n_rows, n_cols), "k must be < min(n_rows, n_cols)");
RAFT_EXPECTS((k + p) < std::min(n_rows, n_cols), "k + p must be < min(n_rows, n_cols)");
RAFT_EXPECTS(!gen_U || (U != nullptr), "computation of U vector requested but found nullptr");
RAFT_EXPECTS(!gen_V || (V != nullptr), "computation of V vector requested but found nullptr");
#if CUDART_VERSION < 11050
RAFT_EXPECTS(gen_U && gen_V, "not computing U or V is not supported in CUDA version < 11.5");
#endif
cudaStream_t stream = handle.get_stream();
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();

char jobu = gen_U ? 'S' : 'N';
char jobv = gen_V ? 'S' : 'N';

auto lda = n_rows;
auto ldu = n_rows;
auto ldv = n_cols;
auto* in_ptr = const_cast<math_t*>(in);

size_t workspaceDevice = 0;
size_t workspaceHost = 0;
RAFT_CUSOLVER_TRY(cusolverDnxgesvdr_bufferSize(cusolverH,
jobu,
jobv,
n_rows,
n_cols,
k,
p,
niters,
in_ptr,
lda,
S,
U,
ldu,
V,
ldv,
&workspaceDevice,
&workspaceHost,
stream));

auto d_workspace = raft::make_device_vector<char>(handle, workspaceDevice);
auto h_workspace = raft::make_host_vector<char>(workspaceHost);
auto devInfo = raft::make_device_scalar<int>(handle, 0);

RAFT_CUSOLVER_TRY(cusolverDnxgesvdr(cusolverH,
jobu,
jobv,
n_rows,
n_cols,
k,
p,
niters,
in_ptr,
lda,
S,
U,
ldu,
V,
ldv,
d_workspace.data_handle(),
workspaceDevice,
h_workspace.data_handle(),
workspaceHost,
devInfo.data_handle(),
stream));

RAFT_CUDA_TRY(cudaGetLastError());

int dev_info;
raft::update_host(&dev_info, devInfo.data_handle(), 1, stream);
handle.sync_stream(stream);
ASSERT(dev_info == 0, "rsvd.cuh: Invalid parameter encountered.");
}

/**
* @brief randomized singular value decomposition (RSVD) on the column major
* float type input matrix (Jacobi-based), by specifying no. of PCs and
Expand Down
Loading

0 comments on commit a9dd440

Please sign in to comment.