Skip to content

Commit

Permalink
level3/gemmi feature (#83)
Browse files Browse the repository at this point in the history
* gemmi benchmark

* gemmi tests

* gemmi samples

* gemmi documentation

* gemmi API

* gemmi fortran binding and example

* internal gemmi structure

* gemmi kernel for transposed B

* minor tweaks

* bump version
  • Loading branch information
ntrost57 authored Jul 7, 2020
1 parent feb8986 commit 0e0f8be
Show file tree
Hide file tree
Showing 21 changed files with 2,573 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ option(BUILD_VERBOSE "Output additional build information" OFF)
include(cmake/Dependencies.cmake)

# Setup version
set(VERSION_STRING "1.15.0")
set(VERSION_STRING "1.15.1")
rocm_setup_version(VERSION ${VERSION_STRING})
set(rocsparse_SOVERSION 0.1)

Expand Down
14 changes: 13 additions & 1 deletion clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "testing_bsrmm.hpp"
#include "testing_csrmm.hpp"
#include "testing_csrsm.hpp"
#include "testing_gemmi.hpp"

// Extra
#include "testing_csrgeam.hpp"
Expand Down Expand Up @@ -210,7 +211,7 @@ int main(int argc, char* argv[])
"SPARSE function to test. Options:\n"
" Level1: axpyi, doti, dotci, gthr, gthrz, roti, sctr\n"
" Level2: bsrmv, bsrsv, coomv, csrmv, csrsv, ellmv, hybmv\n"
" Level3: bsrmm, csrmm, csrsm\n"
" Level3: bsrmm, csrmm, csrsm, gemmi\n"
" Extra: csrgeam, csrgemm\n"
" Preconditioner: csric0, csrilu0\n"
" Conversion: csr2coo, csr2csc, csr2ell, csr2hyb, csr2bsr\n"
Expand Down Expand Up @@ -601,6 +602,17 @@ int main(int argc, char* argv[])
else if(precision == 'z')
testing_csrsm<rocsparse_double_complex>(arg);
}
else if(function == "gemmi")
{
if(precision == 's')
testing_gemmi<float>(arg);
else if(precision == 'd')
testing_gemmi<double>(arg);
else if(precision == 'c')
testing_gemmi<rocsparse_float_complex>(arg);
else if(precision == 'z')
testing_gemmi<rocsparse_double_complex>(arg);
}
else if(function == "csrgeam")
{
if(precision == 's')
Expand Down
153 changes: 153 additions & 0 deletions clients/common/rocsparse_template_specialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2284,6 +2284,159 @@ rocsparse_status rocsparse_csrsm_solve(rocsparse_handle handle,
temp_buffer);
}

// gemmi
template <>
rocsparse_status rocsparse_gemmi(rocsparse_handle handle,
rocsparse_operation trans_A,
rocsparse_operation trans_B,
rocsparse_int m,
rocsparse_int n,
rocsparse_int k,
rocsparse_int nnz,
const float* alpha,
const float* A,
rocsparse_int lda,
const rocsparse_mat_descr descr,
const float* csr_val,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
const float* beta,
float* C,
rocsparse_int ldc)
{
return rocsparse_sgemmi(handle,
trans_A,
trans_B,
m,
n,
k,
nnz,
alpha,
A,
lda,
descr,
csr_val,
csr_row_ptr,
csr_col_ind,
beta,
C,
ldc);
}

template <>
rocsparse_status rocsparse_gemmi(rocsparse_handle handle,
rocsparse_operation trans_A,
rocsparse_operation trans_B,
rocsparse_int m,
rocsparse_int n,
rocsparse_int k,
rocsparse_int nnz,
const double* alpha,
const double* A,
rocsparse_int lda,
const rocsparse_mat_descr descr,
const double* csr_val,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
const double* beta,
double* C,
rocsparse_int ldc)
{
return rocsparse_dgemmi(handle,
trans_A,
trans_B,
m,
n,
k,
nnz,
alpha,
A,
lda,
descr,
csr_val,
csr_row_ptr,
csr_col_ind,
beta,
C,
ldc);
}

template <>
rocsparse_status rocsparse_gemmi(rocsparse_handle handle,
rocsparse_operation trans_A,
rocsparse_operation trans_B,
rocsparse_int m,
rocsparse_int n,
rocsparse_int k,
rocsparse_int nnz,
const rocsparse_float_complex* alpha,
const rocsparse_float_complex* A,
rocsparse_int lda,
const rocsparse_mat_descr descr,
const rocsparse_float_complex* csr_val,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
const rocsparse_float_complex* beta,
rocsparse_float_complex* C,
rocsparse_int ldc)
{
return rocsparse_cgemmi(handle,
trans_A,
trans_B,
m,
n,
k,
nnz,
alpha,
A,
lda,
descr,
csr_val,
csr_row_ptr,
csr_col_ind,
beta,
C,
ldc);
}

template <>
rocsparse_status rocsparse_gemmi(rocsparse_handle handle,
rocsparse_operation trans_A,
rocsparse_operation trans_B,
rocsparse_int m,
rocsparse_int n,
rocsparse_int k,
rocsparse_int nnz,
const rocsparse_double_complex* alpha,
const rocsparse_double_complex* A,
rocsparse_int lda,
const rocsparse_mat_descr descr,
const rocsparse_double_complex* csr_val,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
const rocsparse_double_complex* beta,
rocsparse_double_complex* C,
rocsparse_int ldc)
{
return rocsparse_zgemmi(handle,
trans_A,
trans_B,
m,
n,
k,
nnz,
alpha,
A,
lda,
descr,
csr_val,
csr_row_ptr,
csr_col_ind,
beta,
C,
ldc);
}

/*
* ===========================================================================
* extra SPARSE
Expand Down
20 changes: 20 additions & 0 deletions clients/include/rocsparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,26 @@ rocsparse_status rocsparse_csrsm_solve(rocsparse_handle handle,
rocsparse_solve_policy policy,
void* temp_buffer);

// gemmi
template <typename T>
rocsparse_status rocsparse_gemmi(rocsparse_handle handle,
rocsparse_operation trans_A,
rocsparse_operation trans_B,
rocsparse_int m,
rocsparse_int n,
rocsparse_int k,
rocsparse_int nnz,
const T* alpha,
const T* A,
rocsparse_int lda,
const rocsparse_mat_descr descr,
const T* csr_val,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
const T* beta,
T* C,
rocsparse_int ldc);

/*
* ===========================================================================
* extra SPARSE
Expand Down
42 changes: 42 additions & 0 deletions clients/include/rocsparse_host.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,48 @@ inline void host_csrsm(rocsparse_int M,
*numeric_pivot = (*numeric_pivot == M + 1) ? -1 : *numeric_pivot;
}

template <typename T>
inline void host_gemmi(rocsparse_int M,
rocsparse_int N,
rocsparse_operation transA,
rocsparse_operation transB,
T alpha,
const T* A,
rocsparse_int lda,
const rocsparse_int* csr_row_ptr,
const rocsparse_int* csr_col_ind,
const T* csr_val,
T beta,
T* C,
rocsparse_int ldc,
rocsparse_index_base base)
{
if(transB == rocsparse_operation_transpose)
{
for(rocsparse_int i = 0; i < M; ++i)
{
for(rocsparse_int j = 0; j < N; ++j)
{
T sum = static_cast<T>(0);

rocsparse_int row_begin = csr_row_ptr[j] - base;
rocsparse_int row_end = csr_row_ptr[j + 1] - base;

for(rocsparse_int k = row_begin; k < row_end; ++k)
{
rocsparse_int col_B = csr_col_ind[k] - base;
T val_B = csr_val[k];
T val_A = A[col_B * lda + i];

sum = std::fma(val_A, val_B, sum);
}

C[j * ldc + i] = std::fma(beta, C[j * ldc + i], alpha * sum);
}
}
}
}

/*
* ===========================================================================
* extra SPARSE
Expand Down
6 changes: 6 additions & 0 deletions clients/include/rocsparse_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ Functions:
rocsparse_zbsrmm: { function: bsrmm, <<: *double_precision_complex }
rocsparse_scsrmm: { function: csrmm, <<: *single_precision }
rocsparse_dcsrmm: { function: csrmm, <<: *double_precision }
rocsparse_ccsrmm: { function: csrmm, <<: *single_precision_complex }
rocsparse_zcsrmm: { function: csrmm, <<: *double_precision_complex }
rocsparse_scsrsm_buffer_size: { function: csrsm, <<: *single_precision }
rocsparse_dcsrsm_buffer_size: { function: csrsm, <<: *double_precision }
rocsparse_ccsrsm_buffer_size: { function: csrsm, <<: *single_precision_complex }
Expand All @@ -125,6 +127,10 @@ Functions:
rocsparse_zcsrsm_solve: { function: csrsm, <<: *double_precision_complex }
rocsparse_csrsm_zero_pivot: {function: csrsm }
rocsparse_csrsm_clear: {function: csrsm }
rocsparse_sgemmi: { function: gemmi, <<: *single_precision }
rocsparse_dgemmi: { function: gemmi, <<: *double_precision }
rocsparse_cgemmi: { function: gemmi, <<: *single_precision_complex }
rocsparse_zgemmi: { function: gemmi, <<: *double_precision_complex }

rocsparse_csrgeam_nnz: { function: csrgeam }
rocsparse_scsrgeam: { function: csrgeam, <<: *single_precision }
Expand Down
Loading

0 comments on commit 0e0f8be

Please sign in to comment.