diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 747227ae..83c71c66 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -43,6 +43,7 @@ #include "testing_hybmv.hpp" // Level3 +#include "testing_bsrmm.hpp" #include "testing_csrmm.hpp" #include "testing_csrsm.hpp" @@ -209,7 +210,7 @@ int main(int argc, char* argv[]) "SPARSE function to test. Options:\n" " Level1: axpyi, doti, dotci, gthr, gthrz, roti, sctr\n" " Level2: bsrmv, bsrsv, coomv, csrmv, csrsv, ellmv, hybmv\n" - " Level3: csrmm, csrsm\n" + " Level3: bsrmm, csrmm, csrsm\n" " Extra: csrgeam, csrgemm\n" " Preconditioner: csric0, csrilu0\n" " Conversion: csr2coo, csr2csc, csr2ell, csr2hyb, csr2bsr\n" @@ -567,6 +568,17 @@ int main(int argc, char* argv[]) else if(precision == 'z') testing_hybmv(arg); } + else if(function == "bsrmm") + { + if(precision == 's') + testing_bsrmm(arg); + else if(precision == 'd') + testing_bsrmm(arg); + else if(precision == 'c') + testing_bsrmm(arg); + else if(precision == 'z') + testing_bsrmm(arg); + } else if(function == "csrmm") { if(precision == 's') diff --git a/clients/common/rocsparse_template_specialization.cpp b/clients/common/rocsparse_template_specialization.cpp index e1a6c775..3f8b80fd 100644 --- a/clients/common/rocsparse_template_specialization.cpp +++ b/clients/common/rocsparse_template_specialization.cpp @@ -1521,6 +1521,175 @@ rocsparse_status rocsparse_hybmv(rocsparse_handle handle, * level 3 SPARSE * =========================================================================== */ +// bsrmm +template <> +rocsparse_status rocsparse_bsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const float* alpha, + const rocsparse_mat_descr descr, + const float* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const float* B, + rocsparse_int ldb, + const float* beta, + float* C, + rocsparse_int ldc) +{ + return rocsparse_sbsrmm(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +template <> +rocsparse_status rocsparse_bsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const double* alpha, + const rocsparse_mat_descr descr, + const double* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const double* B, + rocsparse_int ldb, + const double* beta, + double* C, + rocsparse_int ldc) +{ + return rocsparse_dbsrmm(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +template <> +rocsparse_status rocsparse_bsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_float_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_float_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_float_complex* B, + rocsparse_int ldb, + const rocsparse_float_complex* beta, + rocsparse_float_complex* C, + rocsparse_int ldc) +{ + return rocsparse_cbsrmm(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +template <> +rocsparse_status rocsparse_bsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_double_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_double_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_double_complex* B, + rocsparse_int ldb, + const rocsparse_double_complex* beta, + rocsparse_double_complex* C, + rocsparse_int ldc) +{ + return rocsparse_zbsrmm(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + // csrmm template <> rocsparse_status rocsparse_csrmm(rocsparse_handle handle, diff --git a/clients/include/flops.hpp b/clients/include/flops.hpp index 82c712aa..fcd06b5e 100644 --- a/clients/include/flops.hpp +++ b/clients/include/flops.hpp @@ -78,6 +78,16 @@ constexpr double csrsv_gflop_count(rocsparse_int M, rocsparse_int nnz, rocsparse * level 3 SPARSE * =========================================================================== */ +template +constexpr double bsrmm_gflop_count(rocsparse_int N, + rocsparse_int nnzb, + rocsparse_int block_dim, + rocsparse_int nnz_C, + bool beta = false) +{ + return (3.0 * nnzb * block_dim * block_dim * N + (beta ? nnz_C : 0)) / 1e9; +} + template constexpr double csrmm_gflop_count(rocsparse_int N, rocsparse_int nnz_A, rocsparse_int nnz_C, bool beta = false) diff --git a/clients/include/gbyte.hpp b/clients/include/gbyte.hpp index d785336a..91a5ab91 100644 --- a/clients/include/gbyte.hpp +++ b/clients/include/gbyte.hpp @@ -132,6 +132,24 @@ constexpr double * level 3 SPARSE * =========================================================================== */ +template +constexpr double bsrmm_gbyte_count(rocsparse_int Mb, + rocsparse_int nnzb, + rocsparse_int block_dim, + rocsparse_int nnz_B, + rocsparse_int nnz_C, + bool beta = false) +{ + //reads + size_t reads = (Mb + 1 + nnzb) * sizeof(rocsparse_int) + + (block_dim * block_dim * nnzb + nnz_B + (beta ? nnz_C : 0)) * sizeof(T); + + //writes + size_t writes = nnz_C * sizeof(T); + + return (reads + writes) / 1e9; +} + template constexpr double csrmm_gbyte_count(rocsparse_int M, rocsparse_int nnz_A, diff --git a/clients/include/rocsparse.hpp b/clients/include/rocsparse.hpp index 94a93ed4..951a26d1 100644 --- a/clients/include/rocsparse.hpp +++ b/clients/include/rocsparse.hpp @@ -302,6 +302,28 @@ rocsparse_status rocsparse_hybmv(rocsparse_handle handle, * level 3 SPARSE * =========================================================================== */ +// bsrmm +template +rocsparse_status rocsparse_bsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const T* alpha, + const rocsparse_mat_descr descr, + const T* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const T* B, + rocsparse_int ldb, + const T* beta, + T* C, + rocsparse_int ldc); + // csrmm template rocsparse_status rocsparse_csrmm(rocsparse_handle handle, diff --git a/clients/include/rocsparse_host.hpp b/clients/include/rocsparse_host.hpp index 3fb380bd..31027841 100644 --- a/clients/include/rocsparse_host.hpp +++ b/clients/include/rocsparse_host.hpp @@ -1233,6 +1233,83 @@ inline void host_hybmv(rocsparse_int M, * level 3 SPARSE * =========================================================================== */ +template +inline void host_bsrmm(rocsparse_int Mb, + rocsparse_int N, + rocsparse_int Kb, + rocsparse_int block_dim, + rocsparse_direction dir, + rocsparse_operation transA, + rocsparse_operation transB, + T alpha, + const std::vector& bsr_row_ptr_A, + const std::vector& bsr_col_ind_A, + const std::vector& bsr_val_A, + const std::vector& B, + rocsparse_int ldb, + T beta, + std::vector& C, + rocsparse_int ldc, + rocsparse_index_base base) +{ + if(transA != rocsparse_operation_none) + { + return; + } + + if(transB != rocsparse_operation_none && transB != rocsparse_operation_transpose) + { + return; + } + + rocsparse_int M = Mb * block_dim; + rocsparse_int K = Kb * block_dim; + +#ifdef _OPENMP +#pragma omp parallel for schedule(dynamic, 1024) +#endif + for(rocsparse_int i = 0; i < M; i++) + { + rocsparse_int local_row = i % block_dim; + + rocsparse_int row_begin = bsr_row_ptr_A[i / block_dim] - base; + rocsparse_int row_end = bsr_row_ptr_A[i / block_dim + 1] - base; + + for(rocsparse_int j = 0; j < N; j++) + { + rocsparse_int idx_C = i + j * ldc; + + T sum = static_cast(0); + + for(rocsparse_int s = row_begin; s < row_end; s++) + { + for(rocsparse_int t = 0; t < block_dim; t++) + { + rocsparse_int idx_A + = (dir == rocsparse_direction_row) + ? block_dim * block_dim * s + block_dim * local_row + t + : block_dim * block_dim * s + block_dim * t + local_row; + rocsparse_int idx_B + = (transB == rocsparse_operation_none) + ? j * ldb + block_dim * (bsr_col_ind_A[s] - base) + t + : (block_dim * (bsr_col_ind_A[s] - base) + t) * ldb + j; + + sum = std::fma(bsr_val_A[idx_A], B[idx_B], sum); + } + } + + if(beta == static_cast(0)) + { + C[idx_C] = alpha * sum; + } + else + { + C[idx_C] = std::fma(beta, C[idx_C], alpha * sum); + } + } + } +} + template inline void host_csrmm(rocsparse_int M, rocsparse_int N, @@ -1267,16 +1344,16 @@ inline void host_csrmm(rocsparse_int M, ? (csr_col_ind_A[k] - base + j * ldb) : (j + (csr_col_ind_A[k] - base) * ldb); - sum = std::fma(alpha * csr_val_A[k], B[idx_B], sum); + sum = std::fma(csr_val_A[k], B[idx_B], sum); } if(beta == static_cast(0)) { - C[idx_C] = sum; + C[idx_C] = alpha * sum; } else { - C[idx_C] = std::fma(beta, C[idx_C], sum); + C[idx_C] = std::fma(beta, C[idx_C], alpha * sum); } } } diff --git a/clients/include/rocsparse_template.yaml b/clients/include/rocsparse_template.yaml index ddd43a68..d4f6ee7c 100644 --- a/clients/include/rocsparse_template.yaml +++ b/clients/include/rocsparse_template.yaml @@ -105,6 +105,10 @@ Functions: rocsparse_chybmv: { function: hybmv, <<: *single_precision_complex } rocsparse_zhybmv: { function: hybmv, <<: *double_precision_complex } + rocsparse_sbsrmm: { function: bsrmm, <<: *single_precision } + rocsparse_dbsrmm: { function: bsrmm, <<: *double_precision } + rocsparse_cbsrmm: { function: bsrmm, <<: *single_precision_complex } + rocsparse_zbsrmm: { function: bsrmm, <<: *double_precision_complex } rocsparse_scsrmm: { function: csrmm, <<: *single_precision } rocsparse_dcsrmm: { function: csrmm, <<: *double_precision } rocsparse_scsrsm_buffer_size: { function: csrsm, <<: *single_precision } diff --git a/clients/include/testing_bsrmm.hpp b/clients/include/testing_bsrmm.hpp new file mode 100644 index 00000000..289f04ff --- /dev/null +++ b/clients/include/testing_bsrmm.hpp @@ -0,0 +1,767 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#pragma once +#ifndef TESTING_BSRMM_HPP +#define TESTING_BSRMM_HPP + +#include + +#include "flops.hpp" +#include "gbyte.hpp" +#include "rocsparse_check.hpp" +#include "rocsparse_host.hpp" +#include "rocsparse_init.hpp" +#include "rocsparse_math.hpp" +#include "rocsparse_random.hpp" +#include "rocsparse_test.hpp" +#include "rocsparse_vector.hpp" +#include "utility.hpp" + +#include "testing_bsr2csr.hpp" + +template +void testing_bsrmm_bad_arg(const Arguments& arg) +{ + static const size_t safe_size = 100; + + T h_alpha = 0.6; + T h_beta = 0.1; + + // Create rocsparse handle + rocsparse_local_handle handle; + + // Create matrix descriptor + rocsparse_local_mat_descr descr; + + // Allocate memory on device + device_vector dbsr_row_ptr(safe_size); + device_vector dbsr_col_ind(safe_size); + device_vector dbsr_val(safe_size); + device_vector dB(safe_size); + device_vector dC(safe_size); + + if(!dbsr_row_ptr || !dbsr_col_ind || !dbsr_val || !dB || !dC) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Test invalid handle + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(nullptr, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_handle); + + // Test invalid pointers + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + nullptr, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + nullptr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + nullptr, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + nullptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + nullptr, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + nullptr, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + nullptr, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + nullptr, + safe_size), + rocsparse_status_invalid_pointer); + + // Test invalid size + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + -1, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_size); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + -1, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_size); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + -1, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_size); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + -1, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_size); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + 0, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_size); + + // Test not implemented + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_transpose, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_not_implemented); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_conjugate_transpose, + rocsparse_operation_none, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_not_implemented); + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + rocsparse_direction_row, + rocsparse_operation_none, + rocsparse_operation_conjugate_transpose, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + safe_size, + dB, + safe_size, + &h_beta, + dC, + safe_size), + rocsparse_status_not_implemented); +} + +template +void testing_bsrmm(const Arguments& arg) +{ + rocsparse_int M = arg.M; + rocsparse_int N = arg.N; + rocsparse_int K = arg.K; + rocsparse_int block_dim = arg.block_dim; + rocsparse_int dim_x = arg.dimx; + rocsparse_int dim_y = arg.dimy; + rocsparse_int dim_z = arg.dimz; + rocsparse_operation transA = arg.transA; + rocsparse_operation transB = arg.transB; + rocsparse_direction direction = arg.direction; + rocsparse_index_base base = arg.baseA; + rocsparse_matrix_init mat = arg.matrix; + bool full_rank = false; + std::string filename + = arg.timing ? arg.filename : rocsparse_exepath() + "../matrices/" + arg.filename + ".csr"; + + rocsparse_int Mb = -1; + rocsparse_int Kb = -1; + if(block_dim > 0) + { + Mb = (M + block_dim - 1) / block_dim; + Kb = (K + block_dim - 1) / block_dim; + } + + T h_alpha = arg.get_alpha(); + T h_beta = arg.get_beta(); + + // Create rocsparse handle + rocsparse_local_handle handle; + + // Create matrix descriptor + rocsparse_local_mat_descr descr; + + // Set matrix index base + CHECK_ROCSPARSE_ERROR(rocsparse_set_mat_index_base(descr, base)); + + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_host)); + + // Argument sanity check before allocating invalid memory + if(Mb <= 0 || N <= 0 || Kb <= 0 || block_dim <= 0) + { + static const size_t safe_size = 100; + + // Allocate memory on device + device_vector dbsr_row_ptr(safe_size); + device_vector dbsr_col_ind(safe_size); + device_vector dbsr_val(safe_size); + device_vector dB(safe_size); + device_vector dC(safe_size); + + if(!dbsr_row_ptr || !dbsr_col_ind || !dbsr_val || !dB || !dC) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + EXPECT_ROCSPARSE_STATUS(rocsparse_bsrmm(handle, + direction, + transA, + transB, + Mb, + N, + Kb, + safe_size, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + safe_size, + &h_beta, + dC, + safe_size), + (Mb < 0 || N < 0 || Kb < 0 || block_dim <= 0) + ? rocsparse_status_invalid_size + : rocsparse_status_success); + + return; + } + + // Allocate host memory for original CSR matrix + host_vector hcsr_row_ptr_orig; + host_vector hcsr_col_ind_orig; + host_vector hcsr_val_orig; + + // Allocate host memory for output BSR matrix + host_vector hbsr_row_ptr; + host_vector hbsr_col_ind; + host_vector hbsr_val; + + rocsparse_seedrand(); + + // Generate original host CSR matrix and then use it to fill in the host BSR matrix + rocsparse_int nnzb = 0; + rocsparse_int dummy = 0; + rocsparse_init_csr_and_bsr_matrix(hcsr_row_ptr_orig, + hcsr_col_ind_orig, + hcsr_val_orig, + M, + K, + base, + hbsr_row_ptr, + hbsr_col_ind, + hbsr_val, + direction, + Mb, + Kb, + block_dim, + dummy, + dim_x, + dim_y, + dim_z, + nnzb, + base, + mat, + filename.c_str(), + false, + full_rank); + + // M and K and Mb and Kb can be modified by rocsparse_init_csr_and_bsr_matrix + M = Mb * block_dim; + K = Kb * block_dim; + + // Some matrix properties + rocsparse_int ldb = (transB == rocsparse_operation_none) ? K : N; + rocsparse_int ldc = M; + + rocsparse_int ncol_B = (transB == rocsparse_operation_none ? N : K); + rocsparse_int nnz_B = ldb * ncol_B; + rocsparse_int nnz_C = ldc * N; + + // Allocate host memory for dense matrices + host_vector hB(nnz_B); + host_vector hC_1(nnz_C); + host_vector hC_2(nnz_C); + host_vector hC_gold(nnz_C); + + // Initialize data on CPU + rocsparse_init(hB, ldb, ncol_B, ldb); + rocsparse_init(hC_1, ldc, N, ldc); + hC_2 = hC_1; + hC_gold = hC_1; + + // Allocate device memory + device_vector dbsr_row_ptr(Mb + 1); + device_vector dbsr_col_ind(nnzb); + device_vector dbsr_val(nnzb * block_dim * block_dim); + device_vector dB(nnz_B); + device_vector dC_1(nnz_C); + device_vector dC_2(nnz_C); + device_vector d_alpha(1); + device_vector d_beta(1); + + if(!dbsr_row_ptr || !dbsr_col_ind || !dbsr_val || !dB || !dC_1 || !dC_2 || !d_alpha || !d_beta) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy( + dbsr_row_ptr, hbsr_row_ptr, sizeof(rocsparse_int) * (Mb + 1), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR( + hipMemcpy(dbsr_col_ind, hbsr_col_ind, sizeof(rocsparse_int) * nnzb, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy( + dbsr_val, hbsr_val, sizeof(T) * nnzb * block_dim * block_dim, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dB, hB, sizeof(T) * nnz_B, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dC_1, hC_1, sizeof(T) * nnz_C, hipMemcpyHostToDevice)); + + if(arg.unit_check) + { + // Copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(dC_2, hC_2, sizeof(T) * nnz_C, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(d_alpha, &h_alpha, sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(d_beta, &h_beta, sizeof(T), hipMemcpyHostToDevice)); + + // Pointer mode host + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_host)); + CHECK_ROCSPARSE_ERROR(rocsparse_bsrmm(handle, + direction, + transA, + transB, + Mb, + N, + Kb, + nnzb, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + ldb, + &h_beta, + dC_1, + ldc)); + + // Pointer mode device + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_device)); + CHECK_ROCSPARSE_ERROR(rocsparse_bsrmm(handle, + direction, + transA, + transB, + Mb, + N, + Kb, + nnzb, + d_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + ldb, + d_beta, + dC_2, + ldc)); + + // Copy output to host + CHECK_HIP_ERROR(hipMemcpy(hC_1, dC_1, sizeof(T) * nnz_C, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(hC_2, dC_2, sizeof(T) * nnz_C, hipMemcpyDeviceToHost)); + + // CPU bsrmm + host_bsrmm(Mb, + N, + Kb, + block_dim, + direction, + transA, + transB, + h_alpha, + hbsr_row_ptr, + hbsr_col_ind, + hbsr_val, + hB, + ldb, + h_beta, + hC_gold, + ldc, + base); + + near_check_general(ldc, N, ldc, hC_gold, hC_1); + near_check_general(ldc, N, ldc, hC_gold, hC_2); + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = arg.iters; + + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_host)); + + // Warm up + for(int iter = 0; iter < number_cold_calls; ++iter) + { + CHECK_ROCSPARSE_ERROR(rocsparse_bsrmm(handle, + direction, + transA, + transB, + Mb, + N, + Kb, + nnzb, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + ldb, + &h_beta, + dC_1, + ldc)); + } + + double gpu_time_used = get_time_us(); + + // Performance run + for(int iter = 0; iter < number_hot_calls; ++iter) + { + CHECK_ROCSPARSE_ERROR(rocsparse_bsrmm(handle, + direction, + transA, + transB, + Mb, + N, + Kb, + nnzb, + &h_alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + ldb, + &h_beta, + dC_1, + ldc)); + } + + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + double gpu_gflops + = bsrmm_gflop_count(N, nnzb, block_dim, nnz_C, h_beta != static_cast(0)) + / gpu_time_used * 1e6; + double gpu_gbyte + = bsrmm_gbyte_count(Mb, nnzb, block_dim, nnz_B, nnz_C, h_beta != static_cast(0)) + / gpu_time_used * 1e6; + + std::cout.precision(2); + std::cout.setf(std::ios::fixed); + std::cout.setf(std::ios::left); + + std::cout << std::setw(12) << "M" << std::setw(12) << "N" << std::setw(12) << "K" + << std::setw(12) << "dir" << std::setw(12) << "transA" << std::setw(12) + << "transB" << std::setw(12) << "nnzb" << std::setw(12) << "block_dim" + << std::setw(12) << "nnz_B" << std::setw(12) << "nnz_C" << std::setw(12) + << "alpha" << std::setw(12) << "beta" << std::setw(12) << "GFlop/s" + << std::setw(12) << "GB/s" << std::setw(12) << "msec" << std::setw(12) << "iter" + << std::setw(12) << "verified" << std::endl; + + std::cout << std::setw(12) << M << std::setw(12) << N << std::setw(12) << K << std::setw(12) + << rocsparse_direction2string(direction) << std::setw(12) + << rocsparse_operation2string(transA) << std::setw(12) + << rocsparse_operation2string(transB) << std::setw(12) << nnzb << std::setw(12) + << block_dim << std::setw(12) << nnz_B << std::setw(12) << nnz_C << std::setw(12) + << h_alpha << std::setw(12) << h_beta << std::setw(12) << gpu_gflops + << std::setw(12) << gpu_gbyte << std::setw(12) << gpu_time_used / 1e3 + << std::setw(12) << number_hot_calls << std::setw(12) + << (arg.unit_check ? "yes" : "no") << std::endl; + } +} + +#endif // TESTING_BSRMM_HPP \ No newline at end of file diff --git a/clients/samples/CMakeLists.txt b/clients/samples/CMakeLists.txt index 77b0a067..b10c4f40 100644 --- a/clients/samples/CMakeLists.txt +++ b/clients/samples/CMakeLists.txt @@ -62,6 +62,8 @@ add_rocsparse_example(example_hybmv.cpp) add_rocsparse_example(example_csrsv.cpp) # Level 3 +add_rocsparse_example(example_bsrmm.cpp) +add_rocsparse_example(example_fortran_bsrmm.f90) add_rocsparse_example(example_csrmm.cpp) add_rocsparse_example(example_csrsm.cpp) diff --git a/clients/samples/example_bsrmm.cpp b/clients/samples/example_bsrmm.cpp new file mode 100644 index 00000000..c1eae5f7 --- /dev/null +++ b/clients/samples/example_bsrmm.cpp @@ -0,0 +1,216 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#include +#include +#include + +#define HIP_CHECK(stat) \ + { \ + if(stat != hipSuccess) \ + { \ + std::cerr << "Error: hip error in line " << __LINE__ << std::endl; \ + return -1; \ + } \ + } + +#define ROCSPARSE_CHECK(stat) \ + { \ + if(stat != rocsparse_status_success) \ + { \ + std::cerr << "Error: rocsparse error in line " << __LINE__ << std::endl; \ + return -1; \ + } \ + } + +int main(int argc, char* argv[]) +{ + // Query device + int ndev; + HIP_CHECK(hipGetDeviceCount(&ndev)); + + if(ndev < 1) + { + std::cerr << "No HIP device found" << std::endl; + return -1; + } + + // Query device properties + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, 0)); + + std::cout << "Device: " << prop.name << std::endl; + + // rocSPARSE handle + rocsparse_handle handle; + ROCSPARSE_CHECK(rocsparse_create_handle(&handle)); + + // Print rocSPARSE version and revision + int ver; + char rev[64]; + + ROCSPARSE_CHECK(rocsparse_get_version(handle, &ver)); + ROCSPARSE_CHECK(rocsparse_get_git_rev(handle, rev)); + + std::cout << "rocSPARSE version: " << ver / 100000 << "." << ver / 100 % 1000 << "." + << ver % 100 << "-" << rev << std::endl; + + // Input data + + // Matrix A (m x k) + // ( 1 2 0 3 0 0 ) + // A = ( 0 4 5 0 0 0 ) + // ( 0 0 0 7 8 0 ) + // ( 0 0 1 2 4 1 ) + + // Number of rows and columns + rocsparse_int block_dim = 2; + rocsparse_int mb = 2; + rocsparse_int kb = 3; + rocsparse_int n = 10; + rocsparse_int m = mb * block_dim; + rocsparse_int k = kb * block_dim; + + // Number of non-zero block entries + rocsparse_int nnzb = 4; + + // BSR row pointers + rocsparse_int hbsr_row_ptr[3] = {0, 2, 4}; + + // BSR column indices + rocsparse_int hbsr_col_ind[4] = {0, 1, 1, 2}; + + // BSR values + double hbsr_val[16] + = {1.0, 2.0, 0.0, 4.0, 0.0, 3.0, 5.0, 0.0, 0.0, 7.0, 1.0, 2.0, 8.0, 0.0, 4.0, 1.0}; + + // Transposition of the matrix + rocsparse_direction dir = rocsparse_direction_row; + rocsparse_operation transA = rocsparse_operation_none; + rocsparse_operation transB = rocsparse_operation_none; + + // Matrix B (k x n) column major order + // ( 9 11 13 15 17 10 12 14 16 18 ) + // ( 8 10 1 10 6 11 7 3 12 17 ) + // B = ( 11 11 0 4 6 12 2 9 13 2 ) + // ( 15 3 2 3 8 1 2 4 6 6 ) + // ( 2 5 7 0 1 15 9 4 10 1 ) + // ( 7 12 12 1 12 5 1 11 1 14 ) + + // Matrix B in column-major + rocsparse_int ldb = k; + double hB[6 * 10] + = {9, 8, 11, 15, 2, 7, 11, 10, 11, 3, 5, 12, 13, 1, 0, 2, 7, 12, 15, 10, + 4, 3, 0, 1, 17, 6, 6, 8, 1, 12, 10, 11, 12, 1, 15, 5, 12, 7, 2, 2, + 9, 1, 14, 3, 9, 4, 4, 11, 16, 12, 13, 6, 10, 1, 18, 17, 2, 6, 1, 14}; + + // Matrix C (m x n) column major order + // ( 0 0 0 0 0 0 0 0 0 0 ) + // C = ( 0 0 0 0 0 0 0 0 0 0 ) + // ( 0 0 0 0 0 0 0 0 0 0 ) + // ( 0 0 0 0 0 0 0 0 0 0 ) + + // Matrix C (m x n) in column-major + rocsparse_int ldc = m; + double hC[4 * 10] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + // Scalar alpha and beta + double alpha = 1.0; + double beta = 0.0; + + // Matrix descriptor + rocsparse_mat_descr descr; + ROCSPARSE_CHECK(rocsparse_create_mat_descr(&descr)); + + // Offload data to device + rocsparse_int* dbsr_row_ptr; + rocsparse_int* dbsr_col_ind; + double* dbsr_val; + double* dB; + double* dC; + + HIP_CHECK(hipMalloc((void**)&dbsr_row_ptr, sizeof(rocsparse_int) * (mb + 1))); + HIP_CHECK(hipMalloc((void**)&dbsr_col_ind, sizeof(rocsparse_int) * nnzb)); + HIP_CHECK(hipMalloc((void**)&dbsr_val, sizeof(double) * nnzb * block_dim * block_dim)); + HIP_CHECK(hipMalloc((void**)&dB, sizeof(double) * k * n)); + HIP_CHECK(hipMalloc((void**)&dC, sizeof(double) * m * n)); + + HIP_CHECK(hipMemcpy( + dbsr_row_ptr, hbsr_row_ptr, sizeof(rocsparse_int) * (mb + 1), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(dbsr_col_ind, hbsr_col_ind, sizeof(rocsparse_int) * nnzb, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy( + dbsr_val, hbsr_val, sizeof(double) * nnzb * block_dim * block_dim, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dB, hB, sizeof(double) * k * n, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dC, hC, sizeof(double) * m * n, hipMemcpyHostToDevice)); + + // Call dbsrmm + ROCSPARSE_CHECK(rocsparse_dbsrmm(handle, + dir, + transA, + transB, + mb, + n, + kb, + nnzb, + &alpha, + descr, + dbsr_val, + dbsr_row_ptr, + dbsr_col_ind, + block_dim, + dB, + ldb, + &beta, + dC, + ldc)); + + // Print result + HIP_CHECK(hipMemcpy(hC, dC, sizeof(double) * m * n, hipMemcpyDeviceToHost)); + + std::cout << "C:" << std::endl; + + for(int i = 0; i < m; ++i) + { + for(int j = 0; j < n; ++j) + { + std::cout << " " << hC[i + j * ldc]; + } + + std::cout << std::endl; + } + + // Clear rocSPARSE + ROCSPARSE_CHECK(rocsparse_destroy_mat_descr(descr)); + ROCSPARSE_CHECK(rocsparse_destroy_handle(handle)); + + // Clear device memory + HIP_CHECK(hipFree(dbsr_row_ptr)); + HIP_CHECK(hipFree(dbsr_col_ind)); + HIP_CHECK(hipFree(dbsr_val)); + HIP_CHECK(hipFree(dB)); + HIP_CHECK(hipFree(dC)); + + return 0; +} diff --git a/clients/samples/example_fortran_bsrmm.f90 b/clients/samples/example_fortran_bsrmm.f90 new file mode 100644 index 00000000..95b88a86 --- /dev/null +++ b/clients/samples/example_fortran_bsrmm.f90 @@ -0,0 +1,261 @@ +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +! Copyright (c) 2020 Advanced Micro Devices, Inc. +! +! Permission is hereby granted, free of charge, to any person obtaining a copy +! of this software and associated documentation files (the "Software"), to deal +! in the Software without restriction, including without limitation the rights +! to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +! copies of the Software, and to permit persons to whom the Software is +! furnished to do so, subject to the following conditions: +! +! The above copyright notice and this permission notice shall be included in +! all copies or substantial portions of the Software. +! +! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +! IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +! FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +! AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +! LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +! OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +! THE SOFTWARE. +! +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +subroutine HIP_CHECK(stat) + use iso_c_binding + + implicit none + + integer(c_int) :: stat + + if(stat /= 0) then + write(*,*) 'Error: hip error' + stop + end if + +end subroutine HIP_CHECK + +subroutine ROCSPARSE_CHECK(stat) + use iso_c_binding + + implicit none + + integer(c_int) :: stat + + if(stat /= 0) then + write(*,*) 'Error: rocsparse error' + stop + end if + +end subroutine ROCSPARSE_CHECK + +program example_fortran_bsrmm + use iso_c_binding + use rocsparse + + implicit none + + interface + function hipMalloc(ptr, size) & + result(c_int) & + bind(c, name = 'hipMalloc') + use iso_c_binding + implicit none + type(c_ptr) :: ptr + integer(c_size_t), value :: size + end function hipMalloc + + function hipFree(ptr) & + result(c_int) & + bind(c, name = 'hipFree') + use iso_c_binding + implicit none + type(c_ptr), value :: ptr + end function hipFree + + function hipMemcpy(dst, src, size, kind) & + result(c_int) & + bind(c, name = 'hipMemcpy') + use iso_c_binding + implicit none + type(c_ptr), value :: dst + type(c_ptr), intent(in), value :: src + integer(c_size_t), value :: size + integer(c_int), value :: kind + end function hipMemcpy + + function hipMemset(dst, val, size) & + result(c_int) & + bind(c, name = 'hipMemset') + use iso_c_binding + implicit none + type(c_ptr), value :: dst + integer(c_int), value :: val + integer(c_size_t), value :: size + end function hipMemset + + function hipDeviceSynchronize() & + result(c_int) & + bind(c, name = 'hipDeviceSynchronize') + use iso_c_binding + implicit none + end function hipDeviceSynchronize + + function hipDeviceReset() & + result(c_int) & + bind(c, name = 'hipDeviceReset') + use iso_c_binding + implicit none + end function hipDeviceReset + end interface + + integer, target :: h_bsr_row_ptr(3), h_bsr_col_ind(4) + real(8), target :: h_bsr_val(16), h_B(6 * 10), h_C(4 * 10) + + type(c_ptr) :: d_bsr_row_ptr + type(c_ptr) :: d_bsr_col_ind + type(c_ptr) :: d_bsr_val + type(c_ptr) :: d_B + type(c_ptr) :: d_C + + integer :: i, j + integer(c_int) :: M, Mb, N, K, Kb, nnzb, block_dim + + real(c_double), target :: alpha, beta + + type(c_ptr) :: handle + type(c_ptr) :: descr + + integer :: version + + character(len=12) :: rev + +! Input data +! ( 1 2 0 3 0 0 ) +! A = ( 0 4 5 0 0 0 ) +! ( 0 0 0 7 8 0 ) +! ( 0 0 1 2 4 1 ) + +! ( 9 11 13 15 17 10 12 14 16 18 ) +! ( 8 10 1 10 6 11 7 3 12 17 ) +! B = ( 11 11 0 4 6 12 2 9 13 2 ) +! ( 15 3 2 3 8 1 2 4 6 6 ) +! ( 2 5 7 0 1 15 9 4 10 1 ) +! ( 7 12 12 1 12 5 1 11 1 14 ) + +! Number of rows and columns + block_dim = 2 + Mb = 2 + Kb = 3 + N = 10 + M = Mb * block_dim + K = Kb * block_dim + +! Number of non-zero blocks + nnzb = 4 + +! Fill BSR structure + h_bsr_row_ptr = (/0, 2, 4/) + h_bsr_col_ind = (/0, 1, 1, 2/) + h_bsr_val = (/1, 2, 0, 4, 0, 3, 5, 0, 0, 7, 1, 2, 8, 0, 4, 1/) + +! Scalar alpha and beta + alpha = 1.0 + beta = 0.0 + +! Fill B in column-major + h_B = (/9, 8, 11, 15, 2, 7, & + 11, 10, 11, 3, 5, 12, & + 13, 1, 0, 2, 7, 12, & + 15, 10, 4, 3, 0, 1, & + 17, 6, 6, 8, 1, 12, & + 10, 11, 12, 1, 15, 5, & + 12, 7, 2, 2, 9, 1, & + 14, 3, 9, 4, 4, 11, & + 16, 12, 13, 6, 10, 1, & + 18, 17, 2, 6, 1, 14/) + +! Fill C in column-major + h_C = (/0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0, & + 0, 0, 0, 0/) + +! Allocate device memory + call HIP_CHECK(hipMalloc(d_bsr_row_ptr, (int(Mb, c_size_t) + 1) * 4)) + call HIP_CHECK(hipMalloc(d_bsr_col_ind, int(nnzb, c_size_t) * 4)) + call HIP_CHECK(hipMalloc(d_bsr_val, int(nnzb * block_dim * block_dim, c_size_t) * 8)) + call HIP_CHECK(hipMalloc(d_B, int(K * N, c_size_t) * 8)) + call HIP_CHECK(hipMalloc(d_C, int(M * N, c_size_t) * 8)) + +! Copy host data to device + call HIP_CHECK(hipMemcpy(d_bsr_row_ptr, c_loc(h_bsr_row_ptr), (int(Mb, c_size_t) + 1) * 4, 1)) + call HIP_CHECK(hipMemcpy(d_bsr_col_ind, c_loc(h_bsr_col_ind), int(nnzb, c_size_t) * 4, 1)) + call HIP_CHECK(hipMemcpy(d_bsr_val, c_loc(h_bsr_val), int(nnzb * block_dim * block_dim, c_size_t) * 8, 1)) + call HIP_CHECK(hipMemcpy(d_B, c_loc(h_B), int(K * N, c_size_t) * 8, 1)) + call HIP_CHECK(hipMemcpy(d_C, c_loc(h_C), int(M * N, c_size_t) * 8, 1)) + +! Create rocSPARSE handle + call ROCSPARSE_CHECK(rocsparse_create_handle(handle)) + +! Get rocSPARSE version + call ROCSPARSE_CHECK(rocsparse_get_version(handle, version)) + call ROCSPARSE_CHECK(rocsparse_get_git_rev(handle, rev)) + +! Print version on screen + write(*,fmt='(A,I0,A,I0,A,I0,A,A)') 'rocSPARSE version: ', version / 100000, '.', & + mod(version / 100, 1000), '.', mod(version, 100), '-', rev + +! Create matrix descriptor + call ROCSPARSE_CHECK(rocsparse_create_mat_descr(descr)) + +! Perform the matrix multiplication + call ROCSPARSE_CHECK(rocsparse_dbsrmm(handle, & + rocsparse_direction_row, & + rocsparse_operation_none, & + rocsparse_operation_none, & + Mb, & + N, & + Kb, & + nnzb, & + c_loc(alpha), & + descr, & + d_bsr_val, & + d_bsr_row_ptr, & + d_bsr_col_ind, & + block_dim, & + d_B, & + K, & + c_loc(beta), & + d_C, & + M)) + +! Print result + call HIP_CHECK(hipMemcpy(c_loc(h_C), d_C, int(M * N, c_size_t) * 8, 2)) + +! Note: C in column major ordering + do i = 1, M + do j = 1, N + write(*,fmt='(A,F6.2)',advance='no') ' ', h_C(M * (j - 1) + i) + end do + write(*,*) + end do + +! Clear rocSPARSE + call ROCSPARSE_CHECK(rocsparse_destroy_mat_descr(descr)) + call ROCSPARSE_CHECK(rocsparse_destroy_handle(handle)) + +! Clear device memory + call HIP_CHECK(hipFree(d_bsr_row_ptr)) + call HIP_CHECK(hipFree(d_bsr_col_ind)) + call HIP_CHECK(hipFree(d_bsr_val)) + call HIP_CHECK(hipFree(d_B)) + call HIP_CHECK(hipFree(d_C)) + +end program example_fortran_bsrmm diff --git a/clients/tests/CMakeLists.txt b/clients/tests/CMakeLists.txt index 993db90d..0112f92c 100644 --- a/clients/tests/CMakeLists.txt +++ b/clients/tests/CMakeLists.txt @@ -168,6 +168,7 @@ set(ROCSPARSE_TEST_SOURCES test_csrsv.cpp test_ellmv.cpp test_hybmv.cpp + test_bsrmm.cpp test_csrmm.cpp test_csrsm.cpp test_csrgeam.cpp @@ -229,7 +230,7 @@ set_target_properties(rocsparse-test PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${PROJ set(ROCSPARSE_TEST_DATA "${PROJECT_BINARY_DIR}/staging/rocsparse_test.data") add_custom_command(OUTPUT "${ROCSPARSE_TEST_DATA}" COMMAND ../common/rocsparse_gentest.py -I ../include rocsparse_test.yaml -o "${ROCSPARSE_TEST_DATA}" - DEPENDS ../common/rocsparse_gentest.py rocsparse_test.yaml ../include/rocsparse_common.yaml known_bugs.yaml test_axpyi.yaml test_doti.yaml test_dotci.yaml test_gthr.yaml test_gthrz.yaml test_roti.yaml test_sctr.yaml test_bsrmv.yaml test_bsrsv.yaml test_coomv.yaml test_csrmv.yaml test_csrsv.yaml test_ellmv.yaml test_hybmv.yaml test_csrmm.yaml test_csrsm.yaml test_csrgeam.yaml test_csrgemm.yaml test_csric0.yaml test_csrilu0.yaml test_csr2coo.yaml test_csr2csc.yaml test_csr2ell.yaml test_csr2hyb.yaml test_bsr2csr.yaml test_csr2bsr.yaml test_coo2csr.yaml test_ell2csr.yaml test_hyb2csr.yaml test_identity.yaml test_csrsort.yaml test_cscsort.yaml test_coosort.yaml test_csricsv.yaml test_csrilusv.yaml test_nnz.yaml test_dense2csr.yaml test_dense2csc.yaml test_csr2dense.yaml test_csc2dense.yaml test_csr2csr_compress.cpp + DEPENDS ../common/rocsparse_gentest.py rocsparse_test.yaml ../include/rocsparse_common.yaml known_bugs.yaml test_axpyi.yaml test_doti.yaml test_dotci.yaml test_gthr.yaml test_gthrz.yaml test_roti.yaml test_sctr.yaml test_bsrmv.yaml test_bsrsv.yaml test_coomv.yaml test_csrmv.yaml test_csrsv.yaml test_ellmv.yaml test_hybmv.yaml test_bsrmm.yaml test_csrmm.yaml test_csrsm.yaml test_csrgeam.yaml test_csrgemm.yaml test_csric0.yaml test_csrilu0.yaml test_csr2coo.yaml test_csr2csc.yaml test_csr2ell.yaml test_csr2hyb.yaml test_bsr2csr.yaml test_csr2bsr.yaml test_coo2csr.yaml test_ell2csr.yaml test_hyb2csr.yaml test_identity.yaml test_csrsort.yaml test_cscsort.yaml test_coosort.yaml test_csricsv.yaml test_csrilusv.yaml test_nnz.yaml test_dense2csr.yaml test_dense2csc.yaml test_csr2dense.yaml test_csc2dense.yaml test_csr2csr_compress.cpp WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}") add_custom_target(rocsparse-test-data DEPENDS "${ROCSPARSE_TEST_DATA}" ) diff --git a/clients/tests/rocsparse_test.yaml b/clients/tests/rocsparse_test.yaml index 6c1b6096..1513081d 100644 --- a/clients/tests/rocsparse_test.yaml +++ b/clients/tests/rocsparse_test.yaml @@ -35,6 +35,7 @@ include: test_csrmv.yaml include: test_csrsv.yaml include: test_ellmv.yaml include: test_hybmv.yaml +include: test_bsrmm.yaml include: test_csrmm.yaml include: test_csrsm.yaml include: test_csrgeam.yaml diff --git a/clients/tests/test_bsrmm.cpp b/clients/tests/test_bsrmm.cpp new file mode 100644 index 00000000..72cdf5a7 --- /dev/null +++ b/clients/tests/test_bsrmm.cpp @@ -0,0 +1,117 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#include "rocsparse_data.hpp" +#include "rocsparse_datatype2string.hpp" +#include "rocsparse_test.hpp" +#include "testing_bsrmm.hpp" +#include "type_dispatch.hpp" + +#include +#include +#include + +namespace +{ + // By default, this test does not apply to any types. + // The unnamed second parameter is used for enable_if below. + template + struct bsrmm_testing : rocsparse_test_invalid + { + }; + + // When the condition in the second argument is satisfied, the type combination + // is valid. When the condition is false, this specialization does not apply. + template + struct bsrmm_testing< + T, + typename std::enable_if{} || std::is_same{} + || std::is_same{} + || std::is_same{}>::type> + { + explicit operator bool() + { + return true; + } + void operator()(const Arguments& arg) + { + if(!strcmp(arg.function, "bsrmm")) + testing_bsrmm(arg); + else if(!strcmp(arg.function, "bsrmm_bad_arg")) + testing_bsrmm_bad_arg(arg); + else + FAIL() << "Internal error: Test called with unknown function: " << arg.function; + } + }; + + struct bsrmm : RocSPARSE_Test + { + // Filter for which types apply to this suite + static bool type_filter(const Arguments& arg) + { + return rocsparse_simple_dispatch(arg); + } + + // Filter for which functions apply to this suite + static bool function_filter(const Arguments& arg) + { + return !strcmp(arg.function, "bsrmm") || !strcmp(arg.function, "bsrmm_bad_arg"); + } + + // Google Test name suffix based on parameters + static std::string name_suffix(const Arguments& arg) + { + if(arg.matrix == rocsparse_matrix_file_rocalution + || arg.matrix == rocsparse_matrix_file_mtx) + { + return RocSPARSE_TestName{} + << rocsparse_datatype2string(arg.compute_type) << '_' << arg.N << '_' + << arg.block_dim << '_' << rocsparse_direction2string(arg.direction) + << arg.alpha << '_' << arg.alphai << '_' << arg.beta << '_' << arg.betai + << '_' << rocsparse_operation2string(arg.transA) << '_' + << rocsparse_operation2string(arg.transB) << '_' + << rocsparse_indexbase2string(arg.baseA) << '_' + << rocsparse_matrix2string(arg.matrix) << '_' << arg.filename; + } + else + { + return RocSPARSE_TestName{} + << rocsparse_datatype2string(arg.compute_type) << '_' << arg.M << '_' + << arg.N << '_' << arg.K << '_' << arg.block_dim << '_' + << rocsparse_direction2string(arg.direction) << '_' << arg.alpha << '_' + << arg.alphai << '_' << arg.beta << '_' << arg.betai << '_' + << rocsparse_operation2string(arg.transA) << '_' + << rocsparse_operation2string(arg.transB) << '_' + << rocsparse_indexbase2string(arg.baseA) << '_' + << rocsparse_matrix2string(arg.matrix); + } + } + }; + + TEST_P(bsrmm, level3) + { + rocsparse_simple_dispatch(GetParam()); + } + INSTANTIATE_TEST_CATEGORIES(bsrmm); + +} // namespace diff --git a/clients/tests/test_bsrmm.yaml b/clients/tests/test_bsrmm.yaml new file mode 100644 index 00000000..f6517f9a --- /dev/null +++ b/clients/tests/test_bsrmm.yaml @@ -0,0 +1,207 @@ +# ######################################################################## +# Copyright (c) 2020 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ######################################################################## + +--- +include: rocsparse_common.yaml +include: known_bugs.yaml + +Definitions: + - &alpha_beta_range_quick + - { alpha: 1.0, beta: -1.0, alphai: 1.0, betai: -0.5 } + - { alpha: -0.5, beta: 0.5, alphai: -0.5, betai: 1.0 } + + - &alpha_beta_range_checkin + - { alpha: 2.0, beta: 0.0, alphai: 0.5, betai: 0.5 } + - { alpha: 0.0, beta: 1.0, alphai: 1.5, betai: 0.5 } + - { alpha: 3.0, beta: 1.0, alphai: 0.0, betai: -0.5 } + + - &alpha_beta_range_nightly + - { alpha: 0.0, beta: 0.0, alphai: 1.5, betai: 0.5 } + - { alpha: 2.0, beta: 0.67, alphai: 0.0, betai: 1.5 } + - { alpha: 3.0, beta: 1.0, alphai: 1.5, betai: 0.0 } + - { alpha: -0.5, beta: 0.5, alphai: 1.0, betai: -0.5 } + +Tests: +- name: bsrmm_bad_arg + category: pre_checkin + function: bsrmm_bad_arg + precision: *single_double_precisions_complex_real + +- name: bsrmm + category: quick + function: bsrmm + precision: *single_double_precisions_complex_real + M: [-1, 275, 708] + N: [-1, 128, 628] + K: [-1, 173, 747] + block_dim: [0, 5, 7, 16] + alpha_beta: *alpha_beta_range_quick + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none, rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + direction: [rocsparse_direction_row] + matrix: [rocsparse_matrix_random] + +- name: bsrmm + category: pre_checkin + function: bsrmm + precision: *single_double_precisions_complex_real + M: [0, 511, 2059] + N: [0, 7, 33] + K: [0, 391, 1375] + block_dim: [17, 25] + alpha_beta: *alpha_beta_range_checkin + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none, rocsparse_operation_transpose] + baseA: [rocsparse_index_base_one] + direction: [rocsparse_direction_column] + matrix: [rocsparse_matrix_random] + +- name: bsrmm + category: nightly + function: bsrmm + precision: *single_double_precisions_complex_real + M: [3943, 94912] + N: [27, 49] + K: [4134, 73291] + block_dim: [2, 9] + alpha_beta: *alpha_beta_range_nightly + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none, rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + direction: [rocsparse_direction_row] + matrix: [rocsparse_matrix_random] + +- name: bsrmm_file + category: quick + function: bsrmm + precision: *single_double_precisions + M: 1 + N: [4, 19] + K: 1 + block_dim: [4] + alpha_beta: *alpha_beta_range_quick + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none, rocsparse_operation_transpose] + baseA: [rocsparse_index_base_one] + direction: [rocsparse_direction_column] + matrix: [rocsparse_matrix_file_rocalution] + filename: [mac_econ_fwd500, + nos2, + nos4, + nos6, + scircuit] + +- name: bsrmm_file + category: pre_checkin + function: bsrmm + precision: *single_double_precisions + M: 1 + N: [73] + K: 1 + block_dim: [8] + alpha_beta: *alpha_beta_range_checkin + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none] + baseA: [rocsparse_index_base_zero] + direction: [rocsparse_direction_row] + matrix: [rocsparse_matrix_file_rocalution] + filename: [rma10, + mc2depi, + ASIC_320k, + nos1, + nos3, + nos5, + nos7] + +- name: bsrmm_file + category: nightly + function: bsrmm + precision: *single_double_precisions + M: 1 + N: [38] + K: 1 + block_dim: [3] + alpha_beta: *alpha_beta_range_nightly + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + direction: [rocsparse_direction_row] + matrix: [rocsparse_matrix_file_rocalution] + filename: [bibd_22_8, + bmwcra_1, + amazon0312, + Chebyshev4, + sme3Dc, + webbase-1M, + shipsec1] + +- name: bsrmm_file + category: quick + function: bsrmm + precision: *single_double_precisions_complex + M: 1 + N: [3, 21] + K: 1 + block_dim: [6] + alpha_beta: *alpha_beta_range_quick + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none] + baseA: [rocsparse_index_base_one] + direction: [rocsparse_direction_column] + matrix: [rocsparse_matrix_file_rocalution] + filename: [Chevron2, + qc2534] + +- name: bsrmm_file + category: pre_checkin + function: bsrmm + precision: *single_double_precisions_complex + M: 1 + N: [68] + K: 1 + block_dim: [11] + alpha_beta: *alpha_beta_range_checkin + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + direction: [rocsparse_direction_row] + matrix: [rocsparse_matrix_file_rocalution] + filename: [mplate, + Chevron3] + +- name: bsrmm_file + category: nightly + function: bsrmm + precision: *single_double_precisions_complex + M: 1 + N: [40] + K: 1 + block_dim: [7] + alpha_beta: *alpha_beta_range_nightly + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_none] + baseA: [rocsparse_index_base_one] + direction: [rocsparse_direction_column] + matrix: [rocsparse_matrix_file_rocalution] + filename: [Chevron4] \ No newline at end of file diff --git a/docs/source/usermanual.rst b/docs/source/usermanual.rst index d859652c..b29b466a 100644 --- a/docs/source/usermanual.rst +++ b/docs/source/usermanual.rst @@ -642,6 +642,7 @@ Sparse Level 3 Functions ========================================================================= ====== ====== ============== ============== Function name single double single complex double complex ========================================================================= ====== ====== ============== ============== +:cpp:func:`rocsparse_Xbsrmm() ` x x x x :cpp:func:`rocsparse_Xcsrmm() ` x x x x :cpp:func:`rocsparse_Xcsrsm_buffer_size() ` x x x x :cpp:func:`rocsparse_Xcsrsm_analysis() ` x x x x @@ -1144,6 +1145,17 @@ This module holds all sparse level 3 routines. The sparse level 3 routines describe operations between a matrix in sparse format and multiple vectors in dense format that can also be seen as a dense matrix. +rocsparse_bsrmm() +----------------- + +.. doxygenfunction:: rocsparse_sbsrmm + :outline: +.. doxygenfunction:: rocsparse_dbsrmm + :outline: +.. doxygenfunction:: rocsparse_cbsrmm + :outline: +.. doxygenfunction:: rocsparse_zbsrmm + rocsparse_csrmm() ----------------- diff --git a/library/include/rocsparse-functions.h b/library/include/rocsparse-functions.h index 1b7b3fda..cf330d50 100644 --- a/library/include/rocsparse-functions.h +++ b/library/include/rocsparse-functions.h @@ -2539,6 +2539,250 @@ rocsparse_status rocsparse_zhybmv(rocsparse_handle handle, * =========================================================================== */ +/*! \ingroup level3_module + * \brief Sparse matrix dense matrix multiplication using BSR storage format + * + * \details + * \p rocsparse_bsrmm multiplies the scalar \f$\alpha\f$ with a sparse \f$mb \times kb\f$ + * matrix \f$A\f$, defined in BSR storage format, and the dense \f$k \times n\f$ + * matrix \f$B\f$ (where \f$k = block\_dim \times kb\f$) and adds the result to the dense + * \f$m \times n\f$ matrix \f$C\f$ (where \f$m = block\_dim \times mb\f$) that + * is multiplied by the scalar \f$\beta\f$, such that + * \f[ + * C := \alpha \cdot op(A) \cdot op(B) + \beta \cdot C, + * \f] + * with + * \f[ + * op(A) = \left\{ + * \begin{array}{ll} + * A, & \text{if trans_A == rocsparse_operation_none} \\ + * \end{array} + * \right. + * \f] + * and + * \f[ + * op(B) = \left\{ + * \begin{array}{ll} + * B, & \text{if trans_B == rocsparse_operation_none} \\ + * B^T, & \text{if trans_B == rocsparse_operation_transpose} \\ + * \end{array} + * \right. + * \f] + * + * \note + * This function is non blocking and executed asynchronously with respect to the host. + * It may return before the actual computation has finished. + * + * \note + * Currently, only \p trans_A == \ref rocsparse_operation_none is supported. + * + * @param[in] + * handle handle to the rocsparse library context queue. + * @param[in] + * dir the storage format of the blocks. Can be \ref rocsparse_direction_row or \ref rocsparse_direction_column. + * @param[in] + * trans_A matrix \f$A\f$ operation type. Currently, only \ref rocsparse_operation_none is supported. + * @param[in] + * trans_B matrix \f$B\f$ operation type. Currently, only \ref rocsparse_operation_none and rocsparse_operation_transpose + * are supported. + * @param[in] + * mb number of block rows of the sparse BSR matrix \f$A\f$. + * @param[in] + * n number of columns of the dense matrix \f$op(B)\f$ and \f$C\f$. + * @param[in] + * kb number of block columns of the sparse BSR matrix \f$A\f$. + * @param[in] + * nnzb number of non-zero blocks of the sparse BSR matrix \f$A\f$. + * @param[in] + * alpha scalar \f$\alpha\f$. + * @param[in] + * descr descriptor of the sparse BSR matrix \f$A\f$. Currently, only + * \ref rocsparse_matrix_type_general is supported. + * @param[in] + * bsr_val array of \p nnzb*block_dim*block_dim elements of the sparse BSR matrix \f$A\f$. + * @param[in] + * bsr_row_ptr array of \p mb+1 elements that point to the start of every block row of the + * sparse BSR matrix \f$A\f$. + * @param[in] + * bsr_col_ind array of \p nnzb elements containing the block column indices of the sparse + * BSR matrix \f$A\f$. + * @param[in] + * block_dim size of the blocks in the sparse BSR matrix. + * @param[in] + * B array of dimension \f$ldb \times n\f$ (\f$op(B) == B\f$) or + * \f$ldb \times k\f$ (\f$op(B) == B^T\f$). + * @param[in] + * ldb leading dimension of \f$B\f$, must be at least \f$\max{(1, k)}\f$ where \f$k = block\_dim \times kb\f$. + * @param[in] + * beta scalar \f$\beta\f$. + * @param[inout] + * C array of dimension \f$ldc \times n\f$. + * @param[in] + * ldc leading dimension of \f$C\f$, must be at least \f$\max{(1, m)}\f$ where \f$m = block\_dim \times mb\f$. + * + * \retval rocsparse_status_success the operation completed successfully. + * \retval rocsparse_status_invalid_handle the library context was not initialized. + * \retval rocsparse_status_invalid_size \p mb, \p n, \p kb, \p nnzb, \p ldb or \p ldc + * is invalid. + * \retval rocsparse_status_invalid_pointer \p descr, \p alpha, \p bsr_val, + * \p bsr_row_ptr, \p bsr_col_ind, \p B, \p beta or \p C pointer is invalid. + * \retval rocsparse_status_arch_mismatch the device is not supported. + * \retval rocsparse_status_not_implemented + * \p trans_A != \ref rocsparse_operation_none or + * \p trans_B == \ref rocsparse_operation_conjugate_transpose or + * \ref rocsparse_matrix_type != \ref rocsparse_matrix_type_general. + * + * \par Example + * This example multiplies a BSR matrix with a dense matrix. + * \code{.c} + * // 1 2 0 3 0 0 + * // A = 0 4 5 0 0 0 + * // 0 0 0 7 8 0 + * // 0 0 1 2 4 1 + * + * rocsparse_int block_dim = 2; + * rocsparse_int mb = 2; + * rocsparse_int kb = 3; + * rocsparse_int nnzb = 4; + * rocsparse_direction dir = rocsparse_direction_row; + * + * bsr_row_ptr[mb+1] = {0, 2, 4}; // device memory + * bsr_col_ind[nnzb] = {0, 1, 1, 2}; // device memory + * bsr_val[nnzb*block_dim*block_dim] = {1, 2, 0, 4, 0, 3, 5, 0, 0, 7, 1, 2, 8, 0, 4, 1}; // device memory + * + * // Set dimension n of B + * rocsparse_int n = 64; + * rocsparse_int m = mb * block_dim; + * rocsparse_int k = kb * block_dim; + * + * // Allocate and generate dense matrix B + * std::vector hB(k * n); + * for(rocsparse_int i = 0; i < k * n; ++i) + * { + * hB[i] = static_cast(rand()) / RAND_MAX; + * } + * + * // Copy B to the device + * float* B; + * hipMalloc((void**)&B, sizeof(float) * k * n); + * hipMemcpy(B, hB.data(), sizeof(float) * k * n, hipMemcpyHostToDevice); + * + * // alpha and beta + * float alpha = 1.0f; + * float beta = 0.0f; + * + * // Allocate memory for the resulting matrix C + * float* C; + * hipMalloc((void**)&C, sizeof(float) * m * n); + * + * // Perform the matrix multiplication + * rocsparse_sbsrmm(handle, + * dir, + * rocsparse_operation_none, + * rocsparse_operation_none, + * mb, + * n, + * kb, + * nnzb, + * &alpha, + * descr, + * bsr_val, + * bsr_row_ptr, + * bsr_col_ind, + * block_dim, + * B, + * k, + * &beta, + * C, + * m); + * \endcode + */ +/**@{*/ +ROCSPARSE_EXPORT +rocsparse_status rocsparse_sbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const float* alpha, + const rocsparse_mat_descr descr, + const float* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const float* B, + rocsparse_int ldb, + const float* beta, + float* C, + rocsparse_int ldc); + +ROCSPARSE_EXPORT +rocsparse_status rocsparse_dbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const double* alpha, + const rocsparse_mat_descr descr, + const double* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const double* B, + rocsparse_int ldb, + const double* beta, + double* C, + rocsparse_int ldc); + +ROCSPARSE_EXPORT +rocsparse_status rocsparse_cbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_float_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_float_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_float_complex* B, + rocsparse_int ldb, + const rocsparse_float_complex* beta, + rocsparse_float_complex* C, + rocsparse_int ldc); + +ROCSPARSE_EXPORT +rocsparse_status rocsparse_zbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_double_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_double_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_double_complex* B, + rocsparse_int ldb, + const rocsparse_double_complex* beta, + rocsparse_double_complex* C, + rocsparse_int ldc); +/**@}*/ + /*! \ingroup level3_module * \brief Sparse matrix dense matrix multiplication using CSR storage format * diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index b22630c9..a30dc45d 100644 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -51,6 +51,7 @@ set(rocsparse_source src/level2/rocsparse_hybmv.cpp # Level3 + src/level3/rocsparse_bsrmm.cpp src/level3/rocsparse_csrmm.cpp src/level3/rocsparse_csrsm.cpp diff --git a/library/src/level3/bsrmm_device.h b/library/src/level3/bsrmm_device.h new file mode 100644 index 00000000..315e0db2 --- /dev/null +++ b/library/src/level3/bsrmm_device.h @@ -0,0 +1,583 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#pragma once +#ifndef BSRMM_DEVICE_H +#define BSRMM_DEVICE_H + +#include "common.h" + +#include + +template +static __device__ void bsrmmnn_small_blockdim_device(rocsparse_direction direction, + rocsparse_int Mb, + rocsparse_int N, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + constexpr rocsparse_int PADDED_BSR_BLOCK_DIM = (BSR_BLOCK_DIM + 1); + + rocsparse_int tid = hipThreadIdx_x; + rocsparse_int gid = hipBlockIdx_x * hipBlockDim_x + tid; + rocsparse_int lid = gid & (WF_SIZE - 1); + rocsparse_int wid = tid / WF_SIZE; + rocsparse_int nwfb = hipGridDim_x * hipBlockDim_x / (WF_SIZE * BSR_BLOCK_DIM); + rocsparse_int col = lid + hipBlockIdx_y * WF_SIZE; + + rocsparse_int colB = col * ldb; + rocsparse_int colC = col * ldc; + + // global row + rocsparse_int global_row = (gid / WF_SIZE); + + // local row within block row + rocsparse_int local_row = (gid / WF_SIZE) % BSR_BLOCK_DIM; + + __shared__ rocsparse_int shared_col[BLOCKSIZE / WF_SIZE][WF_SIZE]; + __shared__ T shared_val[BLOCKSIZE / WF_SIZE][WF_SIZE * PADDED_BSR_BLOCK_DIM]; + + for(rocsparse_int block_row = gid / (WF_SIZE * BSR_BLOCK_DIM); block_row < Mb; + block_row += nwfb) + { + rocsparse_int block_row_start = bsr_row_ptr[block_row] - idx_base; + rocsparse_int block_row_end = bsr_row_ptr[block_row + 1] - idx_base; + + T sum = static_cast(0); + + for(rocsparse_int j = block_row_start; j < block_row_end; j += WF_SIZE) + { + rocsparse_int k = j + lid; + + shared_col[wid][lid] = (k < block_row_end) ? BSR_BLOCK_DIM * (bsr_col_ind[k] - idx_base) : 0; + + if(direction == rocsparse_direction_row) + { + // Perform: + // for(rocsparse_int l = 0; l < BSR_BLOCK_DIM; l++) + // { + // shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + l] + // = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + // + BSR_BLOCK_DIM * local_row + l] + // : static_cast(0); + // } + // as unrolled loop. + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row] + : static_cast(0); + if(BSR_BLOCK_DIM >= 2) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 1] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 1] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 3) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 2] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 2] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 4) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 3] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 3] + : static_cast(0); + } + } + else + { + // Perform: + // for(rocsparse_int l = 0; l < BSR_BLOCK_DIM; l++) + // { + // shared_val[wid][BSR_BLOCK_DIM * lid + l] + // = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + // + BSR_BLOCK_DIM * l + local_row] + // : static_cast(0); + // } + // as unrolled loop. + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + local_row] + : static_cast(0); + if(BSR_BLOCK_DIM >= 2) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 1] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 1 + local_row] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 3) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 2] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 2 + local_row] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 4) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 3] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 3 + local_row] + : static_cast(0); + } + } + + __syncthreads(); + + if(col < N) + { + for(rocsparse_int i = 0; i < WF_SIZE; ++i) + { + // Perform: + // for(rocsparse_int l = 0; l < BSR_BLOCK_DIM; l++) + // { + // sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + l], + // B[shared_col[wid][i] + l], + // sum); + // } + // as unrolled loop. + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i], + B[shared_col[wid][i] + colB], + sum); + if(BSR_BLOCK_DIM >= 2) + { + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 1], + B[shared_col[wid][i] + 1 + colB], + sum); + } + if(BSR_BLOCK_DIM >= 3) + { + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 2], + B[shared_col[wid][i] + 2 + colB], + sum); + } + if(BSR_BLOCK_DIM >= 4) + { + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 3], + B[shared_col[wid][i] + 3 + colB], + sum); + } + } + } + } + + if(col < N) + { + if(beta == static_cast(0)) + { + C[global_row + colC] = alpha * sum; + } + else + { + C[global_row + colC] = rocsparse_fma(beta, C[global_row + colC], alpha * sum); + } + } + } +} + +template +static __device__ void bsrmmnt_small_blockdim_device(rocsparse_direction direction, + rocsparse_int Mb, + rocsparse_int N, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + constexpr rocsparse_int PADDED_BSR_BLOCK_DIM = (BSR_BLOCK_DIM + 1); + + rocsparse_int tid = hipThreadIdx_x; + rocsparse_int gid = hipBlockIdx_x * hipBlockDim_x + tid; + rocsparse_int block_row = gid / (WF_SIZE * BSR_BLOCK_DIM); + rocsparse_int global_row = gid / WF_SIZE; + rocsparse_int local_row = (gid / WF_SIZE) % BSR_BLOCK_DIM; + rocsparse_int lid = tid & (WF_SIZE - 1); + rocsparse_int wid = tid / WF_SIZE; + + if(block_row >= Mb) + { + return; + } + + __shared__ rocsparse_int shared_col[BLOCKSIZE / WF_SIZE][WF_SIZE]; + __shared__ T shared_val[BLOCKSIZE / WF_SIZE][WF_SIZE * PADDED_BSR_BLOCK_DIM]; + + rocsparse_int block_row_start = bsr_row_ptr[block_row] - idx_base; + rocsparse_int block_row_end = bsr_row_ptr[block_row + 1] - idx_base; + + for(rocsparse_int l = 0; l < N; l += WF_SIZE) + { + rocsparse_int col = l + lid; + T sum = static_cast(0); + + for(rocsparse_int j = block_row_start; j < block_row_end; j += WF_SIZE) + { + rocsparse_int k = j + lid; + + shared_col[wid][lid] = (k < block_row_end) ? N * BSR_BLOCK_DIM * (bsr_col_ind[k] - idx_base) : 0; + + if(direction == rocsparse_direction_row) + { + // Perform: + // for(rocsparse_int p = 0; p < BSR_BLOCK_DIM; p++) + // { + // shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + p] + // = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + // + BSR_BLOCK_DIM * local_row + p] + // : static_cast(0); + // } + // as unrolled loop. + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row] + : static_cast(0); + if(BSR_BLOCK_DIM >= 2) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 1] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 1] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 3) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 2] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 2] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 4) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 3] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * local_row + 3] + : static_cast(0); + } + } + else + { + // Perform: + // for(rocsparse_int p = 0; p < BSR_BLOCK_DIM; p++) + // { + // shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + p] + // = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + // + BSR_BLOCK_DIM * p + local_row] + // : static_cast(0); + // } + // as unrolled loop. + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + local_row] + : static_cast(0); + if(BSR_BLOCK_DIM >= 2) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 1] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 1 + local_row] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 3) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 2] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 2 + local_row] + : static_cast(0); + } + if(BSR_BLOCK_DIM >= 4) + { + shared_val[wid][PADDED_BSR_BLOCK_DIM * lid + 3] + = (k < block_row_end) ? bsr_val[BSR_BLOCK_DIM * BSR_BLOCK_DIM * k + + BSR_BLOCK_DIM * 3 + local_row] + : static_cast(0); + } + } + + __syncthreads(); + + if(col < N) + { + for(rocsparse_int i = 0; i < WF_SIZE; ++i) + { + // Perform: + // for(rocsparse_int p = 0; p < BSR_BLOCK_DIM; p++) + // { + // T val_B = rocsparse_ldg(B + col + N * p + shared_col[wid][i]); + // sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + p], val_B, sum); + // } + // as unrolled loop. + T val_B + = rocsparse_ldg(B + col + shared_col[wid][i]); + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i], val_B, sum); + if(BSR_BLOCK_DIM >= 2) + { + val_B + = rocsparse_ldg(B + col + N * 1 + shared_col[wid][i]); + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 1], val_B, sum); + } + if(BSR_BLOCK_DIM >= 3) + { + val_B + = rocsparse_ldg(B + col + N * 2 + shared_col[wid][i]); + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 2], val_B, sum); + } + if(BSR_BLOCK_DIM >= 4) + { + val_B + = rocsparse_ldg(B + col + N * 3 + shared_col[wid][i]); + sum = rocsparse_fma(shared_val[wid][PADDED_BSR_BLOCK_DIM * i + 3], val_B, sum); + } + } + } + } + + if(col < N) + { + if(beta == static_cast(0)) + { + C[global_row + col * ldc] = alpha * sum; + } + else + { + C[global_row + col * ldc] = rocsparse_fma(beta, C[global_row + col * ldc], alpha * sum); + } + } + } +} + +template +static __device__ void bsrmm_large_blockdim_device(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int Mb, + rocsparse_int N, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + rocsparse_int tidx = hipThreadIdx_x; + rocsparse_int tidy = hipThreadIdx_y; + + rocsparse_int global_row = tidx + hipBlockIdx_x * block_dim; + rocsparse_int global_col = tidy + hipBlockIdx_y * BLK_SIZE_Y; + + rocsparse_int block_row = hipBlockIdx_x; + + rocsparse_int block_row_start = 0; + rocsparse_int block_row_end = 0; + if(block_row < Mb) + { + block_row_start = bsr_row_ptr[block_row] - idx_base; + block_row_end = bsr_row_ptr[block_row + 1] - idx_base; + } + + rocsparse_int colB = global_col * ldb; + rocsparse_int colC = global_col * ldc; + + __shared__ T shared_B[BSR_BLOCK_DIM * BLK_SIZE_Y]; + __shared__ T shared_A[BSR_BLOCK_DIM * BSR_BLOCK_DIM]; + + T sum = static_cast(0); + + rocsparse_int index = BSR_BLOCK_DIM * tidy + tidx; + rocsparse_int block_dim_sqr = block_dim * block_dim; + + for(rocsparse_int k = block_row_start; k < block_row_end; k++) + { + rocsparse_int block_col = (bsr_col_ind[k] - idx_base); + + if(trans_B == rocsparse_operation_none) + { + shared_B[index] + = (global_col < N && tidx < block_dim) ? B[block_dim * block_col + tidx + colB] : static_cast(0); + } + else + { + shared_B[index] + = (global_col < N && tidx < block_dim) ? B[global_col + ldb * (block_dim * block_col + tidx)] : static_cast(0); + } + + if(direction == rocsparse_direction_row) + { + if(tidx < block_dim && tidy < block_dim) + { + shared_A[index] = bsr_val[block_dim_sqr * k + block_dim * tidx + tidy]; + } + } + else + { + if(tidx < block_dim && tidy < block_dim) + { + shared_A[index] + = bsr_val[block_dim_sqr * k + block_dim * tidy + tidx]; + } + } + + __syncthreads(); + + for(rocsparse_int j = 0; j < block_dim; j++) + { + sum = rocsparse_fma(shared_A[BSR_BLOCK_DIM * j + tidx], + shared_B[BSR_BLOCK_DIM * tidy + j], + sum); + } + + __syncthreads(); + } + + if(block_row < Mb && global_col < N && tidx < block_dim) + { + if(beta == static_cast(0)) + { + C[global_row + colC] = alpha * sum; + } + else + { + C[global_row + colC] = rocsparse_fma(beta, C[global_row + colC], alpha * sum); + } + } +} + +template +static __device__ void bsrmm_general_blockdim_device(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int Mb, + rocsparse_int N, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + rocsparse_int tidx = hipThreadIdx_x; + rocsparse_int tidy = hipThreadIdx_y; + + rocsparse_int block_row = hipBlockIdx_x; + + rocsparse_int block_row_start = 0; + rocsparse_int block_row_end = 0; + if(block_row < Mb) + { + block_row_start = bsr_row_ptr[block_row] - idx_base; + block_row_end = bsr_row_ptr[block_row + 1] - idx_base; + } + + __shared__ T shared_B[BSR_BLOCK_DIM * BLK_SIZE_Y]; + __shared__ T shared_A[BSR_BLOCK_DIM * BSR_BLOCK_DIM]; + + rocsparse_int global_col = tidy + hipBlockIdx_y * BLK_SIZE_Y; + + rocsparse_int colB = global_col * ldb; + rocsparse_int colC = global_col * ldc; + + for(rocsparse_int x = 0; x < block_dim; x += BSR_BLOCK_DIM) + { + rocsparse_int global_row = tidx + x + hipBlockIdx_x * block_dim; + + T sum = static_cast(0); + + for(rocsparse_int k = block_row_start; k < block_row_end; k++) + { + rocsparse_int block_col = (bsr_col_ind[k] - idx_base); + + for(rocsparse_int y = 0; y < block_dim; y += BLK_SIZE_Y) + { + if(trans_B == rocsparse_operation_none) + { + shared_B[BSR_BLOCK_DIM * tidy + tidx] + = (global_col < N && (tidx + y) < block_dim) ? B[block_dim * block_col + (tidx + y) + colB] : static_cast(0); + } + else + { + shared_B[BSR_BLOCK_DIM * tidy + tidx] + = (global_col < N && (tidx + y) < block_dim) ? B[global_col + ldb * (block_dim * block_col + (tidx + y))] : static_cast(0); + } + + + if(direction == rocsparse_direction_row) + { + shared_A[BSR_BLOCK_DIM * tidy + tidx] + = ((tidx + x) < block_dim && (tidy + y) < block_dim) ? bsr_val[block_dim * block_dim * k + block_dim * (tidx + x) + (tidy + y)] : static_cast(0); + } + else + { + shared_A[BSR_BLOCK_DIM * tidy + tidx] + = ((tidx + x) < block_dim && (tidy + y) < block_dim) ? bsr_val[block_dim * block_dim * k + block_dim * (tidy + y) + (tidx + x)] : static_cast(0); + } + + __syncthreads(); + + for(rocsparse_int j = 0; j < BSR_BLOCK_DIM; j++) + { + sum = rocsparse_fma(shared_A[BSR_BLOCK_DIM * j + tidx], + shared_B[BSR_BLOCK_DIM * tidy + j], + sum); + } + + __syncthreads(); + } + } + + if(block_row < Mb && global_col < N && (tidx + x) < block_dim) + { + if(beta == static_cast(0)) + { + C[global_row + colC] = alpha * sum; + } + else + { + C[global_row + colC] = rocsparse_fma(beta, C[global_row + colC], alpha * sum); + } + } + } +} + +#endif // BSRMM_DEVICE_H \ No newline at end of file diff --git a/library/src/level3/csrmm_device.h b/library/src/level3/csrmm_device.h index 1d042d59..99011412 100644 --- a/library/src/level3/csrmm_device.h +++ b/library/src/level3/csrmm_device.h @@ -72,7 +72,7 @@ static __device__ void csrmmnn_general_device(rocsparse_int M, __syncthreads(); shared_col[wid][lid] = (k < row_end) ? csr_col_ind[k] - idx_base : 0; - shared_val[wid][lid] = (k < row_end) ? alpha * csr_val[k] : static_cast(0); + shared_val[wid][lid] = (k < row_end) ? csr_val[k] : static_cast(0); __syncthreads(); @@ -86,11 +86,11 @@ static __device__ void csrmmnn_general_device(rocsparse_int M, { if(beta == static_cast(0)) { - C[row + colC] = sum; + C[row + colC] = alpha * sum; } else { - C[row + colC] = rocsparse_fma(beta, C[row + colC], sum); + C[row + colC] = rocsparse_fma(beta, C[row + colC], alpha * sum); } } } @@ -143,7 +143,7 @@ static __device__ void csrmmnt_general_device(rocsparse_int offset, __syncthreads(); shared_col[wid][lid] = (k < row_end) ? N * (csr_col_ind[k] - idx_base) : 0; - shared_val[wid][lid] = (k < row_end) ? alpha * csr_val[k] : static_cast(0); + shared_val[wid][lid] = (k < row_end) ? csr_val[k] : static_cast(0); __syncthreads(); @@ -159,11 +159,11 @@ static __device__ void csrmmnt_general_device(rocsparse_int offset, { if(beta == static_cast(0)) { - C[row + col * ldc] = sum; + C[row + col * ldc] = alpha * sum; } else { - C[row + col * ldc] = rocsparse_fma(beta, C[row + col * ldc], sum); + C[row + col * ldc] = rocsparse_fma(beta, C[row + col * ldc], alpha * sum); } } } diff --git a/library/src/level3/rocsparse_bsrmm.cpp b/library/src/level3/rocsparse_bsrmm.cpp new file mode 100644 index 00000000..1e481059 --- /dev/null +++ b/library/src/level3/rocsparse_bsrmm.cpp @@ -0,0 +1,195 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#include "rocsparse.h" + +#include "rocsparse_bsrmm.hpp" + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ +extern "C" rocsparse_status rocsparse_sbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const float* alpha, + const rocsparse_mat_descr descr, + const float* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const float* B, + rocsparse_int ldb, + const float* beta, + float* C, + rocsparse_int ldc) +{ + return rocsparse_bsrmm_template(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +extern "C" rocsparse_status rocsparse_dbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const double* alpha, + const rocsparse_mat_descr descr, + const double* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const double* B, + rocsparse_int ldb, + const double* beta, + double* C, + rocsparse_int ldc) +{ + return rocsparse_bsrmm_template(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +extern "C" rocsparse_status rocsparse_cbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_float_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_float_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_float_complex* B, + rocsparse_int ldb, + const rocsparse_float_complex* beta, + rocsparse_float_complex* C, + rocsparse_int ldc) +{ + return rocsparse_bsrmm_template(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} + +extern "C" rocsparse_status rocsparse_zbsrmm(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const rocsparse_double_complex* alpha, + const rocsparse_mat_descr descr, + const rocsparse_double_complex* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const rocsparse_double_complex* B, + rocsparse_int ldb, + const rocsparse_double_complex* beta, + rocsparse_double_complex* C, + rocsparse_int ldc) +{ + return rocsparse_bsrmm_template(handle, + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + ldb, + beta, + C, + ldc); +} diff --git a/library/src/level3/rocsparse_bsrmm.hpp b/library/src/level3/rocsparse_bsrmm.hpp new file mode 100644 index 00000000..9cfcb7e7 --- /dev/null +++ b/library/src/level3/rocsparse_bsrmm.hpp @@ -0,0 +1,868 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#pragma once +#ifndef ROCSPARSE_BSRMM_HPP +#define ROCSPARSE_BSRMM_HPP + +#include "bsrmm_device.h" +#include "handle.h" +#include "rocsparse.h" +#include "rocsparse_csrmm.hpp" +#include "../level2/rocsparse_bsrmv.hpp" +#include "utility.h" + +#include + +#define launch_bsrmmnn_small_blockdim_kernel_host_pointer(T, block_size, wf_size, bsr_block_dim) \ + hipLaunchKernelGGL( \ + (bsrmmnn_small_blockdim_kernel_host_pointer), \ + bsrmmnn_blocks, \ + bsrmmnn_threads, \ + 0, \ + stream, \ + dir, \ + mb, \ + n, \ + *alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + B, \ + ldb, \ + *beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmmnn_small_blockdim_kernel_device_pointer(T, block_size, wf_size, bsr_block_dim) \ + hipLaunchKernelGGL( \ + (bsrmmnn_small_blockdim_kernel_device_pointer), \ + bsrmmnn_blocks, \ + bsrmmnn_threads, \ + 0, \ + stream, \ + dir, \ + mb, \ + n, \ + alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + B, \ + ldb, \ + beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmmnt_small_blockdim_kernel_host_pointer(T, block_size, wf_size, bsr_block_dim) \ + hipLaunchKernelGGL( \ + (bsrmmnt_small_blockdim_kernel_host_pointer), \ + bsrmmnt_blocks, \ + bsrmmnt_threads, \ + 0, \ + stream, \ + dir, \ + mb, \ + n, \ + *alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + B, \ + ldb, \ + *beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmmnt_small_blockdim_kernel_device_pointer(T, block_size, wf_size, bsr_block_dim) \ + hipLaunchKernelGGL( \ + (bsrmmnt_small_blockdim_kernel_device_pointer), \ + bsrmmnt_blocks, \ + bsrmmnt_threads, \ + 0, \ + stream, \ + dir, \ + mb, \ + n, \ + alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + B, \ + ldb, \ + beta, \ + C, \ + ldc, \ + descr->base); + + #define launch_bsrmm_large_blockdim_kernel_host_pointer(T, bsr_block_dim, blk_size_y) \ + hipLaunchKernelGGL( \ + (bsrmm_large_blockdim_kernel_host_pointer), \ + bsrmm_blocks, \ + bsrmm_threads, \ + 0, \ + stream, \ + dir, \ + trans_B, \ + mb, \ + n, \ + *alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + block_dim, \ + B, \ + ldb, \ + *beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmm_large_blockdim_kernel_device_pointer(T, bsr_block_dim, blk_size_y) \ + hipLaunchKernelGGL( \ + (bsrmm_large_blockdim_kernel_device_pointer), \ + bsrmm_blocks, \ + bsrmm_threads, \ + 0, \ + stream, \ + dir, \ + trans_B, \ + mb, \ + n, \ + alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + block_dim, \ + B, \ + ldb, \ + beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmm_general_blockdim_kernel_host_pointer(T, bsr_block_dim, blk_size_y) \ + hipLaunchKernelGGL( \ + (bsrmm_general_blockdim_kernel_host_pointer), \ + bsrmm_blocks, \ + bsrmm_threads, \ + 0, \ + stream, \ + dir, \ + trans_B, \ + mb, \ + n, \ + *alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + block_dim, \ + B, \ + ldb, \ + *beta, \ + C, \ + ldc, \ + descr->base); + +#define launch_bsrmm_general_blockdim_kernel_device_pointer(T, bsr_block_dim, blk_size_y) \ + hipLaunchKernelGGL( \ + (bsrmm_general_blockdim_kernel_device_pointer), \ + bsrmm_blocks, \ + bsrmm_threads, \ + 0, \ + stream, \ + dir, \ + trans_B, \ + mb, \ + n, \ + alpha, \ + bsr_row_ptr, \ + bsr_col_ind, \ + bsr_val, \ + block_dim, \ + B, \ + ldb, \ + beta, \ + C, \ + ldc, \ + descr->base); + +template +__launch_bounds__(BLOCKSIZE) __global__ + void bsrmmnn_small_blockdim_kernel_host_pointer(rocsparse_direction direction, + rocsparse_int mb, + rocsparse_int n, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(alpha == static_cast(0) && beta == static_cast(1)) + { + return; + } + + bsrmmnn_small_blockdim_device( + direction, mb, n, alpha, bsr_row_ptr, bsr_col_ind, bsr_val, B, ldb, beta, C, ldc, idx_base); +} + +template +__launch_bounds__(BLOCKSIZE) __global__ + void bsrmmnn_small_blockdim_kernel_device_pointer(rocsparse_direction direction, + rocsparse_int mb, + rocsparse_int n, + const T* alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + const T* beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(*alpha == static_cast(0) && *beta == static_cast(1)) + { + return; + } + + bsrmmnn_small_blockdim_device(direction, + mb, + n, + *alpha, + bsr_row_ptr, + bsr_col_ind, + bsr_val, + B, + ldb, + *beta, + C, + ldc, + idx_base); +} + +template +__launch_bounds__(BLOCKSIZE) __global__ + void bsrmmnt_small_blockdim_kernel_host_pointer(rocsparse_direction direction, + rocsparse_int mb, + rocsparse_int n, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(alpha == static_cast(0) && beta == static_cast(1)) + { + return; + } + + bsrmmnt_small_blockdim_device(direction, + mb, + n, + alpha, + bsr_row_ptr, + bsr_col_ind, + bsr_val, + B, + ldb, + beta, + C, + ldc, + idx_base); +} + +template +__launch_bounds__(BLOCKSIZE) __global__ + void bsrmmnt_small_blockdim_kernel_device_pointer(rocsparse_direction direction, + rocsparse_int mb, + rocsparse_int n, + const T* alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + const T* __restrict__ B, + rocsparse_int ldb, + const T* beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(*alpha == static_cast(0) && *beta == static_cast(1)) + { + return; + } + + bsrmmnt_small_blockdim_device(direction, + mb, + n, + *alpha, + bsr_row_ptr, + bsr_col_ind, + bsr_val, + B, + ldb, + *beta, + C, + ldc, + idx_base); +} + +template +__launch_bounds__(BSR_BLOCK_DIM * BLK_SIZE_Y) __global__ + void bsrmm_large_blockdim_kernel_host_pointer(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(alpha == static_cast(0) && beta == static_cast(1)) + { + return; + } + + bsrmm_large_blockdim_device( + direction, trans_B, mb, n, alpha, bsr_row_ptr, bsr_col_ind, bsr_val, block_dim, B, ldb, beta, C, ldc, idx_base); +} + +template +__launch_bounds__(BSR_BLOCK_DIM * BLK_SIZE_Y) __global__ + void bsrmm_large_blockdim_kernel_device_pointer(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + const T* alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + const T* beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(*alpha == static_cast(0) && *beta == static_cast(1)) + { + return; + } + + bsrmm_large_blockdim_device(direction, + trans_B, + mb, + n, + *alpha, + bsr_row_ptr, + bsr_col_ind, + bsr_val, + block_dim, + B, + ldb, + *beta, + C, + ldc, + idx_base); +} + +template +__launch_bounds__(BSR_BLOCK_DIM * BLK_SIZE_Y) __global__ + void bsrmm_general_blockdim_kernel_host_pointer(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + T alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(alpha == static_cast(0) && beta == static_cast(1)) + { + return; + } + + bsrmm_general_blockdim_device( + direction, trans_B, mb, n, alpha, bsr_row_ptr, bsr_col_ind, bsr_val, block_dim, B, ldb, beta, C, ldc, idx_base); +} + +template +__launch_bounds__(BSR_BLOCK_DIM * BLK_SIZE_Y) __global__ + void bsrmm_general_blockdim_kernel_device_pointer(rocsparse_direction direction, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + const T* alpha, + const rocsparse_int* __restrict__ bsr_row_ptr, + const rocsparse_int* __restrict__ bsr_col_ind, + const T* __restrict__ bsr_val, + rocsparse_int block_dim, + const T* __restrict__ B, + rocsparse_int ldb, + const T* beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base idx_base) +{ + if(*alpha == static_cast(0) && *beta == static_cast(1)) + { + return; + } + + bsrmm_general_blockdim_device(direction, + trans_B, + mb, + n, + *alpha, + bsr_row_ptr, + bsr_col_ind, + bsr_val, + block_dim, + B, + ldb, + *beta, + C, + ldc, + idx_base); +} + +template +rocsparse_status rocsparse_bsrmm_template(rocsparse_handle handle, + rocsparse_direction dir, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int mb, + rocsparse_int n, + rocsparse_int kb, + rocsparse_int nnzb, + const T* alpha, + const rocsparse_mat_descr descr, + const T* bsr_val, + const rocsparse_int* bsr_row_ptr, + const rocsparse_int* bsr_col_ind, + rocsparse_int block_dim, + const T* B, + rocsparse_int ldb, + const T* beta, + T* C, + rocsparse_int ldc) +{ + // Check for valid handle and matrix descriptor + if(handle == nullptr) + { + return rocsparse_status_invalid_handle; + } + else if(descr == nullptr) + { + return rocsparse_status_invalid_pointer; + } + + // Logging TODO bench logging + if(handle->pointer_mode == rocsparse_pointer_mode_host) + { + log_trace(handle, + replaceX("rocsparse_Xbsrmm"), + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + *alpha, + (const void*&)descr, + (const void*&)bsr_val, + (const void*&)bsr_row_ptr, + (const void*&)bsr_col_ind, + block_dim, + (const void*&)B, + ldb, + *beta, + (const void*&)C, + ldc); + } + else + { + log_trace(handle, + replaceX("rocsparse_Xbsrmm"), + dir, + trans_A, + trans_B, + mb, + n, + kb, + nnzb, + (const void*&)alpha, + (const void*&)descr, + (const void*&)bsr_val, + (const void*&)bsr_row_ptr, + (const void*&)bsr_col_ind, + block_dim, + (const void*&)B, + ldb, + (const void*&)beta, + (const void*&)C, + ldc); + } + + // Check index base + if(descr->base != rocsparse_index_base_zero && descr->base != rocsparse_index_base_one) + { + return rocsparse_status_invalid_value; + } + + // Check matrix type + if(descr->type != rocsparse_matrix_type_general) + { + // TODO + return rocsparse_status_not_implemented; + } + + // Check operation + if(trans_A != rocsparse_operation_none) + { + return rocsparse_status_not_implemented; + } + else if(trans_B != rocsparse_operation_none && trans_B != rocsparse_operation_transpose) + { + return rocsparse_status_not_implemented; + } + + // Check sizes + if(mb < 0 || n < 0 || kb < 0 || nnzb < 0 || block_dim <= 0) + { + return rocsparse_status_invalid_size; + } + + // Quick return if possible + if(mb == 0 || n == 0 || kb == 0) + { + return rocsparse_status_success; + } + + // Check pointer arguments + if(bsr_val == nullptr || bsr_row_ptr == nullptr || bsr_col_ind == nullptr || B == nullptr + || C == nullptr || alpha == nullptr || beta == nullptr) + { + return rocsparse_status_invalid_pointer; + } + + // Check leading dimension of B + if(trans_B == rocsparse_operation_none) + { + if(ldb < kb) + { + return rocsparse_status_invalid_size; + } + } + else + { + if(ldb < n) + { + return rocsparse_status_invalid_size; + } + } + + // Check leading dimension of C + if(ldc < mb) + { + return rocsparse_status_invalid_size; + } + + // Stream + hipStream_t stream = handle->stream; + + rocsparse_int m = mb * block_dim; + rocsparse_int k = kb * block_dim; + rocsparse_int nnz = nnzb * block_dim; + + // If n is only 1 and B are non-transposed, then call bsrmv + if(n == 1) + { + if(trans_B == rocsparse_operation_none) + { + return rocsparse_bsrmv_template(handle, + dir, + trans_A, + mb, + kb, + nnzb, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + block_dim, + B, + beta, + C); + } + } + + // If block dimension is one we can simply call csrmm + if(block_dim == 1) + { + return rocsparse_csrmm_template(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + alpha, + descr, + bsr_val, + bsr_row_ptr, + bsr_col_ind, + B, + ldb, + beta, + C, + ldc); + } + + if(block_dim == 2) + { + if(handle->pointer_mode == rocsparse_pointer_mode_device) + { + if(trans_B == rocsparse_operation_none) + { + constexpr rocsparse_int BSRMMNN_DIM = 64; + constexpr rocsparse_int SUB_WF_SIZE = 8; + + dim3 bsrmmnn_blocks((SUB_WF_SIZE * m - 1) / BSRMMNN_DIM + 1, (n - 1) / SUB_WF_SIZE + 1); + dim3 bsrmmnn_threads(BSRMMNN_DIM); + launch_bsrmmnn_small_blockdim_kernel_device_pointer(T, BSRMMNN_DIM, SUB_WF_SIZE, 2); + } + else + { + constexpr rocsparse_int BSRMMNT_DIM = 64; + + // Average nnzb per row of A + rocsparse_int avg_row_nnzb = (nnzb - 1) / mb + 1; + + // Launch appropriate kernel depending on row nnz of A + if(avg_row_nnzb < 16) + { + dim3 bsrmmnt_blocks((8 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_device_pointer(T, BSRMMNT_DIM, 8, 2); + } + else if(avg_row_nnzb < 32) + { + dim3 bsrmmnt_blocks((16 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_device_pointer(T, BSRMMNT_DIM, 16, 2); + } + else if(avg_row_nnzb < 64 || handle->wavefront_size == 32) + { + dim3 bsrmmnt_blocks((32 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_device_pointer(T, BSRMMNT_DIM, 32, 2); + } + else if(handle->wavefront_size == 64) + { + dim3 bsrmmnt_blocks((64 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_device_pointer(T, BSRMMNT_DIM, 64, 2); + } + else + { + return rocsparse_status_arch_mismatch; + } + } + } + else + { + if(trans_B == rocsparse_operation_none) + { + constexpr rocsparse_int BSRMMNN_DIM = 64; + constexpr rocsparse_int SUB_WF_SIZE = 8; + + dim3 bsrmmnn_blocks((SUB_WF_SIZE * m - 1) / BSRMMNN_DIM + 1, (n - 1) / SUB_WF_SIZE + 1); + dim3 bsrmmnn_threads(BSRMMNN_DIM); + launch_bsrmmnn_small_blockdim_kernel_host_pointer(T, BSRMMNN_DIM, SUB_WF_SIZE, 2); + } + else + { + constexpr rocsparse_int BSRMMNT_DIM = 64; + + // Average nnzb per row of A + rocsparse_int avg_row_nnzb = (nnzb - 1) / mb + 1; + + // Launch appropriate kernel depending on row nnz of A + if(avg_row_nnzb < 16) + { + dim3 bsrmmnt_blocks((8 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_host_pointer(T, BSRMMNT_DIM, 8, 2); + } + else if(avg_row_nnzb < 32) + { + dim3 bsrmmnt_blocks((16 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_host_pointer(T, BSRMMNT_DIM, 16, 2); + } + else if(avg_row_nnzb < 64 || handle->wavefront_size == 32) + { + dim3 bsrmmnt_blocks((32 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_host_pointer(T, BSRMMNT_DIM, 32, 2); + } + else if(handle->wavefront_size == 64) + { + dim3 bsrmmnt_blocks((64 * m - 1) / BSRMMNT_DIM + 1); + dim3 bsrmmnt_threads(BSRMMNT_DIM); + launch_bsrmmnt_small_blockdim_kernel_host_pointer(T, BSRMMNT_DIM, 64, 2); + } + else + { + return rocsparse_status_arch_mismatch; + } + } + } + + return rocsparse_status_success; + } + + // Run different bsrmm kernels for block dim > 2 + if(n <= 16 && block_dim > 4 && block_dim <= 8) + { + if(handle->pointer_mode == rocsparse_pointer_mode_device) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(8, 16, 1); + launch_bsrmm_large_blockdim_kernel_device_pointer(T, 8, 16); + } + else + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(8, 16, 1); + launch_bsrmm_large_blockdim_kernel_host_pointer(T, 8, 16); + } + } + else + { + if(handle->pointer_mode == rocsparse_pointer_mode_device) + { + if(block_dim <= 4) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(4, 16, 1); + launch_bsrmm_large_blockdim_kernel_device_pointer(T, 4, 16); + } + else if(block_dim <= 8) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(8, 32, 1); + launch_bsrmm_large_blockdim_kernel_device_pointer(T, 8, 32); + } + else if(block_dim <= 16) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(16, 16, 1); + launch_bsrmm_large_blockdim_kernel_device_pointer(T, 16, 16); + } + else if(block_dim <= 32) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(32, 32, 1); + launch_bsrmm_large_blockdim_kernel_device_pointer(T, 32, 32); + } + else + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(32, 32, 1); + launch_bsrmm_general_blockdim_kernel_device_pointer(T, 32, 32); + } + } + else + { + if(block_dim <= 4) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(4, 16, 1); + launch_bsrmm_large_blockdim_kernel_host_pointer(T, 4, 16); + } + else if(block_dim <= 8) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(8, 32, 1); + launch_bsrmm_large_blockdim_kernel_host_pointer(T, 8, 32); + } + else if(block_dim <= 16) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 16 + 1); + dim3 bsrmm_threads(16, 16, 1); + launch_bsrmm_large_blockdim_kernel_host_pointer(T, 16, 16); + } + else if(block_dim <= 32) + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(32, 32, 1); + launch_bsrmm_large_blockdim_kernel_host_pointer(T, 32, 32); + } + else + { + dim3 bsrmm_blocks((mb - 1) / 1 + 1, (n - 1) / 32 + 1); + dim3 bsrmm_threads(32, 32, 1); + launch_bsrmm_general_blockdim_kernel_host_pointer(T, 32, 32); + } + } + } + + return rocsparse_status_success; +} + +#endif // ROCSPARSE_BSRMM_HPP \ No newline at end of file diff --git a/library/src/rocsparse_module.f90 b/library/src/rocsparse_module.f90 index dc9c90ec..cc67ef89 100644 --- a/library/src/rocsparse_module.f90 +++ b/library/src/rocsparse_module.f90 @@ -1728,6 +1728,114 @@ end function rocsparse_zhybmv ! =========================================================================== ! level 3 SPARSE ! =========================================================================== +! rocsparse_bsrmm + function rocsparse_sbsrmm(handle, dir, trans_A, trans_B, mb, n, kb, nnzb, alpha, descr, & + bsr_val, bsr_row_ptr, bsr_col_ind, block_dim, B, ldb, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_sbsrmm') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: dir + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: mb + integer(c_int), value :: n + integer(c_int), value :: kb + integer(c_int), value :: nnzb + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: bsr_val + type(c_ptr), intent(in), value :: bsr_row_ptr + type(c_ptr), intent(in), value :: bsr_col_ind + integer(c_int), value :: block_dim + type(c_ptr), intent(in), value :: B + integer(c_int), value :: ldb + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_sbsrmm + + function rocsparse_dbsrmm(handle, dir, trans_A, trans_B, mb, n, kb, nnzb, alpha, descr, & + bsr_val, bsr_row_ptr, bsr_col_ind, block_dim, B, ldb, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_dbsrmm') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: dir + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: mb + integer(c_int), value :: n + integer(c_int), value :: kb + integer(c_int), value :: nnzb + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: bsr_val + type(c_ptr), intent(in), value :: bsr_row_ptr + type(c_ptr), intent(in), value :: bsr_col_ind + integer(c_int), value :: block_dim + type(c_ptr), intent(in), value :: B + integer(c_int), value :: ldb + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_dbsrmm + + function rocsparse_cbsrmm(handle, dir, trans_A, trans_B, mb, n, kb, nnzb, alpha, descr, & + bsr_val, bsr_row_ptr, bsr_col_ind, block_dim, B, ldb, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_cbsrmm') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: dir + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: mb + integer(c_int), value :: n + integer(c_int), value :: kb + integer(c_int), value :: nnzb + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: bsr_val + type(c_ptr), intent(in), value :: bsr_row_ptr + type(c_ptr), intent(in), value :: bsr_col_ind + integer(c_int), value :: block_dim + type(c_ptr), intent(in), value :: B + integer(c_int), value :: ldb + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_cbsrmm + + function rocsparse_zbsrmm(handle, dir, trans_A, trans_B, mb, n, kb, nnzb, alpha, descr, & + bsr_val, bsr_row_ptr, bsr_col_ind, block_dim, B, ldb, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_zbsrmm') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: dir + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: mb + integer(c_int), value :: n + integer(c_int), value :: kb + integer(c_int), value :: nnzb + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: bsr_val + type(c_ptr), intent(in), value :: bsr_row_ptr + type(c_ptr), intent(in), value :: bsr_col_ind + integer(c_int), value :: block_dim + type(c_ptr), intent(in), value :: B + integer(c_int), value :: ldb + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_zbsrmm ! rocsparse_csrmm function rocsparse_scsrmm(handle, trans_A, trans_B, m, n, k, nnz, alpha, descr, &