Skip to content

Commit

Permalink
enable use of remaining rocblas template functions (ROCm#152)
Browse files Browse the repository at this point in the history
* remove unnecessary ROCSOLVER_EXPORTs on the cpp files of the public APIs
* enable rocblas trmm template
* trmm fixes
* trmm fixes2
* enable rocblas trsm template
* trsm fixes
* clean up
* review fixes
  • Loading branch information
jzuniga-amd authored Sep 15, 2020
1 parent 671dfc3 commit c2cd214
Show file tree
Hide file tree
Showing 84 changed files with 1,535 additions and 1,551 deletions.
15 changes: 11 additions & 4 deletions rocsolver/clients/gtest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,34 @@ set( THREADS_PREFER_PTHREAD_FLAG ON )
find_package( Threads REQUIRED )

set(roclapack_test_source
# vector & matrix manipulations
lacgv_gtest.cpp
laswp_gtest.cpp
# householder reflections
larfg_gtest.cpp
larf_gtest.cpp
larft_gtest.cpp
larfb_gtest.cpp
labrd_gtest.cpp
bdsqr_gtest.cpp
# orthonormal/unitary matrices
orgxr_ungxr_gtest.cpp
orglx_unglx_gtest.cpp
ormxr_unmxr_gtest.cpp
ormlx_unmlx_gtest.cpp
orgbr_ungbr_gtest.cpp
ormbr_unmbr_gtest.cpp
getf2_getrf_gtest.cpp
# bidiagonal matrices and svd
labrd_gtest.cpp
bdsqr_gtest.cpp
# triangular factorizations and linear solvers
potf2_potrf_gtest.cpp
getrs_gtest.cpp
getf2_getrf_gtest.cpp
getri_gtest.cpp
getrs_gtest.cpp
# orthogonal factorizations
geqr2_geqrf_gtest.cpp
geql2_geqlf_gtest.cpp
gelq2_gelqf_gtest.cpp
# bidiagonalization and svd
gebd2_gebrd_gtest.cpp
gesvd_gtest.cpp
)
Expand Down
8 changes: 4 additions & 4 deletions rocsolver/clients/include/testing_getrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ void getrs_getPerfData(const rocblas_handle handle,
template <bool BATCHED, bool STRIDED, typename T>
void testing_getrs(Arguments argus) {
rocblas_local_handle handle;
/* Set handle memory size to a large enough value for all tests to pass.
(TODO: A more definitive solution could be implemented once
the handle memory model APIs are enabled in rocsolver)*/
rocblas_set_device_memory_size(handle, 20000000);
// /* Set handle memory size to a large enough value for all tests to pass.
// (TODO: A more definitive solution could be implemented once
// the handle memory model APIs are enabled in rocsolver)*/
// rocblas_set_device_memory_size(handle, 20000000);

// get arguments
rocblas_int m = argus.M;
Expand Down
10 changes: 8 additions & 2 deletions rocsolver/library/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)

set( rocsolver_auxiliary_source
# vector & matrix manipulations
auxiliary/rocauxiliary_aliases.cpp
auxiliary/rocauxiliary_lacgv.cpp
auxiliary/rocauxiliary_laswp.cpp
# householder reflections
auxiliary/rocauxiliary_larfg.cpp
auxiliary/rocauxiliary_larf.cpp
auxiliary/rocauxiliary_larft.cpp
auxiliary/rocauxiliary_larfb.cpp
auxiliary/rocauxiliary_labrd.cpp
# orthonormal/unitary matrices
auxiliary/rocauxiliary_org2r_ung2r.cpp
auxiliary/rocauxiliary_orgqr_ungqr.cpp
auxiliary/rocauxiliary_orgl2_ungl2.cpp
Expand All @@ -45,10 +47,13 @@ set( rocsolver_auxiliary_source
auxiliary/rocauxiliary_orml2_unml2.cpp
auxiliary/rocauxiliary_ormlq_unmlq.cpp
auxiliary/rocauxiliary_ormbr_unmbr.cpp
# bidiagonal matrices and svd
auxiliary/rocauxiliary_bdsqr.cpp
auxiliary/rocauxiliary_labrd.cpp
)

set( rocsolver_lapack_source
# triangular factorizations and linear solvers
lapack/roclapack_getf2.cpp
lapack/roclapack_getf2_batched.cpp
lapack/roclapack_getf2_strided_batched.cpp
Expand All @@ -68,6 +73,7 @@ set( rocsolver_lapack_source
lapack/roclapack_potrf.cpp
lapack/roclapack_potrf_batched.cpp
lapack/roclapack_potrf_strided_batched.cpp
# orthogonal factorizations
lapack/roclapack_geqr2.cpp
lapack/roclapack_geqr2_batched.cpp
lapack/roclapack_geqr2_strided_batched.cpp
Expand All @@ -87,6 +93,7 @@ set( rocsolver_lapack_source
lapack/roclapack_gelqf.cpp
lapack/roclapack_gelqf_batched.cpp
lapack/roclapack_gelqf_strided_batched.cpp
# bidiagonalization and svd
lapack/roclapack_gebd2.cpp
lapack/roclapack_gebd2_batched.cpp
lapack/roclapack_gebd2_strided_batched.cpp
Expand All @@ -100,7 +107,6 @@ set( rocsolver_lapack_source

set( auxiliaries
buildinfo.cpp
rocblas.cpp
)

prepend_path( ".." rocsolver_headers_public relative_rocsolver_headers_public )
Expand Down
40 changes: 22 additions & 18 deletions rocsolver/library/src/auxiliary/rocauxiliary_bdsqr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,36 +61,40 @@ rocsolver_bdsqr_impl(rocblas_handle handle, const rocblas_fill uplo,

extern "C" {

ROCSOLVER_EXPORT rocblas_status rocsolver_sbdsqr(
rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n,
const rocblas_int nv, const rocblas_int nu, const rocblas_int nc, float *D,
float *E, float *V, const rocblas_int ldv, float *U, const rocblas_int ldu,
float *C, const rocblas_int ldc, rocblas_int *info) {
rocblas_status rocsolver_sbdsqr(rocblas_handle handle, const rocblas_fill uplo,
const rocblas_int n, const rocblas_int nv,
const rocblas_int nu, const rocblas_int nc,
float *D, float *E, float *V,
const rocblas_int ldv, float *U,
const rocblas_int ldu, float *C,
const rocblas_int ldc, rocblas_int *info) {
return rocsolver_bdsqr_impl<float>(handle, uplo, n, nv, nu, nc, D, E, V, ldv,
U, ldu, C, ldc, info);
}

ROCSOLVER_EXPORT rocblas_status rocsolver_dbdsqr(
rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n,
const rocblas_int nv, const rocblas_int nu, const rocblas_int nc, double *D,
double *E, double *V, const rocblas_int ldv, double *U,
const rocblas_int ldu, double *C, const rocblas_int ldc,
rocblas_int *info) {
rocblas_status rocsolver_dbdsqr(rocblas_handle handle, const rocblas_fill uplo,
const rocblas_int n, const rocblas_int nv,
const rocblas_int nu, const rocblas_int nc,
double *D, double *E, double *V,
const rocblas_int ldv, double *U,
const rocblas_int ldu, double *C,
const rocblas_int ldc, rocblas_int *info) {
return rocsolver_bdsqr_impl<double>(handle, uplo, n, nv, nu, nc, D, E, V, ldv,
U, ldu, C, ldc, info);
}

ROCSOLVER_EXPORT rocblas_status rocsolver_cbdsqr(
rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n,
const rocblas_int nv, const rocblas_int nu, const rocblas_int nc, float *D,
float *E, rocblas_float_complex *V, const rocblas_int ldv,
rocblas_float_complex *U, const rocblas_int ldu, rocblas_float_complex *C,
const rocblas_int ldc, rocblas_int *info) {
rocblas_status rocsolver_cbdsqr(rocblas_handle handle, const rocblas_fill uplo,
const rocblas_int n, const rocblas_int nv,
const rocblas_int nu, const rocblas_int nc,
float *D, float *E, rocblas_float_complex *V,
const rocblas_int ldv, rocblas_float_complex *U,
const rocblas_int ldu, rocblas_float_complex *C,
const rocblas_int ldc, rocblas_int *info) {
return rocsolver_bdsqr_impl<rocblas_float_complex>(
handle, uplo, n, nv, nu, nc, D, E, V, ldv, U, ldu, C, ldc, info);
}

ROCSOLVER_EXPORT rocblas_status rocsolver_zbdsqr(
rocblas_status rocsolver_zbdsqr(
rocblas_handle handle, const rocblas_fill uplo, const rocblas_int n,
const rocblas_int nv, const rocblas_int nu, const rocblas_int nc, double *D,
double *E, rocblas_double_complex *V, const rocblas_int ldv,
Expand Down
38 changes: 21 additions & 17 deletions rocsolver/library/src/auxiliary/rocauxiliary_labrd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,35 +76,39 @@ rocsolver_labrd_impl(rocblas_handle handle, const rocblas_int m,

extern "C" {

ROCSOLVER_EXPORT rocblas_status rocsolver_slabrd(
rocblas_handle handle, const rocblas_int m, const rocblas_int n,
const rocblas_int k, float *A, const rocblas_int lda, float *D, float *E,
float *tauq, float *taup, float *X, const rocblas_int ldx, float *Y,
const rocblas_int ldy) {
rocblas_status rocsolver_slabrd(rocblas_handle handle, const rocblas_int m,
const rocblas_int n, const rocblas_int k,
float *A, const rocblas_int lda, float *D,
float *E, float *tauq, float *taup, float *X,
const rocblas_int ldx, float *Y,
const rocblas_int ldy) {
return rocsolver_labrd_impl<float, float>(handle, m, n, k, A, lda, D, E, tauq,
taup, X, ldx, Y, ldy);
}

ROCSOLVER_EXPORT rocblas_status rocsolver_dlabrd(
rocblas_handle handle, const rocblas_int m, const rocblas_int n,
const rocblas_int k, double *A, const rocblas_int lda, double *D, double *E,
double *tauq, double *taup, double *X, const rocblas_int ldx, double *Y,
const rocblas_int ldy) {
rocblas_status rocsolver_dlabrd(rocblas_handle handle, const rocblas_int m,
const rocblas_int n, const rocblas_int k,
double *A, const rocblas_int lda, double *D,
double *E, double *tauq, double *taup,
double *X, const rocblas_int ldx, double *Y,
const rocblas_int ldy) {
return rocsolver_labrd_impl<double, double>(handle, m, n, k, A, lda, D, E,
tauq, taup, X, ldx, Y, ldy);
}

ROCSOLVER_EXPORT rocblas_status rocsolver_clabrd(
rocblas_handle handle, const rocblas_int m, const rocblas_int n,
const rocblas_int k, rocblas_float_complex *A, const rocblas_int lda,
float *D, float *E, rocblas_float_complex *tauq,
rocblas_float_complex *taup, rocblas_float_complex *X,
const rocblas_int ldx, rocblas_float_complex *Y, const rocblas_int ldy) {
rocblas_status rocsolver_clabrd(rocblas_handle handle, const rocblas_int m,
const rocblas_int n, const rocblas_int k,
rocblas_float_complex *A, const rocblas_int lda,
float *D, float *E, rocblas_float_complex *tauq,
rocblas_float_complex *taup,
rocblas_float_complex *X, const rocblas_int ldx,
rocblas_float_complex *Y,
const rocblas_int ldy) {
return rocsolver_labrd_impl<float, rocblas_float_complex>(
handle, m, n, k, A, lda, D, E, tauq, taup, X, ldx, Y, ldy);
}

ROCSOLVER_EXPORT rocblas_status rocsolver_zlabrd(
rocblas_status rocsolver_zlabrd(
rocblas_handle handle, const rocblas_int m, const rocblas_int n,
const rocblas_int k, rocblas_double_complex *A, const rocblas_int lda,
double *D, double *E, rocblas_double_complex *tauq,
Expand Down
14 changes: 6 additions & 8 deletions rocsolver/library/src/auxiliary/rocauxiliary_lacgv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,15 @@ rocblas_status rocsolver_lacgv_impl(rocblas_handle handle, const rocblas_int n,

extern "C" {

ROCSOLVER_EXPORT rocblas_status rocsolver_clacgv(rocblas_handle handle,
const rocblas_int n,
rocblas_float_complex *x,
const rocblas_int incx) {
rocblas_status rocsolver_clacgv(rocblas_handle handle, const rocblas_int n,
rocblas_float_complex *x,
const rocblas_int incx) {
return rocsolver_lacgv_impl<rocblas_float_complex>(handle, n, x, incx);
}

ROCSOLVER_EXPORT rocblas_status rocsolver_zlacgv(rocblas_handle handle,
const rocblas_int n,
rocblas_double_complex *x,
const rocblas_int incx) {
rocblas_status rocsolver_zlacgv(rocblas_handle handle, const rocblas_int n,
rocblas_double_complex *x,
const rocblas_int incx) {
return rocsolver_lacgv_impl<rocblas_double_complex>(handle, n, x, incx);
}

Expand Down
41 changes: 23 additions & 18 deletions rocsolver/library/src/auxiliary/rocauxiliary_larf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,35 +66,40 @@ rocblas_status rocsolver_larf_impl(rocblas_handle handle,

extern "C" {

ROCSOLVER_EXPORT rocblas_status rocsolver_slarf(
rocblas_handle handle, const rocblas_side side, const rocblas_int m,
const rocblas_int n, float *x, const rocblas_int incx, const float *alpha,
float *A, const rocblas_int lda) {
rocblas_status rocsolver_slarf(rocblas_handle handle, const rocblas_side side,
const rocblas_int m, const rocblas_int n,
float *x, const rocblas_int incx,
const float *alpha, float *A,
const rocblas_int lda) {
return rocsolver_larf_impl<float>(handle, side, m, n, x, incx, alpha, A, lda);
}

ROCSOLVER_EXPORT rocblas_status rocsolver_dlarf(
rocblas_handle handle, const rocblas_side side, const rocblas_int m,
const rocblas_int n, double *x, const rocblas_int incx, const double *alpha,
double *A, const rocblas_int lda) {
rocblas_status rocsolver_dlarf(rocblas_handle handle, const rocblas_side side,
const rocblas_int m, const rocblas_int n,
double *x, const rocblas_int incx,
const double *alpha, double *A,
const rocblas_int lda) {
return rocsolver_larf_impl<double>(handle, side, m, n, x, incx, alpha, A,
lda);
}

ROCSOLVER_EXPORT rocblas_status rocsolver_clarf(
rocblas_handle handle, const rocblas_side side, const rocblas_int m,
const rocblas_int n, rocblas_float_complex *x, const rocblas_int incx,
const rocblas_float_complex *alpha, rocblas_float_complex *A,
const rocblas_int lda) {
rocblas_status rocsolver_clarf(rocblas_handle handle, const rocblas_side side,
const rocblas_int m, const rocblas_int n,
rocblas_float_complex *x, const rocblas_int incx,
const rocblas_float_complex *alpha,
rocblas_float_complex *A,
const rocblas_int lda) {
return rocsolver_larf_impl<rocblas_float_complex>(handle, side, m, n, x, incx,
alpha, A, lda);
}

ROCSOLVER_EXPORT rocblas_status rocsolver_zlarf(
rocblas_handle handle, const rocblas_side side, const rocblas_int m,
const rocblas_int n, rocblas_double_complex *x, const rocblas_int incx,
const rocblas_double_complex *alpha, rocblas_double_complex *A,
const rocblas_int lda) {
rocblas_status rocsolver_zlarf(rocblas_handle handle, const rocblas_side side,
const rocblas_int m, const rocblas_int n,
rocblas_double_complex *x,
const rocblas_int incx,
const rocblas_double_complex *alpha,
rocblas_double_complex *A,
const rocblas_int lda) {
return rocsolver_larf_impl<rocblas_double_complex>(handle, side, m, n, x,
incx, alpha, A, lda);
}
Expand Down
Loading

0 comments on commit c2cd214

Please sign in to comment.