Skip to content

Commit

Permalink
Merge pull request ROCm#5 from iotamudelta/master
Browse files Browse the repository at this point in the history
sgetrs/dgetrs functions
  • Loading branch information
iotamudelta authored Jun 19, 2018
2 parents f3475e1 + 6b45238 commit 198e32e
Show file tree
Hide file tree
Showing 22 changed files with 877 additions and 38 deletions.
57 changes: 57 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,63 @@ include( ROCMCreatePackage )
include( ROCMInstallTargets )
include( ROCMPackageConfigHelpers )
include( ROCMInstallSymlinks )
include( ROCMClangTidy )
rocm_enable_clang_tidy(
CHECKS
*
-cert-env33-c
-android-cloexec-fopen
-cert-msc50-cpp
-clang-analyzer-alpha.core.CastToStruct
-clang-analyzer-optin.performance.Padding
-clang-diagnostic-deprecated-declarations
-clang-diagnostic-extern-c-compat
-clang-diagnostic-unused-command-line-argument
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
-cppcoreguidelines-pro-bounds-constant-array-index
-cppcoreguidelines-pro-bounds-pointer-arithmetic
-cppcoreguidelines-pro-type-member-init
-cppcoreguidelines-pro-type-reinterpret-cast
-cppcoreguidelines-pro-type-union-access
-cppcoreguidelines-pro-type-vararg
-cppcoreguidelines-special-member-functions
-fuchsia-*
-google-readability-braces-around-statements
-google-readability-todo
-google-runtime-int
-google-runtime-references
-hicpp-braces-around-statements
-hicpp-explicit-conversions
-hicpp-no-array-decay
-hicpp-special-member-functions
-hicpp-use-override
# This check is broken
-hicpp-use-auto
-llvm-header-guard
-llvm-include-order
-misc-macro-parentheses
-modernize-use-auto
-modernize-use-override
-modernize-pass-by-value
-modernize-use-default-member-init
-modernize-use-transparent-functors
-performance-unnecessary-value-param
-readability-braces-around-statements
-readability-else-after-return
-readability-named-parameter
-*-explicit-constructor
-*-use-emplace
-*-use-equals-default
ERRORS
*
-readability-inconsistent-declaration-parameter-name
HEADER_FILTER
".*hpp"
EXTRA_ARGS
-DROCSOLVER_USE_CLANG_TIDY
ANALYZE_TEMPORARY_DTORS ON

)

rocm_setup_version( VERSION 0.1.0 NO_GIT_TAG_VERSION )

Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ CXX=/opt/rocm/bin/hcc cmake ..
make
```
# Implemented functions in LAPACK notation
Cholesky decomposition: `rocsolver_spotf2() rocsolver_dpotf2()`
unblocked LU decomposition: `rocsolver_sgetf2() rocsolver_dgetf2()`
blocked LU decomposition: `rocsolver_sgetrf() rocsolver_dgetrf()`
Cholesky decomposition: `rocsolver_spotf2() rocsolver_dpotf2()`
unblocked LU decomposition: `rocsolver_sgetf2() rocsolver_dgetf2()`
blocked LU decomposition: `rocsolver_sgetrf() rocsolver_dgetrf()`
solution of system of linear equations: `rocsolver_sgetrs() rocsolver_dgetrs()`
2 changes: 2 additions & 0 deletions clients/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ set( rocsolver_benchmark_common
add_executable( rocsolver-bench client.cpp ${rocsolver_benchmark_common} )
target_compile_features( rocsolver-bench PRIVATE cxx_static_assert cxx_nullptr cxx_auto_type )

rocm_clang_tidy_check(rocsolver-bench)

if( BUILD_WITH_TENSILE )
target_compile_definitions( rocsolver-bench PRIVATE BUILD_WITH_TENSILE=1 )
else()
Expand Down
8 changes: 7 additions & 1 deletion clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "testing_getf2.hpp"
#include "testing_getrf.hpp"
#include "testing_getrs.hpp"
#include "testing_potf2.hpp"
#include "utility.h"

Expand Down Expand Up @@ -99,7 +100,7 @@ int main(int argc, char *argv[]) {

("function,f",
po::value<std::string>(&function)->default_value("potf2"),
"LAPACK function to test. Options: potf2, getf2, getrf")
"LAPACK function to test. Options: potf2, getf2, getrf, getrs")

("precision,r",
po::value<char>(&precision)->default_value('s'), "Options: h,s,d,c,z")
Expand Down Expand Up @@ -191,6 +192,11 @@ int main(int argc, char *argv[]) {
testing_getrf<float>(argus);
else if (precision == 'd')
testing_getrf<double>(argus);
} else if (function == "getrs") {
if (precision == 's')
testing_getrs<float>(argus);
else if (precision == 'd')
testing_getrs<double>(argus);
} else {
printf("Invalid value for --function \n");
return -1;
Expand Down
21 changes: 21 additions & 0 deletions clients/common/arg_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ void getrf_arg_check(rocblas_status status, rocblas_int M, rocblas_int N) {
#endif
}

void getrs_arg_check(rocblas_status status, rocblas_int M, rocblas_int nhrs,
rocblas_int lda, rocblas_int ldb) {
#ifdef GOOGLE_TEST
if (M < 0 || nhrs < 0 || lda < std::max(1, M) || ldb < std::max(1, M)) {
ASSERT_EQ(status, rocblas_status_invalid_size);
} else {
ASSERT_EQ(status, rocblas_status_success);
}
#else
if (M < 0 || nhrs < 0 || lda < std::max(1, M) || ldb < std::max(1, M)) {
if (status != rocblas_status_invalid_size)
std::cerr << "result should be invalid size for size " << M << " and "
<< nhrs << std::endl;
} else {
if (status != rocblas_status_success)
std::cerr << "result should be success for size " << M << " and " << nhrs
<< std::endl;
}
#endif
}

void verify_rocblas_status_invalid_pointer(rocblas_status status,
const char *message) {
#ifdef GOOGLE_TEST
Expand Down
50 changes: 50 additions & 0 deletions clients/common/cblas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ void cgetf2_(int *m, int *n, rocblas_float_complex *A, int *lda, int *ipiv,
void zgetf2_(int *m, int *n, rocblas_double_complex *A, int *lda, int *ipiv,
int *info);

void sgetrs_(char *trans, int *n, int *nrhs, float *A, int *lda, int *ipiv,
float *B, int *ldb, int *info);
void dgetrs_(char *trans, int *n, int *nrhs, double *A, int *lda, int *ipiv,
double *B, int *ldb, int *info);
void cgetrs_(char *trans, int *n, int *nrhs, rocblas_float_complex *A, int *lda,
int *ipiv, rocblas_float_complex *B, int *ldb, int *info);
void zgetrs_(char *trans, int *n, int *nrhs, rocblas_double_complex *A,
int *lda, int *ipiv, rocblas_double_complex *B, int *ldb,
int *info);

#ifdef __cplusplus
}
#endif
Expand Down Expand Up @@ -743,3 +753,43 @@ rocblas_int cblas_potrf<rocblas_double_complex>(char uplo, rocblas_int m,
zpotrf_(&uplo, &m, A, &lda, &info);
return info;
}

// getrs
template <>
rocblas_int cblas_getrs<float>(char trans, rocblas_int n, rocblas_int nrhs,
float *A, rocblas_int lda, rocblas_int *ipiv,
float *B, rocblas_int ldb) {
rocblas_int info;
sgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info);
return info;
}

template <>
rocblas_int cblas_getrs<double>(char trans, rocblas_int n, rocblas_int nrhs,
double *A, rocblas_int lda, rocblas_int *ipiv,
double *B, rocblas_int ldb) {
rocblas_int info;
dgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info);
return info;
}

template <>
rocblas_int
cblas_getrs<rocblas_float_complex>(char trans, rocblas_int n, rocblas_int nrhs,
rocblas_float_complex *A, rocblas_int lda,
rocblas_int *ipiv, rocblas_float_complex *B,
rocblas_int ldb) {
rocblas_int info;
cgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info);
return info;
}

template <>
rocblas_int cblas_getrs<rocblas_double_complex>(
char trans, rocblas_int n, rocblas_int nrhs, rocblas_double_complex *A,
rocblas_int lda, rocblas_int *ipiv, rocblas_double_complex *B,
rocblas_int ldb) {
rocblas_int info;
zgetrs_(&trans, &n, &nrhs, A, &lda, ipiv, B, &ldb, &info);
return info;
}
16 changes: 16 additions & 0 deletions clients/common/unit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,20 @@ void getrf_err_res_check(double max_error, rocblas_int M, rocblas_int N,
#ifdef GOOGLE_TEST
ASSERT_LE(max_error, forward_tolerance * eps);
#endif
}

template <>
void getrs_err_res_check(float max_error, rocblas_int M, rocblas_int nhrs,
float forward_tolerance, float eps) {
#ifdef GOOGLE_TEST
ASSERT_LE(max_error, forward_tolerance * eps);
#endif
}

template <>
void getrs_err_res_check(double max_error, rocblas_int M, rocblas_int nhrs,
double forward_tolerance, double eps) {
#ifdef GOOGLE_TEST
ASSERT_LE(max_error, forward_tolerance * eps);
#endif
}
3 changes: 3 additions & 0 deletions clients/gtest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ find_package( Threads REQUIRED )
set(roclapack_test_source
getf2_gtest.cpp
getrf_gtest.cpp
getrs_gtest.cpp
potf2_gtest.cpp
)

Expand All @@ -54,6 +55,8 @@ set( rocsolver_benchmark_common
add_executable( rocsolver-test ${roclapack_test_source} ${rocsolver_test_source} ${rocsolver_benchmark_common} )
target_compile_features( rocsolver-test PRIVATE cxx_static_assert cxx_nullptr cxx_auto_type )

rocm_clang_tidy_check(rocsolver-test)

target_compile_definitions( rocsolver-test PRIVATE BUILD_WITH_TENSILE=0 GOOGLE_TEST )

# Internal header includes
Expand Down
Loading

0 comments on commit 198e32e

Please sign in to comment.