From 91f7f2b290d4596291eeb8798446b6f255434da6 Mon Sep 17 00:00:00 2001 From: James Sandham <33790278+jsandham@users.noreply.github.com> Date: Mon, 6 Jul 2020 13:30:04 -0500 Subject: [PATCH] Bsrmm (#56) * started creating skeleton code for bsrmm * rebase bsrmm to squash commits clang formatting Allow library dependencies to be installed from CI (#49) csrgeam (#46) * csrgeam API added * csrgeam tests and benchmark added * flops, bandwidth and host implementation for csrgeam * csrgeam unit tests * removed webbase_1M test * csrgeam (functional) added * added tests for invalid sizes * typos and year * clang-format * csrgeam performance scripts bump version Replace host code in bsr2csr (#48) * removed host bsr2csr and csr2bsr code and replaced it with device calls * clang formatting Co-authored-by: jsandham bump version added some examples (#50) * added sparse level 1 examples * added examples for sparse level 2 and 3 * clang-format * added sparse extra examples * bump version hipclang related fixes (#51) * hipclang related fixes * bump version sanity check for matrix download (#52) added fallback for unit test matrix downloads (#53) examples fix (#54) * header fix for examples * bump version got bsrmm working for block dim less than 8 clang formatting fixing bugs and getting benchmark to work optimizing and working on kernels for block dimension greater than 8 kernels and code for block dimension greater than 8 and B matrix transposed expanded loop unrolling up to block dimension 16 clang formatting Remove gpg check for CI package CentOS install (#57) updated internal function names (#61) * renamed internal csrtr to trm * clang-format added missing header (#62) fixes to documentation remove compile time evaluation of direction to help reduce the number of kernels clang formatting small performance improvements to transpose kernel clang formatting increase transpose performance clang formatting re-ordering row pointer and column arrays for csr2csr_compress (#59) * re-ordering row pointer and column arrays for csr2csr_compress * fixing broken tests * fixing incorrect order in log_trace * moving deletion of temporary arry to ensure it is always called Co-authored-by: jsandham bump version Single thread compile in install script (#63) pyyaml package name fix for centos8 (#60) * pyyaml package name fix for centos8 * this should also account for rhel8 * bump version Update README.md pivot test fix (#65) * adding device sync in spin loop tests to not overwrite pivots before checking them * bump version Removing rock-dkms (#66) Revert "Single thread compile in install script (#63)" (#69) Fortran interface (#55) * fortran interface draft with examples added * example fix to properly work with return values * force cmake to add .f90 module to package * added some more missing level1, level3 and conversion routines * added few more missing functions to wrapper * csric0 and csrilu0 fortran examples * csrgemm_buffer_size binding name fixed * fortran example fix, stop allows only constant expressions * fix for string passing * added enums to fortran; example for aux functions; fixes to pointer arguments * more examples * updated fortran example output of csrilu0 and csric0 * updated install.sh script and dockerfiles to install gfortran dependencies * fix for device pointer mode * few changes to make it consistent with hipfort * bump version ddoti fortran fix (#71) bsrmv smem sync? (#70) bump version mtx pattern fix (#73) Added centos 8 dependency fixes (#74) bump version bsrsv (#72) * general working version of bsrsv for lower and upper non transposed matrices * fixing bsr_to_bsc order * added functionality for transposed matrix * enabling complex numbers * optimized bsrsv for BSR dimensions from 2x2 to 32x32 * gfx908 * fortran functions and example * disabling some unit diagonal tests with nos1 and nos2 * bump version fortran module fixes (#75) centos 6 (#76) * centos6 support * bump version Allow library dependencies to be installed from CI (#49) csrgeam (#46) * csrgeam API added * csrgeam tests and benchmark added * flops, bandwidth and host implementation for csrgeam * csrgeam unit tests * removed webbase_1M test * csrgeam (functional) added * added tests for invalid sizes * typos and year * clang-format * csrgeam performance scripts added some examples (#50) * added sparse level 1 examples * added examples for sparse level 2 and 3 * clang-format * added sparse extra examples * bump version examples fix (#54) * header fix for examples * bump version Remove gpg check for CI package CentOS install (#57) added missing header (#62) re-ordering row pointer and column arrays for csr2csr_compress (#59) * re-ordering row pointer and column arrays for csr2csr_compress * fixing broken tests * fixing incorrect order in log_trace * moving deletion of temporary arry to ensure it is always called Co-authored-by: jsandham Single thread compile in install script (#63) Update README.md Removing rock-dkms (#66) Revert "Single thread compile in install script (#63)" (#69) Fortran interface (#55) * fortran interface draft with examples added * example fix to properly work with return values * force cmake to add .f90 module to package * added some more missing level1, level3 and conversion routines * added few more missing functions to wrapper * csric0 and csrilu0 fortran examples * csrgemm_buffer_size binding name fixed * fortran example fix, stop allows only constant expressions * fix for string passing * added enums to fortran; example for aux functions; fixes to pointer arguments * more examples * updated fortran example output of csrilu0 and csric0 * updated install.sh script and dockerfiles to install gfortran dependencies * fix for device pointer mode * few changes to make it consistent with hipfort * bump version ddoti fortran fix (#71) bsrmv smem sync? (#70) bsrsv (#72) * general working version of bsrsv for lower and upper non transposed matrices * fixing bsr_to_bsc order * added functionality for transposed matrix * enabling complex numbers * optimized bsrsv for BSR dimensions from 2x2 to 32x32 * gfx908 * fortran functions and example * disabling some unit diagonal tests with nos1 and nos2 * bump version fortran module fixes (#75) centos 6 (#76) * centos6 support * bump version adding fortran example code fixing fortran compile error adding bsrmm to fortran_module.f90 fixing fortran example array order fix fortran compile error fix fortran compile error adding cpp example code for bsrmm clang formatting working on optimizing kernels working on optimizing kernels optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm optimizing bsrmm reverting back to original kernels optimizing bsrmm making test2 kernel active for block dim 8 optimizing bsrmm significant performance improvement for block dimensions 5 to 32 further performance improvements to transpose and non-transpose case reduce compile times and replaced general kernel optimizing for n <= 16 Correction to the cmake RUNPATH parameter (#79) Co-authored-by: Pruthvi Madugundu bump version cmake update (#80) * cmake update * disabling OpenMP until this is fixed within hipclang Csr2bsr optimization (#78) * optimized csr2bsr_nnz * rebase csr2bsr_optimization branch to squash commits Working on optimizing csr2bsr device code changed blocksize to 16 as this runs twice as fast clang formatting removing comments performance optimizations clang formatting improve performance clang formatting csr2bsr optimization added missing header (#62) re-ordering row pointer and column arrays for csr2csr_compress (#59) * re-ordering row pointer and column arrays for csr2csr_compress * fixing broken tests * fixing incorrect order in log_trace * moving deletion of temporary arry to ensure it is always called Co-authored-by: jsandham bump version Single thread compile in install script (#63) pyyaml package name fix for centos8 (#60) * pyyaml package name fix for centos8 * this should also account for rhel8 * bump version Update README.md pivot test fix (#65) * adding device sync in spin loop tests to not overwrite pivots before checking them * bump version Removing rock-dkms (#66) Revert "Single thread compile in install script (#63)" (#69) Fortran interface (#55) * fortran interface draft with examples added * example fix to properly work with return values * force cmake to add .f90 module to package * added some more missing level1, level3 and conversion routines * added few more missing functions to wrapper * csric0 and csrilu0 fortran examples * csrgemm_buffer_size binding name fixed * fortran example fix, stop allows only constant expressions * fix for string passing * added enums to fortran; example for aux functions; fixes to pointer arguments * more examples * updated fortran example output of csrilu0 and csric0 * updated install.sh script and dockerfiles to install gfortran dependencies * fix for device pointer mode * few changes to make it consistent with hipfort * bump version ddoti fortran fix (#71) bsrmv smem sync? (#70) bump version mtx pattern fix (#73) Added centos 8 dependency fixes (#74) bump version bsrsv (#72) * general working version of bsrsv for lower and upper non transposed matrices * fixing bsr_to_bsc order * added functionality for transposed matrix * enabling complex numbers * optimized bsrsv for BSR dimensions from 2x2 to 32x32 * gfx908 * fortran functions and example * disabling some unit diagonal tests with nos1 and nos2 * bump version fortran module fixes (#75) centos 6 (#76) * centos6 support * bump version added missing header (#62) re-ordering row pointer and column arrays for csr2csr_compress (#59) * re-ordering row pointer and column arrays for csr2csr_compress * fixing broken tests * fixing incorrect order in log_trace * moving deletion of temporary arry to ensure it is always called Co-authored-by: jsandham Single thread compile in install script (#63) Update README.md Removing rock-dkms (#66) Revert "Single thread compile in install script (#63)" (#69) Fortran interface (#55) * fortran interface draft with examples added * example fix to properly work with return values * force cmake to add .f90 module to package * added some more missing level1, level3 and conversion routines * added few more missing functions to wrapper * csric0 and csrilu0 fortran examples * csrgemm_buffer_size binding name fixed * fortran example fix, stop allows only constant expressions * fix for string passing * added enums to fortran; example for aux functions; fixes to pointer arguments * more examples * updated fortran example output of csrilu0 and csric0 * updated install.sh script and dockerfiles to install gfortran dependencies * fix for device pointer mode * few changes to make it consistent with hipfort * bump version ddoti fortran fix (#71) bsrmv smem sync? (#70) bsrsv (#72) * general working version of bsrsv for lower and upper non transposed matrices * fixing bsr_to_bsc order * added functionality for transposed matrix * enabling complex numbers * optimized bsrsv for BSR dimensions from 2x2 to 32x32 * gfx908 * fortran functions and example * disabling some unit diagonal tests with nos1 and nos2 * bump version fortran module fixes (#75) centos 6 (#76) * centos6 support * bump version Co-authored-by: jsandham * reducing number of tests * removing bank conflicts * removing duplicate code from rocsparse-functions header * fixing line in rocspasrse-functions header changed by bad merge * fix formating from merge * fix formatting errors from merge Co-authored-by: jsandham --- clients/benchmarks/client.cpp | 14 +- .../rocsparse_template_specialization.cpp | 169 ++++ clients/include/flops.hpp | 10 + clients/include/gbyte.hpp | 18 + clients/include/rocsparse.hpp | 22 + clients/include/rocsparse_host.hpp | 83 +- clients/include/rocsparse_template.yaml | 4 + clients/include/testing_bsrmm.hpp | 767 ++++++++++++++++ clients/samples/CMakeLists.txt | 2 + clients/samples/example_bsrmm.cpp | 216 +++++ clients/samples/example_fortran_bsrmm.f90 | 261 ++++++ clients/tests/CMakeLists.txt | 3 +- clients/tests/rocsparse_test.yaml | 1 + clients/tests/test_bsrmm.cpp | 117 +++ clients/tests/test_bsrmm.yaml | 207 +++++ docs/source/usermanual.rst | 12 + library/include/rocsparse-functions.h | 244 +++++ library/src/CMakeLists.txt | 1 + library/src/level3/bsrmm_device.h | 583 ++++++++++++ library/src/level3/csrmm_device.h | 12 +- library/src/level3/rocsparse_bsrmm.cpp | 195 ++++ library/src/level3/rocsparse_bsrmm.hpp | 868 ++++++++++++++++++ library/src/rocsparse_module.f90 | 108 +++ 23 files changed, 3906 insertions(+), 11 deletions(-) create mode 100644 clients/include/testing_bsrmm.hpp create mode 100644 clients/samples/example_bsrmm.cpp create mode 100644 clients/samples/example_fortran_bsrmm.f90 create mode 100644 clients/tests/test_bsrmm.cpp create mode 100644 clients/tests/test_bsrmm.yaml create mode 100644 library/src/level3/bsrmm_device.h create mode 100644 library/src/level3/rocsparse_bsrmm.cpp create mode 100644 library/src/level3/rocsparse_bsrmm.hpp diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 747227ae..83c71c66 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -43,6 +43,7 @@ #include "testing_hybmv.hpp" // Level3 +#include "testing_bsrmm.hpp" #include "testing_csrmm.hpp" #include "testing_csrsm.hpp" @@ -209,7 +210,7 @@ int main(int argc, char* argv[]) "SPARSE function to test. Options:\n" " Level1: axpyi, doti, dotci, gthr, gthrz, roti, sctr\n" " Level2: bsrmv, bsrsv, coomv, csrmv, csrsv, ellmv, hybmv\n" - " Level3: csrmm, csrsm\n" + " Level3: bsrmm, csrmm, csrsm\n" " Extra: csrgeam, csrgemm\n" " Preconditioner: csric0, csrilu0\n" " Conversion: csr2coo, csr2csc, csr2ell, csr2hyb, csr2bsr\n" @@ -567,6 +568,17 @@ int main(int argc, char* argv[]) else if(precision == 'z') testing_hybmv(arg); } + else if(function == "bsrmm") + { + if(precision == 's') + testing_bsrmm(arg); + else if(precision == 'd') + testing_bsrmm(arg); + else if(precision == 'c') + testing_bsrmm(arg); + else if(precision == 'z') + testing_bsrmm(arg); + } else if(function == "csrmm") { if(precision == 's') diff --git a/clients/common/rocsparse_template_specialization.cpp b/clients/common/rocsparse_template_specialization.cpp index e1a6c775..3f8b80fd 100644 --- a/clients/common/rocsparse_template_specialization.cpp +++ b/clients/common/rocsparse_template_specialization.cpp @@ -1521,6 +1521,175 @@ rocsparse_status rocsparse_hybmv(rocsparse_handle handle, * level 3 SPARSE * =========================================================================== */ +// bsrmm +template <> +rocsparse_status rocsparse_bsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const float* alpha, + const rocsparse_mat_descr descr, + const float* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const float* B, + rocsparse_int ldb, + const float* beta, + float* C, + rocsparse_int ldc) +{ + return rocsparse_sbsrmm(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +template <> +rocsparse_status rocsparse_bsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const double* alpha, + const rocsparse_mat_descr descr, + const double* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const double* B, + rocsparse_int ldb, + const double* beta, + double* C, + rocsparse_int ldc) +{ + return rocsparse_dbsrmm(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +template <> +rocsparse_status rocsparse_bsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_float_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_float_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_float_complex* B, + rocsparse_int ldb, + const rocsparse_float_complex* beta, + rocsparse_float_complex* C, + rocsparse_int ldc) +{ + return rocsparse_cbsrmm(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +template <> +rocsparse_status rocsparse_bsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_double_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_double_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_double_complex* B, + rocsparse_int ldb, + const rocsparse_double_complex* beta, + rocsparse_double_complex* C, + rocsparse_int ldc) +{ + return rocsparse_zbsrmm(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + // csrmm template <> rocsparse_status rocsparse_csrmm(rocsparse_handle handle, diff --git a/clients/include/flops.hpp b/clients/include/flops.hpp index 82c712aa..fcd06b5e 100644 --- a/clients/include/flops.hpp +++ b/clients/include/flops.hpp @@ -78,6 +78,16 @@ constexpr double csrsv_gflop_count(rocsparse_int M, rocsparse_int nnz, rocsparse * level 3 SPARSE * =========================================================================== */ +template +constexpr double bsrmm_gflop_count(rocsparse_int N, + rocsparse_int nnzb, + rocsparse_int block_dim, + rocsparse_int nnz_C, + bool beta = false) +{ + return (3.0 * nnzb * block_dim * block_dim * N + (beta ? nnz_C : 0)) / 1e9; +} + template constexpr double csrmm_gflop_count(rocsparse_int N, rocsparse_int nnz_A, rocsparse_int nnz_C, bool beta = false) diff --git a/clients/include/gbyte.hpp b/clients/include/gbyte.hpp index d785336a..91a5ab91 100644 --- a/clients/include/gbyte.hpp +++ b/clients/include/gbyte.hpp @@ -132,6 +132,24 @@ constexpr double * level 3 SPARSE * =========================================================================== */ +template +constexpr double bsrmm_gbyte_count(rocsparse_int Mb, + rocsparse_int nnzb, + rocsparse_int block_dim, + rocsparse_int nnz_B, + rocsparse_int nnz_C, + bool beta = false) +{ + //reads + size_t reads = (Mb + 1 + nnzb) * sizeof(rocsparse_int) + + (block_dim * block_dim * nnzb + nnz_B + (beta ? nnz_C : 0)) * sizeof(T); + + //writes + size_t writes = nnz_C * sizeof(T); + + return (reads + writes) / 1e9; +} + template constexpr double csrmm_gbyte_count(rocsparse_int M, rocsparse_int nnz_A, diff --git a/clients/include/rocsparse.hpp b/clients/include/rocsparse.hpp index 94a93ed4..951a26d1 100644 --- a/clients/include/rocsparse.hpp +++ b/clients/include/rocsparse.hpp @@ -302,6 +302,28 @@ rocsparse_status rocsparse_hybmv(rocsparse_handle handle, * level 3 SPARSE * =========================================================================== */ +// bsrmm +template +rocsparse_status rocsparse_bsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const T* alpha, + const rocsparse_mat_descr descr, + const T* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const T* B, + rocsparse_int ldb, + const T* beta, + T* C, + rocsparse_int ldc); + // csrmm template rocsparse_status rocsparse_csrmm(rocsparse_handle handle, diff --git a/clients/include/rocsparse_host.hpp b/clients/include/rocsparse_host.hpp index 3fb380bd..31027841 100644 --- a/clients/include/rocsparse_host.hpp +++ b/clients/include/rocsparse_host.hpp @@ -1233,6 +1233,83 @@ inline void host_hybmv(rocsparse_int M, * level 3 SPARSE * =========================================================================== */ +template +inline void host_bsrmm(rocsparse_int Mb, + rocsparse_int N, + rocsparse_int Kb, + rocsparse_int block_dim, + rocsparse_direction dir, + rocsparse_operation transA, + rocsparse_operation transB, + T alpha, + const std::vector& bsr_row_ptr_A, + const std::vector& bsr_col_ind_A, + const std::vector& bsr_val_A, + const std::vector& B, + rocsparse_int ldb, + T beta, + std::vector& C, + rocsparse_int ldc, + rocsparse_index_base base) +{ + if(transA != rocsparse_operation_none) + { + return; + } + + if(transB != rocsparse_operation_none && transB != rocsparse_operation_transpose) + { + return; + } + + rocsparse_int M = Mb * block_dim; + rocsparse_int K = Kb * block_dim; + +#ifdef _OPENMP +#pragma omp parallel for schedule(dynamic, 1024) +#endif + for(rocsparse_int i = 0; i < M; i++) + { + rocsparse_int local_row = i % block_dim; + + rocsparse_int row_begin = bsr_row_ptr_A[i / block_dim] - base; + rocsparse_int row_end = bsr_row_ptr_A[i / block_dim + 1] - base; + + for(rocsparse_int j = 0; j < N; j++) + { + rocsparse_int idx_C = i + j * ldc; + + T sum = static_cast(0); + + for(rocsparse_int s = row_begin; s < row_end; s++) + { + for(rocsparse_int t = 0; t < block_dim; t++) + { + rocsparse_int idx_A + = (dir == rocsparse_direction_row) + ? block_dim * block_dim * s + block_dim * local_row + t + : block_dim * block_dim * s + block_dim * t + local_row; + rocsparse_int idx_B + = (transB == rocsparse_operation_none) + ? j * ldb + block_dim * (bsr_col_ind_A[s] - base) + t + : (block_dim * (bsr_col_ind_A[s] - base) + t) * ldb + j; + + sum = std::fma(bsr_val_A[idx_A], B[idx_B], sum); + } + } + + if(beta == static_cast(0)) + { + C[idx_C] = alpha * sum; + } + else + { + C[idx_C] = std::fma(beta, C[idx_C], alpha * sum); + } + } + } +} + template inline void host_csrmm(rocsparse_int M, rocsparse_int N, @@ -1267,16 +1344,16 @@ inline void host_csrmm(rocsparse_int M, ? (csr_col_ind_A[k] - base + j * ldb) : (j + (csr_col_ind_A[k] - base) * ldb); - sum = std::fma(alpha * csr_val_A[k], B[idx_B], sum); + sum = std::fma(csr_val_A[k], B[idx_B], sum); } if(beta == static_cast(0)) { - C[idx_C] = sum; + C[idx_C] = alpha * sum; } else { - C[idx_C] = std::fma(beta, C[idx_C], sum); + C[idx_C] = std::fma(beta, C[idx_C], alpha * sum); } } } diff --git a/clients/include/rocsparse_template.yaml b/clients/include/rocsparse_template.yaml index ddd43a68..d4f6ee7c 100644 --- a/clients/include/rocsparse_template.yaml +++ b/clients/include/rocsparse_template.yaml @@ -105,6 +105,10 @@ Functions: rocsparse_chybmv: { function: hybmv, <<: *single_precision_complex } rocsparse_zhybmv: { function: hybmv, <<: *double_precision_complex } + rocsparse_sbsrmm: { function: bsrmm, <<: *single_precision } + rocsparse_dbsrmm: { function: bsrmm, <<: *double_precision } + rocsparse_cbsrmm: { function: bsrmm, <<: *single_precision_complex } + rocsparse_zbsrmm: { function: bsrmm, <<: *double_precision_complex } rocsparse_scsrmm: { function: csrmm, <<: *single_precision } rocsparse_dcsrmm: { function: csrmm, <<: *double_precision } rocsparse_scsrsm_buffer_size: { function: csrsm, <<: *single_precision } diff --git a/clients/include/testing_bsrmm.hpp b/clients/include/testing_bsrmm.hpp new file mode 100644 index 00000000..289f04ff --- /dev/null +++ b/clients/include/testing_bsrmm.hpp @@ -0,0 +1,767 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#pragma once +#ifndef TESTING_BSRMM_HPP +#define TESTING_BSRMM_HPP + +#include + +#include "flops.hpp" +#include "gbyte.hpp" +#include "rocsparse_check.hpp" +#include "rocsparse_host.hpp" +#include "rocsparse_init.hpp" +#include "rocsparse_math.hpp" +#include "rocsparse_random.hpp" +#include "rocsparse_test.hpp" +#include "rocsparse_vector.hpp" +#include "utility.hpp" + +#include "testing_bsr2csr.hpp" + +template +void testing_bsrmm_bad_arg(const Arguments& arg) +{ + static const size_t safe_size = 100; + + T h_alpha = 0.6; + T h_beta = 0.1; + + // Create rocsparse handle + rocsparse_local_handle handle; + + // Create matrix descriptor + rocsparse_local_mat_descr descr; + + // Allocate memory on device + device_vector dbsr_row_ptr(safe_size); + device_vector dbsr_col_ind(safe_size); + device_vector dbsr_val(safe_size); + device_vector dB(safe_size); + device_vector dC(safe_size); + + if(!dbsr_row_ptr || !dbsr_col_ind || !dbsr_val || !dB || !dC) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Test invalid handle + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(nullptr, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_handle); + + // Test invalid pointers + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + nullptr, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + nullptr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + nullptr, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + nullptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + nullptr, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + nullptr, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + nullptr, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + nullptr, + safe_size), + rocsparse_status_invalid_pointer); + + // Test invalid size + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + -1, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_size); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + -1, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_size); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + -1, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_size); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + -1, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_size); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + 0, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_size); + + // Test not implemented + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_transpose, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_not_implemented); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_conjugate_transpose, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_not_implemented); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_conjugate_transpose, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_not_implemented); +} + +template +void testing_bsrmm(const Arguments& arg) +{ + rocsparse_int M = arg.M; + rocsparse_int N = arg.N; + rocsparse_int K = arg.K; + rocsparse_int block_dim = arg.block_dim; + rocsparse_int dim_x = arg.dimx; + rocsparse_int dim_y = arg.dimy; + rocsparse_int dim_z = arg.dimz; + rocsparse_operation transA = arg.transA; + rocsparse_operation transB = arg.transB; + rocsparse_direction direction = arg.direction; + rocsparse_index_base base = arg.baseA; + rocsparse_matrix_init mat = arg.matrix; + bool full_rank = false; + std::string filename + = arg.timing ? arg.filename : rocsparse_exepath() + "../matrices/" + arg.filename + ".csr"; + + rocsparse_int Mb = -1; + rocsparse_int Kb = -1; + if(block_dim > 0) + { + Mb = (M + block_dim - 1) / block_dim; + Kb = (K + block_dim - 1) / block_dim; + } + + T h_alpha = arg.get_alpha(); + T h_beta = arg.get_beta(); + + // Create rocsparse handle + rocsparse_local_handle handle; + + // Create matrix descriptor + rocsparse_local_mat_descr descr; + + // Set matrix index base + CHECK_ROCSPARSE_ERROR(rocsparse_set_mat_index_base(descr, base)); + + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_host)); + + // Argument sanity check before allocating invalid memory + if(Mb <= 0 || N <= 0 || Kb <= 0 || block_dim <= 0) + { + static const size_t safe_size = 100; + + // Allocate memory on device + device_vector dbsr_row_ptr(safe_size); + device_vector dbsr_col_ind(safe_size); + device_vector dbsr_val(safe_size); + device_vector dB(safe_size); + device_vector dC(safe_size); + + if(!dbsr_row_ptr || !dbsr_col_ind || !dbsr_val || !dB || !dC) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + direction, + transA, + transB, + Mb, + N, + Kb, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + safe_size, + &h_beta, + dC, + safe_size), + (Mb < 0 || N < 0 || Kb < 0 || block_dim <= 0) + ? rocsparse_status_invalid_size + : rocsparse_status_success); + + return; + } + + // Allocate host memory for original CSR matrix + host_vector hcsr_row_ptr_orig; + host_vector hcsr_col_ind_orig; + host_vector hcsr_val_orig; + + // Allocate host memory for output BSR matrix + host_vector hbsr_row_ptr; + host_vector hbsr_col_ind; + host_vector hbsr_val; + + rocsparse_seedrand(); + + // Generate original host CSR matrix and then use it to fill in the host BSR matrix + rocsparse_int nnzb = 0; + rocsparse_int dummy = 0; + rocsparse_init_csr_and_bsr_matrix(hcsr_row_ptr_orig, + hcsr_col_ind_orig, + hcsr_val_orig, + M, + K, + base, + hbsr_row_ptr, + hbsr_col_ind, + hbsr_val, + direction, + Mb, + Kb, + block_dim, + dummy, + dim_x, + dim_y, + dim_z, + nnzb, + base, + mat, + filename.c_str(), + false, + full_rank); + + // M and K and Mb and Kb can be modified by rocsparse_init_csr_and_bsr_matrix + M = Mb * block_dim; + K = Kb * block_dim; + + // Some matrix properties + rocsparse_int ldb = (transB == rocsparse_operation_none) ? K : N; + rocsparse_int ldc = M; + + rocsparse_int ncol_B = (transB == rocsparse_operation_none ? N : K); + rocsparse_int nnz_B = ldb * ncol_B; + rocsparse_int nnz_C = ldc * N; + + // Allocate host memory for dense matrices + host_vector hB(nnz_B); + host_vector hC_1(nnz_C); + host_vector hC_2(nnz_C); + host_vector hC_gold(nnz_C); + + // Initialize data on CPU + rocsparse_init(hB, ldb, ncol_B, ldb); + rocsparse_init(hC_1, ldc, N, ldc); + hC_2 = hC_1; + hC_gold = hC_1; + + // Allocate device memory + device_vector dbsr_row_ptr(Mb + 1); + device_vector dbsr_col_ind(nnzb); + device_vector dbsr_val(nnzb * block_dim * block_dim); + device_vector dB(nnz_B); + device_vector dC_1(nnz_C); + device_vector dC_2(nnz_C); + device_vector d_alpha(1); + device_vector d_beta(1); + + if(!dbsr_row_ptr || !dbsr_col_ind || !dbsr_val || !dB || !dC_1 || !dC_2 || !d_alpha || !d_beta) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy( + dbsr_row_ptr, hbsr_row_ptr, sizeof(rocsparse_int) * (Mb + 1), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dbsr_col_ind, hbsr_col_ind, sizeof(rocsparse_int) * nnzb, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy( + dbsr_val, hbsr_val, sizeof(T) * nnzb * block_dim * block_dim, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, hB, sizeof(T) * nnz_B, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dC_1, hC_1, sizeof(T) * nnz_C, hipMemcpyHostToDevice)); + + if(arg.unit_check) + { + // Copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(dC_2, hC_2, sizeof(T) * nnz_C, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(d_alpha, &h_alpha, sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(d_beta, &h_beta, sizeof(T), hipMemcpyHostToDevice)); + + // Pointer mode host + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_host)); + CHECK_ROCSPARSE_ERROR(rocsparse_bsrmm(handle, + direction, + transA, + transB, + Mb, + N, + Kb, + nnzb, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + ldb, + &h_beta, + dC_1, + ldc)); + + // Pointer mode device + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_device)); + CHECK_ROCSPARSE_ERROR(rocsparse_bsrmm(handle, + direction, + transA, + transB, + Mb, + N, + Kb, + nnzb, + d_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + ldb, + d_beta, + dC_2, + ldc)); + + // Copy output to host + CHECK_HIP_ERROR(hipMemcpy(hC_1, dC_1, sizeof(T) * nnz_C, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(hC_2, dC_2, sizeof(T) * nnz_C, hipMemcpyDeviceToHost)); + + // CPU bsrmm + host_bsrmm(Mb, + N, + Kb, + block_dim, + direction, + transA, + transB, + h_alpha, + hbsr_row_ptr, + hbsr_col_ind, + hbsr_val, + hB, + ldb, + h_beta, + hC_gold, + ldc, + base); + + near_check_general(ldc, N, ldc, hC_gold, hC_1); + near_check_general(ldc, N, ldc, hC_gold, hC_2); + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = arg.iters; + + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_host)); + + // Warm up + for(int iter = 0; iter < number_cold_calls; ++iter) + { + CHECK_ROCSPARSE_ERROR(rocsparse_bsrmm(handle, + direction, + transA, + transB, + Mb, + N, + Kb, + nnzb, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + ldb, + &h_beta, + dC_1, + ldc)); + } + + double gpu_time_used = get_time_us(); + + // Performance run + for(int iter = 0; iter < number_hot_calls; ++iter) + { + CHECK_ROCSPARSE_ERROR(rocsparse_bsrmm(handle, + direction, + transA, + transB, + Mb, + N, + Kb, + nnzb, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + ldb, + &h_beta, + dC_1, + ldc)); + } + + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + double gpu_gflops + = bsrmm_gflop_count(N, nnzb, block_dim, nnz_C, h_beta != static_cast(0)) + / gpu_time_used * 1e6; + double gpu_gbyte + = bsrmm_gbyte_count(Mb, nnzb, block_dim, nnz_B, nnz_C, h_beta != static_cast(0)) + / gpu_time_used * 1e6; + + std::cout.precision(2); + std::cout.setf(std::ios::fixed); + std::cout.setf(std::ios::left); + + std::cout << std::setw(12) << "M" << std::setw(12) << "N" << std::setw(12) << "K" + << std::setw(12) << "dir" << std::setw(12) << "transA" << std::setw(12) + << "transB" << std::setw(12) << "nnzb" << std::setw(12) << "block_dim" + << std::setw(12) << "nnz_B" << std::setw(12) << "nnz_C" << std::setw(12) + << "alpha" << std::setw(12) << "beta" << std::setw(12) << "GFlop/s" + << std::setw(12) << "GB/s" << std::setw(12) << "msec" << std::setw(12) << "iter" + << std::setw(12) << "verified" << std::endl; + + std::cout << std::setw(12) << M << std::setw(12) << N << std::setw(12) << K << std::setw(12) + << rocsparse_direction2string(direction) << std::setw(12) + << rocsparse_operation2string(transA) << std::setw(12) + << rocsparse_operation2string(transB) << std::setw(12) << nnzb << std::setw(12) + << block_dim << std::setw(12) << nnz_B << std::setw(12) << nnz_C << std::setw(12) + << h_alpha << std::setw(12) << h_beta << std::setw(12) << gpu_gflops + << std::setw(12) << gpu_gbyte << std::setw(12) << gpu_time_used / 1e3 + << std::setw(12) << number_hot_calls << std::setw(12) + << (arg.unit_check ? "yes" : "no") << std::endl; + } +} + +#endif // TESTING_BSRMM_HPP \ No newline at end of file diff --git a/clients/samples/CMakeLists.txt b/clients/samples/CMakeLists.txt index 77b0a067..b10c4f40 100644 --- a/clients/samples/CMakeLists.txt +++ b/clients/samples/CMakeLists.txt @@ -62,6 +62,8 @@ add_rocsparse_example(example_hybmv.cpp) add_rocsparse_example(example_csrsv.cpp) # Level 3 +add_rocsparse_example(example_bsrmm.cpp) +add_rocsparse_example(example_fortran_bsrmm.f90) add_rocsparse_example(example_csrmm.cpp) add_rocsparse_example(example_csrsm.cpp) diff --git a/clients/samples/example_bsrmm.cpp b/clients/samples/example_bsrmm.cpp new file mode 100644 index 00000000..c1eae5f7 --- /dev/null +++ b/clients/samples/example_bsrmm.cpp @@ -0,0 +1,216 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#include +#include +#include + +#define HIP_CHECK(stat) \ + { \ + if(stat != hipSuccess) \ + { \ + std::cerr << "Error: hip error in line " << __LINE__ << std::endl; \ + return -1; \ + } \ + } + +#define ROCSPARSE_CHECK(stat) \ + { \ + if(stat != rocsparse_status_success) \ + { \ + std::cerr << "Error: rocsparse error in line " << __LINE__ << std::endl; \ + return -1; \ + } \ + } + +int main(int argc, char* argv[]) +{ + // Query device + int ndev; + HIP_CHECK(hipGetDeviceCount(&ndev)); + + if(ndev < 1) + { + std::cerr << "No HIP device found" << std::endl; + return -1; + } + + // Query device properties + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, 0)); + + std::cout << "Device: " << prop.name << std::endl; + + // rocSPARSE handle + rocsparse_handle handle; + ROCSPARSE_CHECK(rocsparse_create_handle(&handle)); + + // Print rocSPARSE version and revision + int ver; + char rev[64]; + + ROCSPARSE_CHECK(rocsparse_get_version(handle, &ver)); + ROCSPARSE_CHECK(rocsparse_get_git_rev(handle, rev)); + + std::cout << "rocSPARSE version: " << ver / 100000 << "." << ver / 100 % 1000 << "." + << ver % 100 << "-" << rev << std::endl; + + // Input data + + // Matrix A (m x k) + // ( 1 2 0 3 0 0 ) + // A = ( 0 4 5 0 0 0 ) + // ( 0 0 0 7 8 0 ) + // ( 0 0 1 2 4 1 ) + + // Number of rows and columns + rocsparse_int block_dim = 2; + rocsparse_int mb = 2; + rocsparse_int kb = 3; + rocsparse_int n = 10; + rocsparse_int m = mb * block_dim; + rocsparse_int k = kb * block_dim; + + // Number of non-zero block entries + rocsparse_int nnzb = 4; + + // BSR row pointers + rocsparse_int hbsr_row_ptr[3] = {0, 2, 4}; + + // BSR column indices + rocsparse_int hbsr_col_ind[4] = {0, 1, 1, 2}; + + // BSR values + double hbsr_val[16] + = {1.0, 2.0, 0.0, 4.0, 0.0, 3.0, 5.0, 0.0, 0.0, 7.0, 1.0, 2.0, 8.0, 0.0, 4.0, 1.0}; + + // Transposition of the matrix + rocsparse_direction dir = rocsparse_direction_row; + rocsparse_operation transA = rocsparse_operation_none; + rocsparse_operation transB = rocsparse_operation_none; + + // Matrix B (k x n) column major order + // ( 9 11 13 15 17 10 12 14 16 18 ) + // ( 8 10 1 10 6 11 7 3 12 17 ) + // B = ( 11 11 0 4 6 12 2 9 13 2 ) + // ( 15 3 2 3 8 1 2 4 6 6 ) + // ( 2 5 7 0 1 15 9 4 10 1 ) + // ( 7 12 12 1 12 5 1 11 1 14 ) + + // Matrix B in column-major + rocsparse_int ldb = k; + double hB[6 * 10] + = {9, 8, 11, 15, 2, 7, 11, 10, 11, 3, 5, 12, 13, 1, 0, 2, 7, 12, 15, 10, + 4, 3, 0, 1, 17, 6, 6, 8, 1, 12, 10, 11, 12, 1, 15, 5, 12, 7, 2, 2, + 9, 1, 14, 3, 9, 4, 4, 11, 16, 12, 13, 6, 10, 1, 18, 17, 2, 6, 1, 14}; + + // Matrix C (m x n) column major order + // ( 0 0 0 0 0 0 0 0 0 0 ) + // C = ( 0 0 0 0 0 0 0 0 0 0 ) + // ( 0 0 0 0 0 0 0 0 0 0 ) + // ( 0 0 0 0 0 0 0 0 0 0 ) + + // Matrix C (m x n) in column-major + rocsparse_int ldc = m; + double hC[4 * 10] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + // Scalar alpha and beta + double alpha = 1.0; + double beta = 0.0; + + // Matrix descriptor + rocsparse_mat_descr descr; + ROCSPARSE_CHECK(rocsparse_create_mat_descr(&descr)); + + // Offload data to device + rocsparse_int* dbsr_row_ptr; + rocsparse_int* dbsr_col_ind; + double* dbsr_val; + double* dB; + double* dC; + + HIP_CHECK(hipMalloc((void**)&dbsr_row_ptr, sizeof(rocsparse_int) * (mb + 1))); + HIP_CHECK(hipMalloc((void**)&dbsr_col_ind, sizeof(rocsparse_int) * nnzb)); + HIP_CHECK(hipMalloc((void**)&dbsr_val, sizeof(double) * nnzb * block_dim * block_dim)); + HIP_CHECK(hipMalloc((void**)&dB, sizeof(double) * k * n)); + HIP_CHECK(hipMalloc((void**)&dC, sizeof(double) * m * n)); + + HIP_CHECK(hipMemcpy( + dbsr_row_ptr, hbsr_row_ptr, sizeof(rocsparse_int) * (mb + 1), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(dbsr_col_ind, hbsr_col_ind, sizeof(rocsparse_int) * nnzb, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy( + dbsr_val, hbsr_val, sizeof(double) * nnzb * block_dim * block_dim, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dB, hB, sizeof(double) * k * n, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dC, hC, sizeof(double) * m * n, hipMemcpyHostToDevice)); + + // Call dbsrmm + ROCSPARSE_CHECK(rocsparse_dbsrmm(handle, + dir, + transA, + transB, + mb, + n, + kb, + nnzb, + &alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + ldb, + &beta, + dC, + ldc)); + + // Print result + HIP_CHECK(hipMemcpy(hC, dC, sizeof(double) * m * n, hipMemcpyDeviceToHost)); + + std::cout << "C:" << std::endl; + + for(int i = 0; i < m; ++i) + { + for(int j = 0; j < n; ++j) + { + std::cout << " " << hC[i + j * ldc]; + } + + std::cout << std::endl; + } + + // Clear rocSPARSE + ROCSPARSE_CHECK(rocsparse_destroy_mat_descr(descr)); + ROCSPARSE_CHECK(rocsparse_destroy_handle(handle)); + + // Clear device memory + HIP_CHECK(hipFree(dbsr_row_ptr)); + HIP_CHECK(hipFree(dbsr_col_ind)); + HIP_CHECK(hipFree(dbsr_val)); + HIP_CHECK(hipFree(dB)); + HIP_CHECK(hipFree(dC)); + + return 0; +} diff --git a/clients/samples/example_fortran_bsrmm.f90 b/clients/samples/example_fortran_bsrmm.f90 new file mode 100644 index 00000000..95b88a86 --- /dev/null +++ b/clients/samples/example_fortran_bsrmm.f90 @@ -0,0 +1,261 @@ +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +! Copyright (c) 2020 Advanced Micro Devices, Inc. +! +! Permission is hereby granted, free of charge, to any person obtaining a copy +! of this software and associated documentation files (the "Software"), to deal +! in the Software without restriction, including without limitation the rights +! to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +! copies of the Software, and to permit persons to whom the Software is +! furnished to do so, subject to the following conditions: +! +! The above copyright notice and this permission notice shall be included in +! all copies or substantial portions of the Software. +! +! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +! IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +! FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +! AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +! LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +! OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +! THE SOFTWARE. +! +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +subroutine HIP_CHECK(stat) + use iso_c_binding + + implicit none + + integer(c_int) :: stat + + if(stat /= 0) then + write(*,*) 'Error: hip error' + stop + end if + +end subroutine HIP_CHECK + +subroutine ROCSPARSE_CHECK(stat) + use iso_c_binding + + implicit none + + integer(c_int) :: stat + + if(stat /= 0) then + write(*,*) 'Error: rocsparse error' + stop + end if + +end subroutine ROCSPARSE_CHECK + +program example_fortran_bsrmm + use iso_c_binding + use rocsparse + + implicit none + + interface + function hipMalloc(ptr, size) & + result(c_int) & + bind(c, name = 'hipMalloc') + use iso_c_binding + implicit none + type(c_ptr) :: ptr + integer(c_size_t), value :: size + end function hipMalloc + + function hipFree(ptr) & + result(c_int) & + bind(c, name = 'hipFree') + use iso_c_binding + implicit none + type(c_ptr), value :: ptr + end function hipFree + + function hipMemcpy(dst, src, size, kind) & + result(c_int) & + bind(c, name = 'hipMemcpy') + use iso_c_binding + implicit none + type(c_ptr), value :: dst + type(c_ptr), intent(in), value :: src + integer(c_size_t), value :: size + integer(c_int), value :: kind + end function hipMemcpy + + function hipMemset(dst, val, size) & + result(c_int) & + bind(c, name = 'hipMemset') + use iso_c_binding + implicit none + type(c_ptr), value :: dst + integer(c_int), value :: val + integer(c_size_t), value :: size + end function hipMemset + + function hipDeviceSynchronize() & + result(c_int) & + bind(c, name = 'hipDeviceSynchronize') + use iso_c_binding + implicit none + end function hipDeviceSynchronize + + function hipDeviceReset() & + result(c_int) & + bind(c, name = 'hipDeviceReset') + use iso_c_binding + implicit none + end function hipDeviceReset + end interface + + integer, target :: h_bsr_row_ptr(3), h_bsr_col_ind(4) + real(8), target :: h_bsr_val(16), h_B(6 * 10), h_C(4 * 10) + + type(c_ptr) :: d_bsr_row_ptr + type(c_ptr) :: d_bsr_col_ind + type(c_ptr) :: d_bsr_val + type(c_ptr) :: d_B + type(c_ptr) :: d_C + + integer :: i, j + integer(c_int) :: M, Mb, N, K, Kb, nnzb, block_dim + + real(c_double), target :: alpha, beta + + type(c_ptr) :: handle + type(c_ptr) :: descr + + integer :: version + + character(len=12) :: rev + +! Input data +! ( 1 2 0 3 0 0 ) +! A = ( 0 4 5 0 0 0 ) +! ( 0 0 0 7 8 0 ) +! ( 0 0 1 2 4 1 ) + +! ( 9 11 13 15 17 10 12 14 16 18 ) +! ( 8 10 1 10 6 11 7 3 12 17 ) +! B = ( 11 11 0 4 6 12 2 9 13 2 ) +! ( 15 3 2 3 8 1 2 4 6 6 ) +! ( 2 5 7 0 1 15 9 4 10 1 ) +! ( 7 12 12 1 12 5 1 11 1 14 ) + +! Number of rows and columns + block_dim = 2 + Mb = 2 + Kb = 3 + N = 10 + M = Mb * block_dim + K = Kb * block_dim + +! Number of non-zero blocks + nnzb = 4 + +! Fill BSR structure + h_bsr_row_ptr = (/0, 2, 4/) + h_bsr_col_ind = (/0, 1, 1, 2/) + h_bsr_val = (/1, 2, 0, 4, 0, 3, 5, 0, 0, 7, 1, 2, 8, 0, 4, 1/) + +! Scalar alpha and beta + alpha = 1.0 + beta = 0.0 + +! Fill B in column-major + h_B = (/9, 8, 11, 15, 2, 7, & + 11, 10, 11, 3, 5, 12, & + 13, 1, 0, 2, 7, 12, & + 15, 10, 4, 3, 0, 1, & + 17, 6, 6, 8, 1, 12, & + 10, 11, 12, 1, 15, 5, & + 12, 7, 2, 2, 9, 1, & + 14, 3, 9, 4, 4, 11, & + 16, 12, 13, 6, 10, 1, & + 18, 17, 2, 6, 1, 14/) + +! Fill C in column-major + h_C = (/0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0/) + +! Allocate device memory + call HIP_CHECK(hipMalloc(d_bsr_row_ptr, (int(Mb, c_size_t) + 1) * 4)) + call HIP_CHECK(hipMalloc(d_bsr_col_ind, int(nnzb, c_size_t) * 4)) + call HIP_CHECK(hipMalloc(d_bsr_val, int(nnzb * block_dim * block_dim, c_size_t) * 8)) + call HIP_CHECK(hipMalloc(d_B, int(K * N, c_size_t) * 8)) + call HIP_CHECK(hipMalloc(d_C, int(M * N, c_size_t) * 8)) + +! Copy host data to device + call HIP_CHECK(hipMemcpy(d_bsr_row_ptr, c_loc(h_bsr_row_ptr), (int(Mb, c_size_t) + 1) * 4, 1)) + call HIP_CHECK(hipMemcpy(d_bsr_col_ind, c_loc(h_bsr_col_ind), int(nnzb, c_size_t) * 4, 1)) + call HIP_CHECK(hipMemcpy(d_bsr_val, c_loc(h_bsr_val), int(nnzb * block_dim * block_dim, c_size_t) * 8, 1)) + call HIP_CHECK(hipMemcpy(d_B, c_loc(h_B), int(K * N, c_size_t) * 8, 1)) + call HIP_CHECK(hipMemcpy(d_C, c_loc(h_C), int(M * N, c_size_t) * 8, 1)) + +! Create rocSPARSE handle + call ROCSPARSE_CHECK(rocsparse_create_handle(handle)) + +! Get rocSPARSE version + call ROCSPARSE_CHECK(rocsparse_get_version(handle, version)) + call ROCSPARSE_CHECK(rocsparse_get_git_rev(handle, rev)) + +! Print version on screen + write(*,fmt='(A,I0,A,I0,A,I0,A,A)') 'rocSPARSE version: ', version / 100000, '.', & + mod(version / 100, 1000), '.', mod(version, 100), '-', rev + +! Create matrix descriptor + call ROCSPARSE_CHECK(rocsparse_create_mat_descr(descr)) + +! Perform the matrix multiplication + call ROCSPARSE_CHECK(rocsparse_dbsrmm(handle, & + rocsparse_direction_row, & + rocsparse_operation_none, & + rocsparse_operation_none, & + Mb, & + N, & + Kb, & + nnzb, & + c_loc(alpha), & + descr, & + d_bsr_val, & + d_bsr_row_ptr, & + d_bsr_col_ind, & + block_dim, & + d_B, & + K, & + c_loc(beta), & + d_C, & + M)) + +! Print result + call HIP_CHECK(hipMemcpy(c_loc(h_C), d_C, int(M * N, c_size_t) * 8, 2)) + +! Note: C in column major ordering + do i = 1, M + do j = 1, N + write(*,fmt='(A,F6.2)',advance='no') ' ', h_C(M * (j - 1) + i) + end do + write(*,*) + end do + +! Clear rocSPARSE + call ROCSPARSE_CHECK(rocsparse_destroy_mat_descr(descr)) + call ROCSPARSE_CHECK(rocsparse_destroy_handle(handle)) + +! Clear device memory + call HIP_CHECK(hipFree(d_bsr_row_ptr)) + call HIP_CHECK(hipFree(d_bsr_col_ind)) + call HIP_CHECK(hipFree(d_bsr_val)) + call HIP_CHECK(hipFree(d_B)) + call HIP_CHECK(hipFree(d_C)) + +end program example_fortran_bsrmm diff --git a/clients/tests/CMakeLists.txt b/clients/tests/CMakeLists.txt index 993db90d..0112f92c 100644 --- a/clients/tests/CMakeLists.txt +++ b/clients/tests/CMakeLists.txt @@ -168,6 +168,7 @@ set(ROCSPARSE_TEST_SOURCES test_csrsv.cpp test_ellmv.cpp test_hybmv.cpp + test_bsrmm.cpp test_csrmm.cpp test_csrsm.cpp test_csrgeam.cpp @@ -229,7 +230,7 @@ set_target_properties(rocsparse-test PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${PROJ set(ROCSPARSE_TEST_DATA "${PROJECT_BINARY_DIR}/staging/rocsparse_test.data") add_custom_command(OUTPUT "${ROCSPARSE_TEST_DATA}" COMMAND ../common/rocsparse_gentest.py -I ../include rocsparse_test.yaml -o "${ROCSPARSE_TEST_DATA}" - DEPENDS ../common/rocsparse_gentest.py rocsparse_test.yaml ../include/rocsparse_common.yaml known_bugs.yaml test_axpyi.yaml test_doti.yaml test_dotci.yaml test_gthr.yaml test_gthrz.yaml test_roti.yaml test_sctr.yaml test_bsrmv.yaml test_bsrsv.yaml test_coomv.yaml test_csrmv.yaml test_csrsv.yaml test_ellmv.yaml test_hybmv.yaml test_csrmm.yaml test_csrsm.yaml test_csrgeam.yaml test_csrgemm.yaml test_csric0.yaml test_csrilu0.yaml test_csr2coo.yaml test_csr2csc.yaml test_csr2ell.yaml test_csr2hyb.yaml test_bsr2csr.yaml test_csr2bsr.yaml test_coo2csr.yaml test_ell2csr.yaml test_hyb2csr.yaml test_identity.yaml test_csrsort.yaml test_cscsort.yaml test_coosort.yaml test_csricsv.yaml test_csrilusv.yaml test_nnz.yaml test_dense2csr.yaml test_dense2csc.yaml test_csr2dense.yaml test_csc2dense.yaml test_csr2csr_compress.cpp + DEPENDS ../common/rocsparse_gentest.py rocsparse_test.yaml ../include/rocsparse_common.yaml known_bugs.yaml test_axpyi.yaml test_doti.yaml test_dotci.yaml test_gthr.yaml test_gthrz.yaml test_roti.yaml test_sctr.yaml test_bsrmv.yaml test_bsrsv.yaml test_coomv.yaml test_csrmv.yaml test_csrsv.yaml test_ellmv.yaml test_hybmv.yaml test_bsrmm.yaml test_csrmm.yaml test_csrsm.yaml test_csrgeam.yaml test_csrgemm.yaml test_csric0.yaml test_csrilu0.yaml test_csr2coo.yaml test_csr2csc.yaml test_csr2ell.yaml test_csr2hyb.yaml test_bsr2csr.yaml test_csr2bsr.yaml test_coo2csr.yaml test_ell2csr.yaml test_hyb2csr.yaml test_identity.yaml test_csrsort.yaml test_cscsort.yaml test_coosort.yaml test_csricsv.yaml test_csrilusv.yaml test_nnz.yaml test_dense2csr.yaml test_dense2csc.yaml test_csr2dense.yaml test_csc2dense.yaml test_csr2csr_compress.cpp WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}") add_custom_target(rocsparse-test-data DEPENDS "${ROCSPARSE_TEST_DATA}" ) diff --git a/clients/tests/rocsparse_test.yaml b/clients/tests/rocsparse_test.yaml index 6c1b6096..1513081d 100644 --- a/clients/tests/rocsparse_test.yaml +++ b/clients/tests/rocsparse_test.yaml @@ -35,6 +35,7 @@ include: test_csrmv.yaml include: test_csrsv.yaml include: test_ellmv.yaml include: test_hybmv.yaml +include: test_bsrmm.yaml include: test_csrmm.yaml include: test_csrsm.yaml include: test_csrgeam.yaml diff --git a/clients/tests/test_bsrmm.cpp b/clients/tests/test_bsrmm.cpp new file mode 100644 index 00000000..72cdf5a7 --- /dev/null +++ b/clients/tests/test_bsrmm.cpp @@ -0,0 +1,117 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#include "rocsparse_data.hpp" +#include "rocsparse_datatype2string.hpp" +#include "rocsparse_test.hpp" +#include "testing_bsrmm.hpp" +#include "type_dispatch.hpp" + +#include +#include +#include + +namespace +{ + // By default, this test does not apply to any types. + // The unnamed second parameter is used for enable_if below. + template + struct bsrmm_testing : rocsparse_test_invalid + { + }; + + // When the condition in the second argument is satisfied, the type combination + // is valid. When the condition is false, this specialization does not apply. + template + struct bsrmm_testing< + T, + typename std::enable_if{} || std::is_same{} + || std::is_same{} + || std::is_same{}>::type> + { + explicit operator bool() + { + return true; + } + void operator()(const Arguments& arg) + { + if(!strcmp(arg.function, "bsrmm")) + testing_bsrmm(arg); + else if(!strcmp(arg.function, "bsrmm_bad_arg")) + testing_bsrmm_bad_arg(arg); + else + FAIL() << "Internal error: Test called with unknown function: " << arg.function; + } + }; + + struct bsrmm : RocSPARSE_Test + { + // Filter for which types apply to this suite + static bool type_filter(const Arguments& arg) + { + return rocsparse_simple_dispatch(arg); + } + + // Filter for which functions apply to this suite + static bool function_filter(const Arguments& arg) + { + return !strcmp(arg.function, "bsrmm") || !strcmp(arg.function, "bsrmm_bad_arg"); + } + + // Google Test name suffix based on parameters + static std::string name_suffix(const Arguments& arg) + { + if(arg.matrix == rocsparse_matrix_file_rocalution + || arg.matrix == rocsparse_matrix_file_mtx) + { + return RocSPARSE_TestName{} + << rocsparse_datatype2string(arg.compute_type) << '_' << arg.N << '_' + << arg.block_dim << '_' << rocsparse_direction2string(arg.direction) + << arg.alpha << '_' << arg.alphai << '_' << arg.beta << '_' << arg.betai + << '_' << rocsparse_operation2string(arg.transA) << '_' + << rocsparse_operation2string(arg.transB) << '_' + << rocsparse_indexbase2string(arg.baseA) << '_' + << rocsparse_matrix2string(arg.matrix) << '_' << arg.filename; + } + else + { + return RocSPARSE_TestName{} + << rocsparse_datatype2string(arg.compute_type) << '_' << arg.M << '_' + << arg.N << '_' << arg.K << '_' << arg.block_dim << '_' + << rocsparse_direction2string(arg.direction) << '_' << arg.alpha << '_' + << arg.alphai << '_' << arg.beta << '_' << arg.betai << '_' + << rocsparse_operation2string(arg.transA) << '_' + << rocsparse_operation2string(arg.transB) << '_' + << rocsparse_indexbase2string(arg.baseA) << '_' + << rocsparse_matrix2string(arg.matrix); + } + } + }; + + TEST_P(bsrmm, level3) + { + rocsparse_simple_dispatch(GetParam()); + } + INSTANTIATE_TEST_CATEGORIES(bsrmm); + +} // namespace diff --git a/clients/tests/test_bsrmm.yaml b/clients/tests/test_bsrmm.yaml new file mode 100644 index 00000000..f6517f9a --- /dev/null +++ b/clients/tests/test_bsrmm.yaml @@ -0,0 +1,207 @@ +# ######################################################################## +# Copyright (c) 2020 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ######################################################################## + +--- +include: rocsparse_common.yaml +include: known_bugs.yaml + +Definitions: + - &alpha_beta_range_quick + - { alpha: 1.0, beta: -1.0, alphai: 1.0, betai: -0.5 } + - { alpha: -0.5, beta: 0.5, alphai: -0.5, betai: 1.0 } + + - &alpha_beta_range_checkin + - { alpha: 2.0, beta: 0.0, alphai: 0.5, betai: 0.5 } + - { alpha: 0.0, beta: 1.0, alphai: 1.5, betai: 0.5 } + - { alpha: 3.0, beta: 1.0, alphai: 0.0, betai: -0.5 } + + - &alpha_beta_range_nightly + - { alpha: 0.0, beta: 0.0, alphai: 1.5, betai: 0.5 } + - { alpha: 2.0, beta: 0.67, alphai: 0.0, betai: 1.5 } + - { alpha: 3.0, beta: 1.0, alphai: 1.5, betai: 0.0 } + - { alpha: -0.5, beta: 0.5, alphai: 1.0, betai: -0.5 } + +Tests: +- name: bsrmm_bad_arg + category: pre_checkin + function: bsrmm_bad_arg + precision: *single_double_precisions_complex_real + +- name: bsrmm + category: quick + function: bsrmm + precision: *single_double_precisions_complex_real + M: [-1, 275, 708] + N: [-1, 128, 628] + K: [-1, 173, 747] + block_dim: [0, 5, 7, 16] + alpha_beta: *alpha_beta_range_quick + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none, rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + direction: [rocsparse_direction_row] + matrix: [rocsparse_matrix_random] + +- name: bsrmm + category: pre_checkin + function: bsrmm + precision: *single_double_precisions_complex_real + M: [0, 511, 2059] + N: [0, 7, 33] + K: [0, 391, 1375] + block_dim: [17, 25] + alpha_beta: *alpha_beta_range_checkin + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none, rocsparse_operation_transpose] + baseA: [rocsparse_index_base_one] + direction: [rocsparse_direction_column] + matrix: [rocsparse_matrix_random] + +- name: bsrmm + category: nightly + function: bsrmm + precision: *single_double_precisions_complex_real + M: [3943, 94912] + N: [27, 49] + K: [4134, 73291] + block_dim: [2, 9] + alpha_beta: *alpha_beta_range_nightly + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none, rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + direction: [rocsparse_direction_row] + matrix: [rocsparse_matrix_random] + +- name: bsrmm_file + category: quick + function: bsrmm + precision: *single_double_precisions + M: 1 + N: [4, 19] + K: 1 + block_dim: [4] + alpha_beta: *alpha_beta_range_quick + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none, rocsparse_operation_transpose] + baseA: [rocsparse_index_base_one] + direction: [rocsparse_direction_column] + matrix: [rocsparse_matrix_file_rocalution] + filename: [mac_econ_fwd500, + nos2, + nos4, + nos6, + scircuit] + +- name: bsrmm_file + category: pre_checkin + function: bsrmm + precision: *single_double_precisions + M: 1 + N: [73] + K: 1 + block_dim: [8] + alpha_beta: *alpha_beta_range_checkin + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none] + baseA: [rocsparse_index_base_zero] + direction: [rocsparse_direction_row] + matrix: [rocsparse_matrix_file_rocalution] + filename: [rma10, + mc2depi, + ASIC_320k, + nos1, + nos3, + nos5, + nos7] + +- name: bsrmm_file + category: nightly + function: bsrmm + precision: *single_double_precisions + M: 1 + N: [38] + K: 1 + block_dim: [3] + alpha_beta: *alpha_beta_range_nightly + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + direction: [rocsparse_direction_row] + matrix: [rocsparse_matrix_file_rocalution] + filename: [bibd_22_8, + bmwcra_1, + amazon0312, + Chebyshev4, + sme3Dc, + webbase-1M, + shipsec1] + +- name: bsrmm_file + category: quick + function: bsrmm + precision: *single_double_precisions_complex + M: 1 + N: [3, 21] + K: 1 + block_dim: [6] + alpha_beta: *alpha_beta_range_quick + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none] + baseA: [rocsparse_index_base_one] + direction: [rocsparse_direction_column] + matrix: [rocsparse_matrix_file_rocalution] + filename: [Chevron2, + qc2534] + +- name: bsrmm_file + category: pre_checkin + function: bsrmm + precision: *single_double_precisions_complex + M: 1 + N: [68] + K: 1 + block_dim: [11] + alpha_beta: *alpha_beta_range_checkin + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + direction: [rocsparse_direction_row] + matrix: [rocsparse_matrix_file_rocalution] + filename: [mplate, + Chevron3] + +- name: bsrmm_file + category: nightly + function: bsrmm + precision: *single_double_precisions_complex + M: 1 + N: [40] + K: 1 + block_dim: [7] + alpha_beta: *alpha_beta_range_nightly + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none] + baseA: [rocsparse_index_base_one] + direction: [rocsparse_direction_column] + matrix: [rocsparse_matrix_file_rocalution] + filename: [Chevron4] \ No newline at end of file diff --git a/docs/source/usermanual.rst b/docs/source/usermanual.rst index d859652c..b29b466a 100644 --- a/docs/source/usermanual.rst +++ b/docs/source/usermanual.rst @@ -642,6 +642,7 @@ Sparse Level 3 Functions ========================================================================= ====== ====== ============== ============== Function name single double single complex double complex ========================================================================= ====== ====== ============== ============== +:cpp:func:`rocsparse_Xbsrmm() ` x x x x :cpp:func:`rocsparse_Xcsrmm() ` x x x x :cpp:func:`rocsparse_Xcsrsm_buffer_size() ` x x x x :cpp:func:`rocsparse_Xcsrsm_analysis() ` x x x x @@ -1144,6 +1145,17 @@ This module holds all sparse level 3 routines. The sparse level 3 routines describe operations between a matrix in sparse format and multiple vectors in dense format that can also be seen as a dense matrix. +rocsparse_bsrmm() +----------------- + +.. doxygenfunction:: rocsparse_sbsrmm + :outline: +.. doxygenfunction:: rocsparse_dbsrmm + :outline: +.. doxygenfunction:: rocsparse_cbsrmm + :outline: +.. doxygenfunction:: rocsparse_zbsrmm + rocsparse_csrmm() ----------------- diff --git a/library/include/rocsparse-functions.h b/library/include/rocsparse-functions.h index 1b7b3fda..cf330d50 100644 --- a/library/include/rocsparse-functions.h +++ b/library/include/rocsparse-functions.h @@ -2539,6 +2539,250 @@ rocsparse_status rocsparse_zhybmv(rocsparse_handle handle, * =========================================================================== */ +/*! \ingroup level3_module + * \brief Sparse matrix dense matrix multiplication using BSR storage format + * + * \details + * \p rocsparse_bsrmm multiplies the scalar \f$\alpha\f$ with a sparse \f$mb \times kb\f$ + * matrix \f$A\f$, defined in BSR storage format, and the dense \f$k \times n\f$ + * matrix \f$B\f$ (where \f$k = block\_dim \times kb\f$) and adds the result to the dense + * \f$m \times n\f$ matrix \f$C\f$ (where \f$m = block\_dim \times mb\f$) that + * is multiplied by the scalar \f$\beta\f$, such that + * \f[ + * C := \alpha \cdot op(A) \cdot op(B) + \beta \cdot C, + * \f] + * with + * \f[ + * op(A) = \left\{ + * \begin{array}{ll} + * A, & \text{if trans_A == rocsparse_operation_none} \\ + * \end{array} + * \right. + * \f] + * and + * \f[ + * op(B) = \left\{ + * \begin{array}{ll} + * B, & \text{if trans_B == rocsparse_operation_none} \\ + * B^T, & \text{if trans_B == rocsparse_operation_transpose} \\ + * \end{array} + * \right. + * \f] + * + * \note + * This function is non blocking and executed asynchronously with respect to the host. + * It may return before the actual computation has finished. + * + * \note + * Currently, only \p trans_A == \ref rocsparse_operation_none is supported. + * + * @param[in] + * handle handle to the rocsparse library context queue. + * @param[in] + * dir the storage format of the blocks. Can be \ref rocsparse_direction_row or \ref rocsparse_direction_column. + * @param[in] + * trans_A matrix \f$A\f$ operation type. Currently, only \ref rocsparse_operation_none is supported. + * @param[in] + * trans_B matrix \f$B\f$ operation type. Currently, only \ref rocsparse_operation_none and rocsparse_operation_transpose + * are supported. + * @param[in] + * mb number of block rows of the sparse BSR matrix \f$A\f$. + * @param[in] + * n number of columns of the dense matrix \f$op(B)\f$ and \f$C\f$. + * @param[in] + * kb number of block columns of the sparse BSR matrix \f$A\f$. + * @param[in] + * nnzb number of non-zero blocks of the sparse BSR matrix \f$A\f$. + * @param[in] + * alpha scalar \f$\alpha\f$. + * @param[in] + * descr descriptor of the sparse BSR matrix \f$A\f$. Currently, only + * \ref rocsparse_matrix_type_general is supported. + * @param[in] + * bsr_val array of \p nnzb*block_dim*block_dim elements of the sparse BSR matrix \f$A\f$. + * @param[in] + * bsr_row_ptr array of \p mb+1 elements that point to the start of every block row of the + * sparse BSR matrix \f$A\f$. + * @param[in] + * bsr_col_ind array of \p nnzb elements containing the block column indices of the sparse + * BSR matrix \f$A\f$. + * @param[in] + * block_dim size of the blocks in the sparse BSR matrix. + * @param[in] + * B array of dimension \f$ldb \times n\f$ (\f$op(B) == B\f$) or + * \f$ldb \times k\f$ (\f$op(B) == B^T\f$). + * @param[in] + * ldb leading dimension of \f$B\f$, must be at least \f$\max{(1, k)}\f$ where \f$k = block\_dim \times kb\f$. + * @param[in] + * beta scalar \f$\beta\f$. + * @param[inout] + * C array of dimension \f$ldc \times n\f$. + * @param[in] + * ldc leading dimension of \f$C\f$, must be at least \f$\max{(1, m)}\f$ where \f$m = block\_dim \times mb\f$. + * + * \retval rocsparse_status_success the operation completed successfully. + * \retval rocsparse_status_invalid_handle the library context was not initialized. + * \retval rocsparse_status_invalid_size \p mb, \p n, \p kb, \p nnzb, \p ldb or \p ldc + * is invalid. + * \retval rocsparse_status_invalid_pointer \p descr, \p alpha, \p bsr_val, + * \p bsr_row_ptr, \p bsr_col_ind, \p B, \p beta or \p C pointer is invalid. + * \retval rocsparse_status_arch_mismatch the device is not supported. + * \retval rocsparse_status_not_implemented + * \p trans_A != \ref rocsparse_operation_none or + * \p trans_B == \ref rocsparse_operation_conjugate_transpose or + * \ref rocsparse_matrix_type != \ref rocsparse_matrix_type_general. + * + * \par Example + * This example multiplies a BSR matrix with a dense matrix. + * \code{.c} + * // 1 2 0 3 0 0 + * // A = 0 4 5 0 0 0 + * // 0 0 0 7 8 0 + * // 0 0 1 2 4 1 + * + * rocsparse_int block_dim = 2; + * rocsparse_int mb = 2; + * rocsparse_int kb = 3; + * rocsparse_int nnzb = 4; + * rocsparse_direction dir = rocsparse_direction_row; + * + * bsr_row_ptr[mb+1] = {0, 2, 4}; // device memory + * bsr_col_ind[nnzb] = {0, 1, 1, 2}; // device memory + * bsr_val[nnzb*block_dim*block_dim] = {1, 2, 0, 4, 0, 3, 5, 0, 0, 7, 1, 2, 8, 0, 4, 1}; // device memory + * + * // Set dimension n of B + * rocsparse_int n = 64; + * rocsparse_int m = mb * block_dim; + * rocsparse_int k = kb * block_dim; + * + * // Allocate and generate dense matrix B + * std::vector hB(k * n); + * for(rocsparse_int i = 0; i < k * n; ++i) + * { + * hB[i] = static_cast(rand()) / RAND_MAX; + * } + * + * // Copy B to the device + * float* B; + * hipMalloc((void**)&B, sizeof(float) * k * n); + * hipMemcpy(B, hB.data(), sizeof(float) * k * n, hipMemcpyHostToDevice); + * + * // alpha and beta + * float alpha = 1.0f; + * float beta = 0.0f; + * + * // Allocate memory for the resulting matrix C + * float* C; + * hipMalloc((void**)&C, sizeof(float) * m * n); + * + * // Perform the matrix multiplication + * rocsparse_sbsrmm(handle, + * dir, + * rocsparse_operation_none, + * rocsparse_operation_none, + * mb, + * n, + * kb, + * nnzb, + * &alpha, + * descr, + * bsr_val, + * bsr_row_ptr, + * bsr_col_ind, + * block_dim, + * B, + * k, + * &beta, + * C, + * m); + * \endcode + */ +/**@{*/ +ROCSPARSE_EXPORT +rocsparse_status rocsparse_sbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const float* alpha, + const rocsparse_mat_descr descr, + const float* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const float* B, + rocsparse_int ldb, + const float* beta, + float* C, + rocsparse_int ldc); + +ROCSPARSE_EXPORT +rocsparse_status rocsparse_dbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const double* alpha, + const rocsparse_mat_descr descr, + const double* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const double* B, + rocsparse_int ldb, + const double* beta, + double* C, + rocsparse_int ldc); + +ROCSPARSE_EXPORT +rocsparse_status rocsparse_cbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_float_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_float_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_float_complex* B, + rocsparse_int ldb, + const rocsparse_float_complex* beta, + rocsparse_float_complex* C, + rocsparse_int ldc); + +ROCSPARSE_EXPORT +rocsparse_status rocsparse_zbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_double_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_double_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_double_complex* B, + rocsparse_int ldb, + const rocsparse_double_complex* beta, + rocsparse_double_complex* C, + rocsparse_int ldc); +/**@}*/ + /*! \ingroup level3_module * \brief Sparse matrix dense matrix multiplication using CSR storage format * diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index b22630c9..a30dc45d 100644 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -51,6 +51,7 @@ set(rocsparse_source src/level2/rocsparse_hybmv.cpp # Level3 + src/level3/rocsparse_bsrmm.cpp src/level3/rocsparse_csrmm.cpp src/level3/rocsparse_csrsm.cpp diff --git a/library/src/level3/bsrmm_device.h b/library/src/level3/bsrmm_device.h new file mode 100644 index 00000000..315e0db2 --- /dev/null +++ b/library/src/level3/bsrmm_device.h @@ -0,0 +1,583 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#pragma once +#ifndef BSRMM_DEVICE_H +#define BSRMM_DEVICE_H + +#include "common.h" + +#include + +template +static __device__ void bsrmmnn_small_blockdim_device(rocsparse_direction direction, + rocsparse_int Mb, + rocsparse_int N, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + constexpr rocsparse_int PADDED_BSR_BLOCK_DIM = (BSR_BLOCK_DIM + 1); + + rocsparse_int tid = hipThreadIdx_x; + rocsparse_int gid = hipBlockIdx_x * hipBlockDim_x + tid; + rocsparse_int lid = gid & (WF_SIZE - 1); + rocsparse_int wid = tid / WF_SIZE; + rocsparse_int nwfb = hipGridDim_x * hipBlockDim_x / (WF_SIZE * BSR_BLOCK_DIM); + rocsparse_int col = lid + hipBlockIdx_y * WF_SIZE; + + rocsparse_int colB = col * ldb; + rocsparse_int colC = col * ldc; + + // global row + rocsparse_int global_row = (gid / WF_SIZE); + + // local row within block row + rocsparse_int local_row = (gid / WF_SIZE) % BSR_BLOCK_DIM; + + __shared__ rocsparse_int shared_col[BLOCKSIZE / WF_SIZE][WF_SIZE]; + __shared__ T shared_val[BLOCKSIZE / WF_SIZE][WF_SIZE * PADDED_BSR_BLOCK_DIM]; + + for(rocsparse_int block_row = gid / (WF_SIZE * BSR_BLOCK_DIM); block_row < Mb; + block_row += nwfb) + { + rocsparse_int block_row_start = bsr_row_ptr[block_row] - idx_base; + rocsparse_int block_row_end = bsr_row_ptr[block_row + 1] - idx_base; + + T sum = static_cast(0); + + for(rocsparse_int j = block_row_start; j < block_row_end; j += WF_SIZE) + { + rocsparse_int k = j + lid; + + shared_col[wid][lid] = (k < block_row_end) ? BSR_BLOCK_DIM * (bsr_col_ind[k] - idx_base) : 0; + + if(direction == rocsparse_direction_row) + { + // Perform: + // for(rocsparse_int l = 0; l < BSR_BLOCK_DIM; l++) + // { + // shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + l] + // = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + // + BSR_BLOCK_DIM * local_row + l] + // : static_cast(0); + // } + // as unrolled loop. + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row] + : static_cast(0); + if(BSR_BLOCK_DIM >= 2) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 1] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 1] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 3) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 2] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 2] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 4) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 3] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 3] + : static_cast(0); + } + } + else + { + // Perform: + // for(rocsparse_int l = 0; l < BSR_BLOCK_DIM; l++) + // { + // shared_val[wid][BSR_BLOCK_DIM * lid + l] + // = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + // + BSR_BLOCK_DIM * l + local_row] + // : static_cast(0); + // } + // as unrolled loop. + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + local_row] + : static_cast(0); + if(BSR_BLOCK_DIM >= 2) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 1] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 1 + local_row] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 3) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 2] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 2 + local_row] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 4) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 3] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 3 + local_row] + : static_cast(0); + } + } + + __syncthreads(); + + if(col < N) + { + for(rocsparse_int i = 0; i < WF_SIZE; ++i) + { + // Perform: + // for(rocsparse_int l = 0; l < BSR_BLOCK_DIM; l++) + // { + // sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + l], + // B[shared_col[wid][i] + l], + // sum); + // } + // as unrolled loop. + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i], + B[shared_col[wid][i] + colB], + sum); + if(BSR_BLOCK_DIM >= 2) + { + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 1], + B[shared_col[wid][i] + 1 + colB], + sum); + } + if(BSR_BLOCK_DIM >= 3) + { + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 2], + B[shared_col[wid][i] + 2 + colB], + sum); + } + if(BSR_BLOCK_DIM >= 4) + { + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 3], + B[shared_col[wid][i] + 3 + colB], + sum); + } + } + } + } + + if(col < N) + { + if(beta == static_cast(0)) + { + C[global_row + colC] = alpha * sum; + } + else + { + C[global_row + colC] = rocsparse_fma(beta, C[global_row + colC], alpha * sum); + } + } + } +} + +template +static __device__ void bsrmmnt_small_blockdim_device(rocsparse_direction direction, + rocsparse_int Mb, + rocsparse_int N, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + constexpr rocsparse_int PADDED_BSR_BLOCK_DIM = (BSR_BLOCK_DIM + 1); + + rocsparse_int tid = hipThreadIdx_x; + rocsparse_int gid = hipBlockIdx_x * hipBlockDim_x + tid; + rocsparse_int block_row = gid / (WF_SIZE * BSR_BLOCK_DIM); + rocsparse_int global_row = gid / WF_SIZE; + rocsparse_int local_row = (gid / WF_SIZE) % BSR_BLOCK_DIM; + rocsparse_int lid = tid & (WF_SIZE - 1); + rocsparse_int wid = tid / WF_SIZE; + + if(block_row >= Mb) + { + return; + } + + __shared__ rocsparse_int shared_col[BLOCKSIZE / WF_SIZE][WF_SIZE]; + __shared__ T shared_val[BLOCKSIZE / WF_SIZE][WF_SIZE * PADDED_BSR_BLOCK_DIM]; + + rocsparse_int block_row_start = bsr_row_ptr[block_row] - idx_base; + rocsparse_int block_row_end = bsr_row_ptr[block_row + 1] - idx_base; + + for(rocsparse_int l = 0; l < N; l += WF_SIZE) + { + rocsparse_int col = l + lid; + T sum = static_cast(0); + + for(rocsparse_int j = block_row_start; j < block_row_end; j += WF_SIZE) + { + rocsparse_int k = j + lid; + + shared_col[wid][lid] = (k < block_row_end) ? N * BSR_BLOCK_DIM * (bsr_col_ind[k] - idx_base) : 0; + + if(direction == rocsparse_direction_row) + { + // Perform: + // for(rocsparse_int p = 0; p < BSR_BLOCK_DIM; p++) + // { + // shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + p] + // = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + // + BSR_BLOCK_DIM * local_row + p] + // : static_cast(0); + // } + // as unrolled loop. + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row] + : static_cast(0); + if(BSR_BLOCK_DIM >= 2) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 1] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 1] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 3) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 2] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 2] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 4) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 3] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 3] + : static_cast(0); + } + } + else + { + // Perform: + // for(rocsparse_int p = 0; p < BSR_BLOCK_DIM; p++) + // { + // shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + p] + // = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + // + BSR_BLOCK_DIM * p + local_row] + // : static_cast(0); + // } + // as unrolled loop. + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + local_row] + : static_cast(0); + if(BSR_BLOCK_DIM >= 2) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 1] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 1 + local_row] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 3) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 2] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 2 + local_row] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 4) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 3] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 3 + local_row] + : static_cast(0); + } + } + + __syncthreads(); + + if(col < N) + { + for(rocsparse_int i = 0; i < WF_SIZE; ++i) + { + // Perform: + // for(rocsparse_int p = 0; p < BSR_BLOCK_DIM; p++) + // { + // T val_B = rocsparse_ldg(B + col + N * p + shared_col[wid][i]); + // sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + p], val_B, sum); + // } + // as unrolled loop. + T val_B + = rocsparse_ldg(B + col + shared_col[wid][i]); + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i], val_B, sum); + if(BSR_BLOCK_DIM >= 2) + { + val_B + = rocsparse_ldg(B + col + N * 1 + shared_col[wid][i]); + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 1], val_B, sum); + } + if(BSR_BLOCK_DIM >= 3) + { + val_B + = rocsparse_ldg(B + col + N * 2 + shared_col[wid][i]); + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 2], val_B, sum); + } + if(BSR_BLOCK_DIM >= 4) + { + val_B + = rocsparse_ldg(B + col + N * 3 + shared_col[wid][i]); + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 3], val_B, sum); + } + } + } + } + + if(col < N) + { + if(beta == static_cast(0)) + { + C[global_row + col * ldc] = alpha * sum; + } + else + { + C[global_row + col * ldc] = rocsparse_fma(beta, C[global_row + col * ldc], alpha * sum); + } + } + } +} + +template +static __device__ void bsrmm_large_blockdim_device(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int Mb, + rocsparse_int N, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + rocsparse_int tidx = hipThreadIdx_x; + rocsparse_int tidy = hipThreadIdx_y; + + rocsparse_int global_row = tidx + hipBlockIdx_x * block_dim; + rocsparse_int global_col = tidy + hipBlockIdx_y * BLK_SIZE_Y; + + rocsparse_int block_row = hipBlockIdx_x; + + rocsparse_int block_row_start = 0; + rocsparse_int block_row_end = 0; + if(block_row < Mb) + { + block_row_start = bsr_row_ptr[block_row] - idx_base; + block_row_end = bsr_row_ptr[block_row + 1] - idx_base; + } + + rocsparse_int colB = global_col * ldb; + rocsparse_int colC = global_col * ldc; + + __shared__ T shared_B[BSR_BLOCK_DIM * BLK_SIZE_Y]; + __shared__ T shared_A[BSR_BLOCK_DIM * BSR_BLOCK_DIM]; + + T sum = static_cast(0); + + rocsparse_int index = BSR_BLOCK_DIM * tidy + tidx; + rocsparse_int block_dim_sqr = block_dim * block_dim; + + for(rocsparse_int k = block_row_start; k < block_row_end; k++) + { + rocsparse_int block_col = (bsr_col_ind[k] - idx_base); + + if(trans_B == rocsparse_operation_none) + { + shared_B[index] + = (global_col < N && tidx < block_dim) ? B[block_dim * block_col + tidx + colB] : static_cast(0); + } + else + { + shared_B[index] + = (global_col < N && tidx < block_dim) ? B[global_col + ldb * (block_dim * block_col + tidx)] : static_cast(0); + } + + if(direction == rocsparse_direction_row) + { + if(tidx < block_dim && tidy < block_dim) + { + shared_A[index] = bsr_val[block_dim_sqr * k + block_dim * tidx + tidy]; + } + } + else + { + if(tidx < block_dim && tidy < block_dim) + { + shared_A[index] + = bsr_val[block_dim_sqr * k + block_dim * tidy + tidx]; + } + } + + __syncthreads(); + + for(rocsparse_int j = 0; j < block_dim; j++) + { + sum = rocsparse_fma(shared_A[BSR_BLOCK_DIM * j + tidx], + shared_B[BSR_BLOCK_DIM * tidy + j], + sum); + } + + __syncthreads(); + } + + if(block_row < Mb && global_col < N && tidx < block_dim) + { + if(beta == static_cast(0)) + { + C[global_row + colC] = alpha * sum; + } + else + { + C[global_row + colC] = rocsparse_fma(beta, C[global_row + colC], alpha * sum); + } + } +} + +template +static __device__ void bsrmm_general_blockdim_device(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int Mb, + rocsparse_int N, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + rocsparse_int tidx = hipThreadIdx_x; + rocsparse_int tidy = hipThreadIdx_y; + + rocsparse_int block_row = hipBlockIdx_x; + + rocsparse_int block_row_start = 0; + rocsparse_int block_row_end = 0; + if(block_row < Mb) + { + block_row_start = bsr_row_ptr[block_row] - idx_base; + block_row_end = bsr_row_ptr[block_row + 1] - idx_base; + } + + __shared__ T shared_B[BSR_BLOCK_DIM * BLK_SIZE_Y]; + __shared__ T shared_A[BSR_BLOCK_DIM * BSR_BLOCK_DIM]; + + rocsparse_int global_col = tidy + hipBlockIdx_y * BLK_SIZE_Y; + + rocsparse_int colB = global_col * ldb; + rocsparse_int colC = global_col * ldc; + + for(rocsparse_int x = 0; x < block_dim; x += BSR_BLOCK_DIM) + { + rocsparse_int global_row = tidx + x + hipBlockIdx_x * block_dim; + + T sum = static_cast(0); + + for(rocsparse_int k = block_row_start; k < block_row_end; k++) + { + rocsparse_int block_col = (bsr_col_ind[k] - idx_base); + + for(rocsparse_int y = 0; y < block_dim; y += BLK_SIZE_Y) + { + if(trans_B == rocsparse_operation_none) + { + shared_B[BSR_BLOCK_DIM * tidy + tidx] + = (global_col < N && (tidx + y) < block_dim) ? B[block_dim * block_col + (tidx + y) + colB] : static_cast(0); + } + else + { + shared_B[BSR_BLOCK_DIM * tidy + tidx] + = (global_col < N && (tidx + y) < block_dim) ? B[global_col + ldb * (block_dim * block_col + (tidx + y))] : static_cast(0); + } + + + if(direction == rocsparse_direction_row) + { + shared_A[BSR_BLOCK_DIM * tidy + tidx] + = ((tidx + x) < block_dim && (tidy + y) < block_dim) ? bsr_val[block_dim * block_dim * k + block_dim * (tidx + x) + (tidy + y)] : static_cast(0); + } + else + { + shared_A[BSR_BLOCK_DIM * tidy + tidx] + = ((tidx + x) < block_dim && (tidy + y) < block_dim) ? bsr_val[block_dim * block_dim * k + block_dim * (tidy + y) + (tidx + x)] : static_cast(0); + } + + __syncthreads(); + + for(rocsparse_int j = 0; j < BSR_BLOCK_DIM; j++) + { + sum = rocsparse_fma(shared_A[BSR_BLOCK_DIM * j + tidx], + shared_B[BSR_BLOCK_DIM * tidy + j], + sum); + } + + __syncthreads(); + } + } + + if(block_row < Mb && global_col < N && (tidx + x) < block_dim) + { + if(beta == static_cast(0)) + { + C[global_row + colC] = alpha * sum; + } + else + { + C[global_row + colC] = rocsparse_fma(beta, C[global_row + colC], alpha * sum); + } + } + } +} + +#endif // BSRMM_DEVICE_H \ No newline at end of file diff --git a/library/src/level3/csrmm_device.h b/library/src/level3/csrmm_device.h index 1d042d59..99011412 100644 --- a/library/src/level3/csrmm_device.h +++ b/library/src/level3/csrmm_device.h @@ -72,7 +72,7 @@ static __device__ void csrmmnn_general_device(rocsparse_int M, __syncthreads(); shared_col[wid][lid] = (k < row_end) ? csr_col_ind[k] - idx_base : 0; - shared_val[wid][lid] = (k < row_end) ? alpha * csr_val[k] : static_cast(0); + shared_val[wid][lid] = (k < row_end) ? csr_val[k] : static_cast(0); __syncthreads(); @@ -86,11 +86,11 @@ static __device__ void csrmmnn_general_device(rocsparse_int M, { if(beta == static_cast(0)) { - C[row + colC] = sum; + C[row + colC] = alpha * sum; } else { - C[row + colC] = rocsparse_fma(beta, C[row + colC], sum); + C[row + colC] = rocsparse_fma(beta, C[row + colC], alpha * sum); } } } @@ -143,7 +143,7 @@ static __device__ void csrmmnt_general_device(rocsparse_int offset, __syncthreads(); shared_col[wid][lid] = (k < row_end) ? N * (csr_col_ind[k] - idx_base) : 0; - shared_val[wid][lid] = (k < row_end) ? alpha * csr_val[k] : static_cast(0); + shared_val[wid][lid] = (k < row_end) ? csr_val[k] : static_cast(0); __syncthreads(); @@ -159,11 +159,11 @@ static __device__ void csrmmnt_general_device(rocsparse_int offset, { if(beta == static_cast(0)) { - C[row + col * ldc] = sum; + C[row + col * ldc] = alpha * sum; } else { - C[row + col * ldc] = rocsparse_fma(beta, C[row + col * ldc], sum); + C[row + col * ldc] = rocsparse_fma(beta, C[row + col * ldc], alpha * sum); } } } diff --git a/library/src/level3/rocsparse_bsrmm.cpp b/library/src/level3/rocsparse_bsrmm.cpp new file mode 100644 index 00000000..1e481059 --- /dev/null +++ b/library/src/level3/rocsparse_bsrmm.cpp @@ -0,0 +1,195 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#include "rocsparse.h" + +#include "rocsparse_bsrmm.hpp" + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ +extern "C" rocsparse_status rocsparse_sbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const float* alpha, + const rocsparse_mat_descr descr, + const float* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const float* B, + rocsparse_int ldb, + const float* beta, + float* C, + rocsparse_int ldc) +{ + return rocsparse_bsrmm_template(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +extern "C" rocsparse_status rocsparse_dbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const double* alpha, + const rocsparse_mat_descr descr, + const double* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const double* B, + rocsparse_int ldb, + const double* beta, + double* C, + rocsparse_int ldc) +{ + return rocsparse_bsrmm_template(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +extern "C" rocsparse_status rocsparse_cbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_float_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_float_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_float_complex* B, + rocsparse_int ldb, + const rocsparse_float_complex* beta, + rocsparse_float_complex* C, + rocsparse_int ldc) +{ + return rocsparse_bsrmm_template(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +extern "C" rocsparse_status rocsparse_zbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_double_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_double_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_double_complex* B, + rocsparse_int ldb, + const rocsparse_double_complex* beta, + rocsparse_double_complex* C, + rocsparse_int ldc) +{ + return rocsparse_bsrmm_template(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} diff --git a/library/src/level3/rocsparse_bsrmm.hpp b/library/src/level3/rocsparse_bsrmm.hpp new file mode 100644 index 00000000..9cfcb7e7 --- /dev/null +++ b/library/src/level3/rocsparse_bsrmm.hpp @@ -0,0 +1,868 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#pragma once +#ifndef ROCSPARSE_BSRMM_HPP +#define ROCSPARSE_BSRMM_HPP + +#include "bsrmm_device.h" +#include "handle.h" +#include "rocsparse.h" +#include "rocsparse_csrmm.hpp" +#include "../level2/rocsparse_bsrmv.hpp" +#include "utility.h" + +#include + +#define launch_bsrmmnn_small_blockdim_kernel_host_pointer(T, block_size, wf_size, bsr_block_dim) \ + hipLaunchKernelGGL( \ + (bsrmmnn_small_blockdim_kernel_host_pointer), \ + bsrmmnn_blocks, \ + bsrmmnn_threads, \ + 0, \ + stream, \ + dir, \ + mb, \ + n, \ + *alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + B, \ + ldb, \ + *beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmmnn_small_blockdim_kernel_device_pointer(T, block_size, wf_size, bsr_block_dim) \ + hipLaunchKernelGGL( \ + (bsrmmnn_small_blockdim_kernel_device_pointer), \ + bsrmmnn_blocks, \ + bsrmmnn_threads, \ + 0, \ + stream, \ + dir, \ + mb, \ + n, \ + alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + B, \ + ldb, \ + beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmmnt_small_blockdim_kernel_host_pointer(T, block_size, wf_size, bsr_block_dim) \ + hipLaunchKernelGGL( \ + (bsrmmnt_small_blockdim_kernel_host_pointer), \ + bsrmmnt_blocks, \ + bsrmmnt_threads, \ + 0, \ + stream, \ + dir, \ + mb, \ + n, \ + *alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + B, \ + ldb, \ + *beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmmnt_small_blockdim_kernel_device_pointer(T, block_size, wf_size, bsr_block_dim) \ + hipLaunchKernelGGL( \ + (bsrmmnt_small_blockdim_kernel_device_pointer), \ + bsrmmnt_blocks, \ + bsrmmnt_threads, \ + 0, \ + stream, \ + dir, \ + mb, \ + n, \ + alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + B, \ + ldb, \ + beta, \ + C, \ + ldc, \ + descr->base); + + #define launch_bsrmm_large_blockdim_kernel_host_pointer(T, bsr_block_dim, blk_size_y) \ + hipLaunchKernelGGL( \ + (bsrmm_large_blockdim_kernel_host_pointer), \ + bsrmm_blocks, \ + bsrmm_threads, \ + 0, \ + stream, \ + dir, \ + trans_B, \ + mb, \ + n, \ + *alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + block_dim, \ + B, \ + ldb, \ + *beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmm_large_blockdim_kernel_device_pointer(T, bsr_block_dim, blk_size_y) \ + hipLaunchKernelGGL( \ + (bsrmm_large_blockdim_kernel_device_pointer), \ + bsrmm_blocks, \ + bsrmm_threads, \ + 0, \ + stream, \ + dir, \ + trans_B, \ + mb, \ + n, \ + alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + block_dim, \ + B, \ + ldb, \ + beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmm_general_blockdim_kernel_host_pointer(T, bsr_block_dim, blk_size_y) \ + hipLaunchKernelGGL( \ + (bsrmm_general_blockdim_kernel_host_pointer), \ + bsrmm_blocks, \ + bsrmm_threads, \ + 0, \ + stream, \ + dir, \ + trans_B, \ + mb, \ + n, \ + *alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + block_dim, \ + B, \ + ldb, \ + *beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmm_general_blockdim_kernel_device_pointer(T, bsr_block_dim, blk_size_y) \ + hipLaunchKernelGGL( \ + (bsrmm_general_blockdim_kernel_device_pointer), \ + bsrmm_blocks, \ + bsrmm_threads, \ + 0, \ + stream, \ + dir, \ + trans_B, \ + mb, \ + n, \ + alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + block_dim, \ + B, \ + ldb, \ + beta, \ + C, \ + ldc, \ + descr->base); + +template +__launch_bounds__(BLOCKSIZE) __global__ + void bsrmmnn_small_blockdim_kernel_host_pointer(rocsparse_direction direction, + rocsparse_int mb, + rocsparse_int n, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(alpha == static_cast(0) && beta == static_cast(1)) + { + return; + } + + bsrmmnn_small_blockdim_device( + direction, mb, n, alpha, bsr_row_ptr, bsr_col_ind, bsr_val, B, ldb, beta, C, ldc, idx_base); +} + +template +__launch_bounds__(BLOCKSIZE) __global__ + void bsrmmnn_small_blockdim_kernel_device_pointer(rocsparse_direction direction, + rocsparse_int mb, + rocsparse_int n, + const T* alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + const T* beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(*alpha == static_cast(0) && *beta == static_cast(1)) + { + return; + } + + bsrmmnn_small_blockdim_device(direction, + mb, + n, + *alpha, + bsr_row_ptr, + bsr_col_ind, + bsr_val, + B, + ldb, + *beta, + C, + ldc, + idx_base); +} + +template +__launch_bounds__(BLOCKSIZE) __global__ + void bsrmmnt_small_blockdim_kernel_host_pointer(rocsparse_direction direction, + rocsparse_int mb, + rocsparse_int n, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(alpha == static_cast(0) && beta == static_cast(1)) + { + return; + } + + bsrmmnt_small_blockdim_device(direction, + mb, + n, + alpha, + bsr_row_ptr, + bsr_col_ind, + bsr_val, + B, + ldb, + beta, + C, + ldc, + idx_base); +} + +template +__launch_bounds__(BLOCKSIZE) __global__ + void bsrmmnt_small_blockdim_kernel_device_pointer(rocsparse_direction direction, + rocsparse_int mb, + rocsparse_int n, + const T* alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + const T* beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(*alpha == static_cast(0) && *beta == static_cast(1)) + { + return; + } + + bsrmmnt_small_blockdim_device(direction, + mb, + n, + *alpha, + bsr_row_ptr, + bsr_col_ind, + bsr_val, + B, + ldb, + *beta, + C, + ldc, + idx_base); +} + +template +__launch_bounds__(BSR_BLOCK_DIM * BLK_SIZE_Y) __global__ + void bsrmm_large_blockdim_kernel_host_pointer(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(alpha == static_cast(0) && beta == static_cast(1)) + { + return; + } + + bsrmm_large_blockdim_device( + direction, trans_B, mb, n, alpha, bsr_row_ptr, bsr_col_ind, bsr_val, block_dim, B, ldb, beta, C, ldc, idx_base); +} + +template +__launch_bounds__(BSR_BLOCK_DIM * BLK_SIZE_Y) __global__ + void bsrmm_large_blockdim_kernel_device_pointer(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + const T* alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + const T* beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(*alpha == static_cast(0) && *beta == static_cast(1)) + { + return; + } + + bsrmm_large_blockdim_device(direction, + trans_B, + mb, + n, + *alpha, + bsr_row_ptr, + bsr_col_ind, + bsr_val, + block_dim, + B, + ldb, + *beta, + C, + ldc, + idx_base); +} + +template +__launch_bounds__(BSR_BLOCK_DIM * BLK_SIZE_Y) __global__ + void bsrmm_general_blockdim_kernel_host_pointer(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(alpha == static_cast(0) && beta == static_cast(1)) + { + return; + } + + bsrmm_general_blockdim_device( + direction, trans_B, mb, n, alpha, bsr_row_ptr, bsr_col_ind, bsr_val, block_dim, B, ldb, beta, C, ldc, idx_base); +} + +template +__launch_bounds__(BSR_BLOCK_DIM * BLK_SIZE_Y) __global__ + void bsrmm_general_blockdim_kernel_device_pointer(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + const T* alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + const T* beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(*alpha == static_cast(0) && *beta == static_cast(1)) + { + return; + } + + bsrmm_general_blockdim_device(direction, + trans_B, + mb, + n, + *alpha, + bsr_row_ptr, + bsr_col_ind, + bsr_val, + block_dim, + B, + ldb, + *beta, + C, + ldc, + idx_base); +} + +template +rocsparse_status rocsparse_bsrmm_template(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const T* alpha, + const rocsparse_mat_descr descr, + const T* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const T* B, + rocsparse_int ldb, + const T* beta, + T* C, + rocsparse_int ldc) +{ + // Check for valid handle and matrix descriptor + if(handle == nullptr) + { + return rocsparse_status_invalid_handle; + } + else if(descr == nullptr) + { + return rocsparse_status_invalid_pointer; + } + + // Logging TODO bench logging + if(handle->pointer_mode == rocsparse_pointer_mode_host) + { + log_trace(handle, + replaceX("rocsparse_Xbsrmm"), + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + *alpha, + (const void*&)descr, + (const void*&)bsr_val, + (const void*&)bsr_row_ptr, + (const void*&)bsr_col_ind, + block_dim, + (const void*&)B, + ldb, + *beta, + (const void*&)C, + ldc); + } + else + { + log_trace(handle, + replaceX("rocsparse_Xbsrmm"), + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + (const void*&)alpha, + (const void*&)descr, + (const void*&)bsr_val, + (const void*&)bsr_row_ptr, + (const void*&)bsr_col_ind, + block_dim, + (const void*&)B, + ldb, + (const void*&)beta, + (const void*&)C, + ldc); + } + + // Check index base + if(descr->base != rocsparse_index_base_zero && descr->base != rocsparse_index_base_one) + { + return rocsparse_status_invalid_value; + } + + // Check matrix type + if(descr->type != rocsparse_matrix_type_general) + { + // TODO + return rocsparse_status_not_implemented; + } + + // Check operation + if(trans_A != rocsparse_operation_none) + { + return rocsparse_status_not_implemented; + } + else if(trans_B != rocsparse_operation_none && trans_B != rocsparse_operation_transpose) + { + return rocsparse_status_not_implemented; + } + + // Check sizes + if(mb < 0 || n < 0 || kb < 0 || nnzb < 0 || block_dim <= 0) + { + return rocsparse_status_invalid_size; + } + + // Quick return if possible + if(mb == 0 || n == 0 || kb == 0) + { + return rocsparse_status_success; + } + + // Check pointer arguments + if(bsr_val == nullptr || bsr_row_ptr == nullptr || bsr_col_ind == nullptr || B == nullptr + || C == nullptr || alpha == nullptr || beta == nullptr) + { + return rocsparse_status_invalid_pointer; + } + + // Check leading dimension of B + if(trans_B == rocsparse_operation_none) + { + if(ldb < kb) + { + return rocsparse_status_invalid_size; + } + } + else + { + if(ldb < n) + { + return rocsparse_status_invalid_size; + } + } + + // Check leading dimension of C + if(ldc < mb) + { + return rocsparse_status_invalid_size; + } + + // Stream + hipStream_t stream = handle->stream; + + rocsparse_int m = mb * block_dim; + rocsparse_int k = kb * block_dim; + rocsparse_int nnz = nnzb * block_dim; + + // If n is only 1 and B are non-transposed, then call bsrmv + if(n == 1) + { + if(trans_B == rocsparse_operation_none) + { + return rocsparse_bsrmv_template(handle, + dir, + trans_A, + mb, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + beta, + C); + } + } + + // If block dimension is one we can simply call csrmm + if(block_dim == 1) + { + return rocsparse_csrmm_template(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + B, + ldb, + beta, + C, + ldc); + } + + if(block_dim == 2) + { + if(handle->pointer_mode == rocsparse_pointer_mode_device) + { + if(trans_B == rocsparse_operation_none) + { + constexpr rocsparse_int BSRMMNN_DIM = 64; + constexpr rocsparse_int SUB_WF_SIZE = 8; + + dim3 bsrmmnn_blocks((SUB_WF_SIZE * m - 1) / BSRMMNN_DIM + 1, (n - 1) / SUB_WF_SIZE + 1); + dim3 bsrmmnn_threads(BSRMMNN_DIM); + launch_bsrmmnn_small_blockdim_kernel_device_pointer(T, BSRMMNN_DIM, SUB_WF_SIZE, 2); + } + else + { + constexpr rocsparse_int BSRMMNT_DIM = 64; + + // Average nnzb per row of A + rocsparse_int avg_row_nnzb = (nnzb - 1) / mb + 1; + + // Launch appropriate kernel depending on row nnz of A + if(avg_row_nnzb < 16) + { + dim3 bsrmmnt_blocks((8 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_device_pointer(T, BSRMMNT_DIM, 8, 2); + } + else if(avg_row_nnzb < 32) + { + dim3 bsrmmnt_blocks((16 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_device_pointer(T, BSRMMNT_DIM, 16, 2); + } + else if(avg_row_nnzb < 64 || handle->wavefront_size == 32) + { + dim3 bsrmmnt_blocks((32 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_device_pointer(T, BSRMMNT_DIM, 32, 2); + } + else if(handle->wavefront_size == 64) + { + dim3 bsrmmnt_blocks((64 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_device_pointer(T, BSRMMNT_DIM, 64, 2); + } + else + { + return rocsparse_status_arch_mismatch; + } + } + } + else + { + if(trans_B == rocsparse_operation_none) + { + constexpr rocsparse_int BSRMMNN_DIM = 64; + constexpr rocsparse_int SUB_WF_SIZE = 8; + + dim3 bsrmmnn_blocks((SUB_WF_SIZE * m - 1) / BSRMMNN_DIM + 1, (n - 1) / SUB_WF_SIZE + 1); + dim3 bsrmmnn_threads(BSRMMNN_DIM); + launch_bsrmmnn_small_blockdim_kernel_host_pointer(T, BSRMMNN_DIM, SUB_WF_SIZE, 2); + } + else + { + constexpr rocsparse_int BSRMMNT_DIM = 64; + + // Average nnzb per row of A + rocsparse_int avg_row_nnzb = (nnzb - 1) / mb + 1; + + // Launch appropriate kernel depending on row nnz of A + if(avg_row_nnzb < 16) + { + dim3 bsrmmnt_blocks((8 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_host_pointer(T, BSRMMNT_DIM, 8, 2); + } + else if(avg_row_nnzb < 32) + { + dim3 bsrmmnt_blocks((16 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_host_pointer(T, BSRMMNT_DIM, 16, 2); + } + else if(avg_row_nnzb < 64 || handle->wavefront_size == 32) + { + dim3 bsrmmnt_blocks((32 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_host_pointer(T, BSRMMNT_DIM, 32, 2); + } + else if(handle->wavefront_size == 64) + { + dim3 bsrmmnt_blocks((64 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_host_pointer(T, BSRMMNT_DIM, 64, 2); + } + else + { + return rocsparse_status_arch_mismatch; + } + } + } + + return rocsparse_status_success; + } + + // Run different bsrmm kernels for block dim > 2 + if(n <= 16 && block_dim > 4 && block_dim <= 8) + { + if(handle->pointer_mode == rocsparse_pointer_mode_device) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(8, 16, 1); + launch_bsrmm_large_blockdim_kernel_device_pointer(T, 8, 16); + } + else + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(8, 16, 1); + launch_bsrmm_large_blockdim_kernel_host_pointer(T, 8, 16); + } + } + else + { + if(handle->pointer_mode == rocsparse_pointer_mode_device) + { + if(block_dim <= 4) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(4, 16, 1); + launch_bsrmm_large_blockdim_kernel_device_pointer(T, 4, 16); + } + else if(block_dim <= 8) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(8, 32, 1); + launch_bsrmm_large_blockdim_kernel_device_pointer(T, 8, 32); + } + else if(block_dim <= 16) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(16, 16, 1); + launch_bsrmm_large_blockdim_kernel_device_pointer(T, 16, 16); + } + else if(block_dim <= 32) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(32, 32, 1); + launch_bsrmm_large_blockdim_kernel_device_pointer(T, 32, 32); + } + else + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(32, 32, 1); + launch_bsrmm_general_blockdim_kernel_device_pointer(T, 32, 32); + } + } + else + { + if(block_dim <= 4) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(4, 16, 1); + launch_bsrmm_large_blockdim_kernel_host_pointer(T, 4, 16); + } + else if(block_dim <= 8) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(8, 32, 1); + launch_bsrmm_large_blockdim_kernel_host_pointer(T, 8, 32); + } + else if(block_dim <= 16) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(16, 16, 1); + launch_bsrmm_large_blockdim_kernel_host_pointer(T, 16, 16); + } + else if(block_dim <= 32) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(32, 32, 1); + launch_bsrmm_large_blockdim_kernel_host_pointer(T, 32, 32); + } + else + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(32, 32, 1); + launch_bsrmm_general_blockdim_kernel_host_pointer(T, 32, 32); + } + } + } + + return rocsparse_status_success; +} + +#endif // ROCSPARSE_BSRMM_HPP \ No newline at end of file diff --git a/library/src/rocsparse_module.f90 b/library/src/rocsparse_module.f90 index dc9c90ec..cc67ef89 100644 --- a/library/src/rocsparse_module.f90 +++ b/library/src/rocsparse_module.f90 @@ -1728,6 +1728,114 @@ end function rocsparse_zhybmv ! =========================================================================== ! level 3 SPARSE ! =========================================================================== +! rocsparse_bsrmm + function rocsparse_sbsrmm(handle, dir, trans_A, trans_B, mb, n, kb, nnzb, alpha, descr, & + bsr_val, bsr_row_ptr, bsr_col_ind, block_dim, B, ldb, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_sbsrmm') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: dir + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: mb + integer(c_int), value :: n + integer(c_int), value :: kb + integer(c_int), value :: nnzb + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: bsr_val + type(c_ptr), intent(in), value :: bsr_row_ptr + type(c_ptr), intent(in), value :: bsr_col_ind + integer(c_int), value :: block_dim + type(c_ptr), intent(in), value :: B + integer(c_int), value :: ldb + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_sbsrmm + + function rocsparse_dbsrmm(handle, dir, trans_A, trans_B, mb, n, kb, nnzb, alpha, descr, & + bsr_val, bsr_row_ptr, bsr_col_ind, block_dim, B, ldb, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_dbsrmm') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: dir + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: mb + integer(c_int), value :: n + integer(c_int), value :: kb + integer(c_int), value :: nnzb + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: bsr_val + type(c_ptr), intent(in), value :: bsr_row_ptr + type(c_ptr), intent(in), value :: bsr_col_ind + integer(c_int), value :: block_dim + type(c_ptr), intent(in), value :: B + integer(c_int), value :: ldb + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_dbsrmm + + function rocsparse_cbsrmm(handle, dir, trans_A, trans_B, mb, n, kb, nnzb, alpha, descr, & + bsr_val, bsr_row_ptr, bsr_col_ind, block_dim, B, ldb, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_cbsrmm') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: dir + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: mb + integer(c_int), value :: n + integer(c_int), value :: kb + integer(c_int), value :: nnzb + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: bsr_val + type(c_ptr), intent(in), value :: bsr_row_ptr + type(c_ptr), intent(in), value :: bsr_col_ind + integer(c_int), value :: block_dim + type(c_ptr), intent(in), value :: B + integer(c_int), value :: ldb + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_cbsrmm + + function rocsparse_zbsrmm(handle, dir, trans_A, trans_B, mb, n, kb, nnzb, alpha, descr, & + bsr_val, bsr_row_ptr, bsr_col_ind, block_dim, B, ldb, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_zbsrmm') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: dir + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: mb + integer(c_int), value :: n + integer(c_int), value :: kb + integer(c_int), value :: nnzb + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: bsr_val + type(c_ptr), intent(in), value :: bsr_row_ptr + type(c_ptr), intent(in), value :: bsr_col_ind + integer(c_int), value :: block_dim + type(c_ptr), intent(in), value :: B + integer(c_int), value :: ldb + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_zbsrmm ! rocsparse_csrmm function rocsparse_scsrmm(handle, trans_A, trans_B, m, n, k, nnz, alpha, descr, &