From 0e0f8befe3b965bd5564eac45f80760541a349e0 Mon Sep 17 00:00:00 2001 From: Nico <31079890+ntrost57@users.noreply.github.com> Date: Tue, 7 Jul 2020 07:12:48 +0200 Subject: [PATCH] level3/gemmi feature (#83) * gemmi benchmark * gemmi tests * gemmi samples * gemmi documentation * gemmi API * gemmi fortran binding and example * internal gemmi structure * gemmi kernel for transposed B * minor tweaks * bump version --- CMakeLists.txt | 2 +- clients/benchmarks/client.cpp | 14 +- .../rocsparse_template_specialization.cpp | 153 +++++ clients/include/rocsparse.hpp | 20 + clients/include/rocsparse_host.hpp | 42 ++ clients/include/rocsparse_template.yaml | 6 + clients/include/testing_gemmi.hpp | 630 ++++++++++++++++++ clients/samples/CMakeLists.txt | 2 + clients/samples/example_fortran_gemmi.f90 | 246 +++++++ clients/samples/example_gemmi.cpp | 203 ++++++ clients/tests/CMakeLists.txt | 3 +- clients/tests/rocsparse_test.yaml | 1 + clients/tests/test_gemmi.cpp | 115 ++++ clients/tests/test_gemmi.yaml | 189 ++++++ docs/source/usermanual.rst | 12 + library/include/rocsparse-functions.h | 216 ++++++ library/src/CMakeLists.txt | 1 + library/src/level3/gemmi_device.h | 95 +++ library/src/level3/rocsparse_gemmi.cpp | 180 +++++ library/src/level3/rocsparse_gemmi.hpp | 345 ++++++++++ library/src/rocsparse_module.f90 | 101 +++ 21 files changed, 2573 insertions(+), 3 deletions(-) create mode 100644 clients/include/testing_gemmi.hpp create mode 100644 clients/samples/example_fortran_gemmi.f90 create mode 100644 clients/samples/example_gemmi.cpp create mode 100644 clients/tests/test_gemmi.cpp create mode 100644 clients/tests/test_gemmi.yaml create mode 100644 library/src/level3/gemmi_device.h create mode 100644 library/src/level3/rocsparse_gemmi.cpp create mode 100644 library/src/level3/rocsparse_gemmi.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 52d9e6cd..04f8713b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,7 +83,7 @@ option(BUILD_VERBOSE "Output additional build information" OFF) include(cmake/Dependencies.cmake) # Setup version -set(VERSION_STRING "1.15.0") +set(VERSION_STRING "1.15.1") rocm_setup_version(VERSION ${VERSION_STRING}) set(rocsparse_SOVERSION 0.1) diff --git a/clients/benchmarks/client.cpp b/clients/benchmarks/client.cpp index 83c71c66..1cf4114c 100644 --- a/clients/benchmarks/client.cpp +++ b/clients/benchmarks/client.cpp @@ -46,6 +46,7 @@ #include "testing_bsrmm.hpp" #include "testing_csrmm.hpp" #include "testing_csrsm.hpp" +#include "testing_gemmi.hpp" // Extra #include "testing_csrgeam.hpp" @@ -210,7 +211,7 @@ int main(int argc, char* argv[]) "SPARSE function to test. Options:\n" " Level1: axpyi, doti, dotci, gthr, gthrz, roti, sctr\n" " Level2: bsrmv, bsrsv, coomv, csrmv, csrsv, ellmv, hybmv\n" - " Level3: bsrmm, csrmm, csrsm\n" + " Level3: bsrmm, csrmm, csrsm, gemmi\n" " Extra: csrgeam, csrgemm\n" " Preconditioner: csric0, csrilu0\n" " Conversion: csr2coo, csr2csc, csr2ell, csr2hyb, csr2bsr\n" @@ -601,6 +602,17 @@ int main(int argc, char* argv[]) else if(precision == 'z') testing_csrsm(arg); } + else if(function == "gemmi") + { + if(precision == 's') + testing_gemmi(arg); + else if(precision == 'd') + testing_gemmi(arg); + else if(precision == 'c') + testing_gemmi(arg); + else if(precision == 'z') + testing_gemmi(arg); + } else if(function == "csrgeam") { if(precision == 's') diff --git a/clients/common/rocsparse_template_specialization.cpp b/clients/common/rocsparse_template_specialization.cpp index 3f8b80fd..02d26cf3 100644 --- a/clients/common/rocsparse_template_specialization.cpp +++ b/clients/common/rocsparse_template_specialization.cpp @@ -2284,6 +2284,159 @@ rocsparse_status rocsparse_csrsm_solve(rocsparse_handle handle, temp_buffer); } +// gemmi +template <> +rocsparse_status rocsparse_gemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const float* alpha, + const float* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const float* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const float* beta, + float* C, + rocsparse_int ldc) +{ + return rocsparse_sgemmi(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + alpha, + A, + lda, + descr, + csr_val, + csr_row_ptr, + csr_col_ind, + beta, + C, + ldc); +} + +template <> +rocsparse_status rocsparse_gemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const double* alpha, + const double* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const double* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const double* beta, + double* C, + rocsparse_int ldc) +{ + return rocsparse_dgemmi(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + alpha, + A, + lda, + descr, + csr_val, + csr_row_ptr, + csr_col_ind, + beta, + C, + ldc); +} + +template <> +rocsparse_status rocsparse_gemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const rocsparse_float_complex* alpha, + const rocsparse_float_complex* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const rocsparse_float_complex* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const rocsparse_float_complex* beta, + rocsparse_float_complex* C, + rocsparse_int ldc) +{ + return rocsparse_cgemmi(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + alpha, + A, + lda, + descr, + csr_val, + csr_row_ptr, + csr_col_ind, + beta, + C, + ldc); +} + +template <> +rocsparse_status rocsparse_gemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const rocsparse_double_complex* alpha, + const rocsparse_double_complex* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const rocsparse_double_complex* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const rocsparse_double_complex* beta, + rocsparse_double_complex* C, + rocsparse_int ldc) +{ + return rocsparse_zgemmi(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + alpha, + A, + lda, + descr, + csr_val, + csr_row_ptr, + csr_col_ind, + beta, + C, + ldc); +} + /* * =========================================================================== * extra SPARSE diff --git a/clients/include/rocsparse.hpp b/clients/include/rocsparse.hpp index 951a26d1..95934fbc 100644 --- a/clients/include/rocsparse.hpp +++ b/clients/include/rocsparse.hpp @@ -400,6 +400,26 @@ rocsparse_status rocsparse_csrsm_solve(rocsparse_handle handle, rocsparse_solve_policy policy, void* temp_buffer); +// gemmi +template +rocsparse_status rocsparse_gemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const T* alpha, + const T* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const T* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const T* beta, + T* C, + rocsparse_int ldc); + /* * =========================================================================== * extra SPARSE diff --git a/clients/include/rocsparse_host.hpp b/clients/include/rocsparse_host.hpp index 31027841..80d2b010 100644 --- a/clients/include/rocsparse_host.hpp +++ b/clients/include/rocsparse_host.hpp @@ -1653,6 +1653,48 @@ inline void host_csrsm(rocsparse_int M, *numeric_pivot = (*numeric_pivot == M + 1) ? -1 : *numeric_pivot; } +template +inline void host_gemmi(rocsparse_int M, + rocsparse_int N, + rocsparse_operation transA, + rocsparse_operation transB, + T alpha, + const T* A, + rocsparse_int lda, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const T* csr_val, + T beta, + T* C, + rocsparse_int ldc, + rocsparse_index_base base) +{ + if(transB == rocsparse_operation_transpose) + { + for(rocsparse_int i = 0; i < M; ++i) + { + for(rocsparse_int j = 0; j < N; ++j) + { + T sum = static_cast(0); + + rocsparse_int row_begin = csr_row_ptr[j] - base; + rocsparse_int row_end = csr_row_ptr[j + 1] - base; + + for(rocsparse_int k = row_begin; k < row_end; ++k) + { + rocsparse_int col_B = csr_col_ind[k] - base; + T val_B = csr_val[k]; + T val_A = A[col_B * lda + i]; + + sum = std::fma(val_A, val_B, sum); + } + + C[j * ldc + i] = std::fma(beta, C[j * ldc + i], alpha * sum); + } + } + } +} + /* * =========================================================================== * extra SPARSE diff --git a/clients/include/rocsparse_template.yaml b/clients/include/rocsparse_template.yaml index d4f6ee7c..0c23bfa8 100644 --- a/clients/include/rocsparse_template.yaml +++ b/clients/include/rocsparse_template.yaml @@ -111,6 +111,8 @@ Functions: rocsparse_zbsrmm: { function: bsrmm, <<: *double_precision_complex } rocsparse_scsrmm: { function: csrmm, <<: *single_precision } rocsparse_dcsrmm: { function: csrmm, <<: *double_precision } + rocsparse_ccsrmm: { function: csrmm, <<: *single_precision_complex } + rocsparse_zcsrmm: { function: csrmm, <<: *double_precision_complex } rocsparse_scsrsm_buffer_size: { function: csrsm, <<: *single_precision } rocsparse_dcsrsm_buffer_size: { function: csrsm, <<: *double_precision } rocsparse_ccsrsm_buffer_size: { function: csrsm, <<: *single_precision_complex } @@ -125,6 +127,10 @@ Functions: rocsparse_zcsrsm_solve: { function: csrsm, <<: *double_precision_complex } rocsparse_csrsm_zero_pivot: {function: csrsm } rocsparse_csrsm_clear: {function: csrsm } + rocsparse_sgemmi: { function: gemmi, <<: *single_precision } + rocsparse_dgemmi: { function: gemmi, <<: *double_precision } + rocsparse_cgemmi: { function: gemmi, <<: *single_precision_complex } + rocsparse_zgemmi: { function: gemmi, <<: *double_precision_complex } rocsparse_csrgeam_nnz: { function: csrgeam } rocsparse_scsrgeam: { function: csrgeam, <<: *single_precision } diff --git a/clients/include/testing_gemmi.hpp b/clients/include/testing_gemmi.hpp new file mode 100644 index 00000000..7999c297 --- /dev/null +++ b/clients/include/testing_gemmi.hpp @@ -0,0 +1,630 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#pragma once +#ifndef TESTING_GEMMI_HPP +#define TESTING_GEMMI_HPP + +#include + +#include "flops.hpp" +#include "gbyte.hpp" +#include "rocsparse_check.hpp" +#include "rocsparse_host.hpp" +#include "rocsparse_init.hpp" +#include "rocsparse_math.hpp" +#include "rocsparse_random.hpp" +#include "rocsparse_test.hpp" +#include "rocsparse_vector.hpp" +#include "utility.hpp" + +template +void testing_gemmi_bad_arg(const Arguments& arg) +{ + static const size_t safe_size = 100; + + T h_alpha = 0.6; + T h_beta = 0.1; + + // Create rocsparse handle + rocsparse_local_handle handle; + + // Create rocsparse mat descriptor + rocsparse_local_mat_descr descr; + + // rocsparse operations + rocsparse_operation trans_A = rocsparse_operation_none; + rocsparse_operation trans_B = rocsparse_operation_transpose; + + // Allocate memory on device + device_vector dcsr_row_ptr(safe_size); + device_vector dcsr_col_ind(safe_size); + device_vector dcsr_val(safe_size); + device_vector dA(safe_size); + device_vector dC(safe_size); + + if(!dcsr_row_ptr || !dcsr_col_ind || !dcsr_val || !dA || !dC) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Test rocsparse_gemmi() + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(nullptr, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + dA, + safe_size, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_handle); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + nullptr, + dA, + safe_size, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + nullptr, + safe_size, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + dA, + safe_size, + nullptr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + dA, + safe_size, + descr, + nullptr, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + dA, + safe_size, + descr, + dcsr_val, + nullptr, + dcsr_col_ind, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + dA, + safe_size, + descr, + dcsr_val, + dcsr_row_ptr, + nullptr, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + dA, + safe_size, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + nullptr, + dC, + safe_size), + rocsparse_status_invalid_pointer); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + dA, + safe_size, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + nullptr, + safe_size), + rocsparse_status_invalid_pointer); + + // Test invalid sizes + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + -1, + safe_size, + safe_size, + safe_size, + nullptr, + nullptr, + safe_size, + descr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + safe_size), + rocsparse_status_invalid_size); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + -1, + safe_size, + safe_size, + nullptr, + nullptr, + safe_size, + descr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + safe_size), + rocsparse_status_invalid_size); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + -1, + safe_size, + nullptr, + nullptr, + safe_size, + descr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + safe_size), + rocsparse_status_invalid_size); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + -1, + nullptr, + nullptr, + safe_size, + descr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + safe_size), + rocsparse_status_invalid_size); + + // Test invalid leading dimensions + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + dA, + -1, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + dC, + safe_size), + rocsparse_status_invalid_value); + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + trans_A, + trans_B, + safe_size, + safe_size, + safe_size, + safe_size, + &h_alpha, + dA, + safe_size, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + dC, + -1), + rocsparse_status_invalid_value); +} + +template +void testing_gemmi(const Arguments& arg) +{ + rocsparse_int M = arg.M; + rocsparse_int N = arg.N; + rocsparse_int K = arg.K; + rocsparse_int dim_x = arg.dimx; + rocsparse_int dim_y = arg.dimy; + rocsparse_int dim_z = arg.dimz; + rocsparse_operation transA = arg.transA; + rocsparse_operation transB = arg.transB; + rocsparse_index_base base = arg.baseA; + rocsparse_matrix_init mat = arg.matrix; + bool full_rank = false; + std::string filename + = arg.timing ? arg.filename : rocsparse_exepath() + "../matrices/" + arg.filename + ".csr"; + + T h_alpha = arg.get_alpha(); + T h_beta = arg.get_beta(); + + // Create rocsparse handle + rocsparse_local_handle handle; + + // Create rocsparse mat descriptor + rocsparse_local_mat_descr descr; + + // Set matrix index base + CHECK_ROCSPARSE_ERROR(rocsparse_set_mat_index_base(descr, base)); + + // Argument sanity check before allocating invalid memory + if(M <= 0 || N <= 0 || K < 0) + { + static const size_t safe_size = 100; + + EXPECT_ROCSPARSE_STATUS(rocsparse_gemmi(handle, + transA, + transB, + M, + N, + K, + safe_size, + nullptr, + nullptr, + safe_size, + descr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + safe_size), + (M < 0 || N < 0 || K < 0) ? rocsparse_status_invalid_size + : rocsparse_status_success); + + return; + } + + // Allocate host memory for matrix B + host_vector hcsr_row_ptr; + host_vector hcsr_col_ind; + host_vector hcsr_val; + + rocsparse_seedrand(); + + // Sample matrix B + rocsparse_int nnz_B; + rocsparse_init_csr_matrix(hcsr_row_ptr, + hcsr_col_ind, + hcsr_val, + (transB == rocsparse_operation_none) ? K : N, + (transB == rocsparse_operation_none) ? N : K, + M, + dim_x, + dim_y, + dim_z, + nnz_B, + base, + mat, + filename.c_str(), + false, + full_rank); + + rocsparse_int nrow_B = (transB == rocsparse_operation_none) ? K : N; + rocsparse_int ncol_B = (transB == rocsparse_operation_none) ? N : K; + + rocsparse_int lda = std::max(1, M); + rocsparse_int ldc = std::max(1, M); + + rocsparse_int nnz_A = lda * K; + rocsparse_int nnz_C = ldc * N; + + // Allocate host memory for matrix A and C + host_vector hA(nnz_A); + host_vector hC_1(nnz_C); + host_vector hC_2(nnz_C); + host_vector hC_gold(nnz_C); + + // Sample matrix A + rocsparse_init(hA, M, K, lda); + + // Sample matrix C + rocsparse_init(hC_gold, M, N, ldc); + + // Allocate device memory + device_vector dcsr_row_ptr(nrow_B + 1); + device_vector dcsr_col_ind(nnz_B); + device_vector dcsr_val(nnz_B); + device_vector dA(nnz_A); + device_vector dC_1(nnz_C); + device_vector dC_2(nnz_C); + device_vector d_alpha(1); + device_vector d_beta(1); + + if(!dcsr_row_ptr || !dcsr_col_ind || !dcsr_val || !dA || !dC_1 || !dC_2 || !d_alpha || !d_beta) + { + CHECK_HIP_ERROR(hipErrorOutOfMemory); + return; + } + + // Copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy( + dcsr_row_ptr, hcsr_row_ptr, sizeof(rocsparse_int) * (nrow_B + 1), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy( + dcsr_col_ind, hcsr_col_ind, sizeof(rocsparse_int) * nnz_B, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dcsr_val, hcsr_val, sizeof(T) * nnz_B, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dA, hA, sizeof(T) * nnz_A, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dC_1, hC_gold, sizeof(T) * nnz_C, hipMemcpyHostToDevice)); + + if(arg.unit_check) + { + // Copy data from CPU to device + CHECK_HIP_ERROR(hipMemcpy(dC_2, dC_1, sizeof(T) * nnz_C, hipMemcpyDeviceToDevice)); + CHECK_HIP_ERROR(hipMemcpy(d_alpha, &h_alpha, sizeof(T), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(d_beta, &h_beta, sizeof(T), hipMemcpyHostToDevice)); + + // Pointer mode host + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_host)); + CHECK_ROCSPARSE_ERROR(rocsparse_gemmi(handle, + transA, + transB, + M, + N, + K, + nnz_B, + &h_alpha, + dA, + lda, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + dC_1, + ldc)); + + // Pointer mode device + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_device)); + CHECK_ROCSPARSE_ERROR(rocsparse_gemmi(handle, + transA, + transB, + M, + N, + K, + nnz_B, + d_alpha, + dA, + lda, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + d_beta, + dC_2, + ldc)); + + // Copy output to host + CHECK_HIP_ERROR(hipMemcpy(hC_1, dC_1, sizeof(T) * nnz_C, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(hC_2, dC_2, sizeof(T) * nnz_C, hipMemcpyDeviceToHost)); + + // CPU gemmi + host_gemmi(M, + N, + transA, + transB, + h_alpha, + hA, + lda, + hcsr_row_ptr, + hcsr_col_ind, + hcsr_val, + h_beta, + hC_gold, + ldc, + base); + + unit_check_general(M, N, ldc, hC_gold, hC_1); + unit_check_general(M, N, ldc, hC_gold, hC_2); + } + + if(arg.timing) + { + int number_cold_calls = 2; + int number_hot_calls = arg.iters; + + CHECK_ROCSPARSE_ERROR(rocsparse_set_pointer_mode(handle, rocsparse_pointer_mode_host)); + + // Warm up + for(int iter = 0; iter < number_cold_calls; ++iter) + { + CHECK_ROCSPARSE_ERROR(rocsparse_gemmi(handle, + transA, + transB, + M, + N, + K, + nnz_B, + &h_alpha, + dA, + lda, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + dC_1, + ldc)); + } + + double gpu_time_used = get_time_us(); + + // Performance run + for(int iter = 0; iter < number_hot_calls; ++iter) + { + CHECK_ROCSPARSE_ERROR(rocsparse_gemmi(handle, + transA, + transB, + M, + N, + K, + nnz_B, + &h_alpha, + dA, + lda, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &h_beta, + dC_1, + ldc)); + } + + gpu_time_used = (get_time_us() - gpu_time_used) / number_hot_calls; + + double gpu_gflops = csrmm_gflop_count(M, nnz_B, nnz_C, h_beta != static_cast(0)) + / gpu_time_used * 1e6; + double gpu_gbyte + = csrmm_gbyte_count(nrow_B, nnz_B, nnz_A, nnz_C, h_beta != static_cast(0)) + / gpu_time_used * 1e6; + + std::cout.precision(2); + std::cout.setf(std::ios::fixed); + std::cout.setf(std::ios::left); + + std::cout << std::setw(12) << "M" << std::setw(12) << "N" << std::setw(12) << "K" + << std::setw(12) << "transA" << std::setw(12) << "transB" << std::setw(12) + << "nnz_A" << std::setw(12) << "nnz_B" << std::setw(12) << "nnz_C" + << std::setw(12) << "alpha" << std::setw(12) << "beta" << std::setw(12) + << "GFlop/s" << std::setw(12) << "GB/s" << std::setw(12) << "msec" + << std::setw(12) << "iter" << std::setw(12) << "verified" << std::endl; + + std::cout << std::setw(12) << M << std::setw(12) << N << std::setw(12) << K << std::setw(12) + << rocsparse_operation2string(transA) << std::setw(12) + << rocsparse_operation2string(transB) << std::setw(12) << nnz_A << std::setw(12) + << nnz_B << std::setw(12) << nnz_C << std::setw(12) << h_alpha << std::setw(12) + << h_beta << std::setw(12) << gpu_gflops << std::setw(12) << gpu_gbyte + << std::setw(12) << gpu_time_used / 1e3 << std::setw(12) << number_hot_calls + << std::setw(12) << (arg.unit_check ? "yes" : "no") << std::endl; + } +} + +#endif // TESTING_GEMMI_HPP diff --git a/clients/samples/CMakeLists.txt b/clients/samples/CMakeLists.txt index b10c4f40..7f00171b 100644 --- a/clients/samples/CMakeLists.txt +++ b/clients/samples/CMakeLists.txt @@ -66,6 +66,7 @@ add_rocsparse_example(example_bsrmm.cpp) add_rocsparse_example(example_fortran_bsrmm.f90) add_rocsparse_example(example_csrmm.cpp) add_rocsparse_example(example_csrsm.cpp) +add_rocsparse_example(example_gemmi.cpp) # Extra add_rocsparse_example(example_csrgeam.cpp) @@ -81,6 +82,7 @@ if(TARGET rocsparse) add_rocsparse_example(example_fortran_csrsv.f90) add_rocsparse_example(example_fortran_spmv.f90) add_rocsparse_example(example_fortran_csrsm.f90) + add_rocsparse_example(example_fortran_gemmi.f90) add_rocsparse_example(example_fortran_auxiliary.f90) add_rocsparse_example(example_fortran_dotci.f90) add_rocsparse_example(example_fortran_roti.f90) diff --git a/clients/samples/example_fortran_gemmi.f90 b/clients/samples/example_fortran_gemmi.f90 new file mode 100644 index 00000000..52df4107 --- /dev/null +++ b/clients/samples/example_fortran_gemmi.f90 @@ -0,0 +1,246 @@ +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +! Copyright (c) 2020 Advanced Micro Devices, Inc. +! +! Permission is hereby granted, free of charge, to any person obtaining a copy +! of this software and associated documentation files (the "Software"), to deal +! in the Software without restriction, including without limitation the rights +! to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +! copies of the Software, and to permit persons to whom the Software is +! furnished to do so, subject to the following conditions: +! +! The above copyright notice and this permission notice shall be included in +! all copies or substantial portions of the Software. +! +! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +! IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +! FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +! AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +! LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +! OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +! THE SOFTWARE. +! +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +subroutine HIP_CHECK(stat) + use iso_c_binding + + implicit none + + integer(c_int) :: stat + + if(stat /= 0) then + write(*,*) 'Error: hip error' + stop + end if + +end subroutine HIP_CHECK + +subroutine ROCSPARSE_CHECK(stat) + use iso_c_binding + + implicit none + + integer(c_int) :: stat + + if(stat /= 0) then + write(*,*) 'Error: rocsparse error' + stop + end if + +end subroutine ROCSPARSE_CHECK + +program example_fortran_gemmi + use iso_c_binding + use rocsparse + + implicit none + + interface + function hipMalloc(ptr, size) & + result(c_int) & + bind(c, name = 'hipMalloc') + use iso_c_binding + implicit none + type(c_ptr) :: ptr + integer(c_size_t), value :: size + end function hipMalloc + + function hipFree(ptr) & + result(c_int) & + bind(c, name = 'hipFree') + use iso_c_binding + implicit none + type(c_ptr), value :: ptr + end function hipFree + + function hipMemcpy(dst, src, size, kind) & + result(c_int) & + bind(c, name = 'hipMemcpy') + use iso_c_binding + implicit none + type(c_ptr), value :: dst + type(c_ptr), intent(in), value :: src + integer(c_size_t), value :: size + integer(c_int), value :: kind + end function hipMemcpy + + function hipMemset(dst, val, size) & + result(c_int) & + bind(c, name = 'hipMemset') + use iso_c_binding + implicit none + type(c_ptr), value :: dst + integer(c_int), value :: val + integer(c_size_t), value :: size + end function hipMemset + + function hipDeviceSynchronize() & + result(c_int) & + bind(c, name = 'hipDeviceSynchronize') + use iso_c_binding + implicit none + end function hipDeviceSynchronize + + function hipDeviceReset() & + result(c_int) & + bind(c, name = 'hipDeviceReset') + use iso_c_binding + implicit none + end function hipDeviceReset + end interface + + integer, target :: h_csr_row_ptr(6), h_csr_col_ind(8) + real(8), target :: h_csr_val(8), h_A(6), h_C(10) + + type(c_ptr) :: d_csr_row_ptr + type(c_ptr) :: d_csr_col_ind + type(c_ptr) :: d_csr_val + type(c_ptr) :: d_A + type(c_ptr) :: d_C + + integer :: i, j + integer(c_int) :: M, N, K, nnz + integer(c_int) :: lda, ldc + + real(c_double), target :: alpha + real(c_double), target :: beta + + type(c_ptr) :: handle + type(c_ptr) :: descr + + integer :: version + + character(len=12) :: rev + +! This example is going to compute C = alpha * A * B^T + beta * C + +! Input data + +! Number of rows and columns + M = 2 + N = 5 + K = 3 + + lda = M + ldc = M + +! Number of non-zero entries + nnz = 8 + +! A = ( 9 10 11 ) +! ( 12 13 14 ) + +! ( 1 0 6 ) +! ( 2 4 0 ) +! B = ( 0 5 0 ) +! ( 3 0 7 ) +! ( 0 0 8 ) + +! C = ( 15 16 17 18 19 ) +! ( 20 21 22 23 24 ) + +! Fill A + h_A = (/9, 12, 10, 13, 11, 14/) + +! Fill CSR structure of B + h_csr_row_ptr = (/0, 2, 4, 5, 7, 8/) + h_csr_col_ind = (/0, 2, 0, 1, 1, 0, 2, 2/) + h_csr_val = (/1, 6, 2, 4, 5, 3, 7, 8/) + +! Fill C + h_C = (/15, 20, 16, 21, 17, 22, 18, 23, 19, 24/) + +! Scalar alpha and beta + alpha = 3.7 + beta = 1.3 + +! Allocate device memory + call HIP_CHECK(hipMalloc(d_A, int(M * K, c_size_t) * 8)) + call HIP_CHECK(hipMalloc(d_csr_row_ptr, (int(N, c_size_t) + 1) * 4)) + call HIP_CHECK(hipMalloc(d_csr_col_ind, int(nnz, c_size_t) * 4)) + call HIP_CHECK(hipMalloc(d_csr_val, int(nnz, c_size_t) * 8)) + call HIP_CHECK(hipMalloc(d_C, int(M * N, c_size_t) * 8)) + +! Copy host data to device + call HIP_CHECK(hipMemcpy(d_A, c_loc(h_A), int(M * K, c_size_t) * 8, 1)) + call HIP_CHECK(hipMemcpy(d_csr_row_ptr, c_loc(h_csr_row_ptr), (int(N, c_size_t) + 1) * 4, 1)) + call HIP_CHECK(hipMemcpy(d_csr_col_ind, c_loc(h_csr_col_ind), int(nnz, c_size_t) * 4, 1)) + call HIP_CHECK(hipMemcpy(d_csr_val, c_loc(h_csr_val), int(nnz, c_size_t) * 8, 1)) + call HIP_CHECK(hipMemcpy(d_C, c_loc(h_C), int(M * N, c_size_t) * 8, 1)) + +! Create rocSPARSE handle + call ROCSPARSE_CHECK(rocsparse_create_handle(handle)) + +! Get rocSPARSE version + call ROCSPARSE_CHECK(rocsparse_get_version(handle, version)) + call ROCSPARSE_CHECK(rocsparse_get_git_rev(handle, rev)) + +! Print version on screen + write(*,fmt='(A,I0,A,I0,A,I0,A,A)') 'rocSPARSE version: ', version / 100000, '.', & + mod(version / 100, 1000), '.', mod(version, 100), '-', rev + +! Create matrix descriptor + call ROCSPARSE_CHECK(rocsparse_create_mat_descr(descr)) + +! Call dgemmi to perform matrix multiplication + call ROCSPARSE_CHECK(rocsparse_dgemmi(handle, & + rocsparse_operation_none, & + rocsparse_operation_transpose, & + M, & + N, & + K, & + nnz, & + c_loc(alpha), & + d_A, & + lda, & + descr, & + d_csr_val, & + d_csr_row_ptr, & + d_csr_col_ind, & + c_loc(beta), & + d_C, & + ldc)) + +! Print result + call HIP_CHECK(hipMemcpy(c_loc(h_C), d_C, int(M * N, c_size_t) * 8, 2)) + + write(*,*) 'C:' + do i = 1, M + do j = 0, N - 1 + write(*,fmt='(A,F0.5)', advance='no') ' ', h_C(i + j * ldc) + end do + write(*,*) + end do + +! Clear rocSPARSE + call ROCSPARSE_CHECK(rocsparse_destroy_mat_descr(descr)) + call ROCSPARSE_CHECK(rocsparse_destroy_handle(handle)) + +! Clear device memory + call HIP_CHECK(hipFree(d_A)) + call HIP_CHECK(hipFree(d_csr_row_ptr)) + call HIP_CHECK(hipFree(d_csr_col_ind)) + call HIP_CHECK(hipFree(d_csr_val)) + call HIP_CHECK(hipFree(d_C)) + +end program example_fortran_gemmi diff --git a/clients/samples/example_gemmi.cpp b/clients/samples/example_gemmi.cpp new file mode 100644 index 00000000..4d204d09 --- /dev/null +++ b/clients/samples/example_gemmi.cpp @@ -0,0 +1,203 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#include +#include +#include +#include + +#define HIP_CHECK(stat) \ + { \ + if(stat != hipSuccess) \ + { \ + std::cerr << "Error: hip error in line " << __LINE__ << std::endl; \ + return -1; \ + } \ + } + +#define ROCSPARSE_CHECK(stat) \ + { \ + if(stat != rocsparse_status_success) \ + { \ + std::cerr << "Error: rocsparse error in line " << __LINE__ << std::endl; \ + return -1; \ + } \ + } + +int main(int argc, char* argv[]) +{ + // Query device + int ndev; + HIP_CHECK(hipGetDeviceCount(&ndev)); + + if(ndev < 1) + { + std::cerr << "No HIP device found" << std::endl; + return -1; + } + + // Query device properties + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, 0)); + + std::cout << "Device: " << prop.name << std::endl; + + // rocSPARSE handle + rocsparse_handle handle; + ROCSPARSE_CHECK(rocsparse_create_handle(&handle)); + + rocsparse_mat_descr descr; + ROCSPARSE_CHECK(rocsparse_create_mat_descr(&descr)); + + // Print rocSPARSE version and revision + int ver; + char rev[64]; + + ROCSPARSE_CHECK(rocsparse_get_version(handle, &ver)); + ROCSPARSE_CHECK(rocsparse_get_git_rev(handle, rev)); + + std::cout << "rocSPARSE version: " << ver / 100000 << "." << ver / 100 % 1000 << "." + << ver % 100 << "-" << rev << std::endl; + + // Input data + + // Number of rows and columns + rocsparse_int m = 3; + rocsparse_int n = 5; + rocsparse_int k = 3; + + // Matrix A (m x k) + // ( 9.0 10.0 11.0 ) + // ( 12.0 13.0 14.0 ) + // ( 15.0 16.0 17.0 ) + + // Matrix A in column-major + rocsparse_int lda = m; + double hA[9] = {9.0, 12.0, 15, 10.0, 13.0, 16.0, 11.0, 14.0, 17.0}; + + // Matrix B (n x k) + // ( 1.0 0.0 6.0 ) + // ( 2.0 4.0 0.0 ) + // ( 0.0 5.0 0.0 ) + // ( 3.0 0.0 7.0 ) + // ( 0.0 0.0 8.0 ) + + // Number of non-zero entries + rocsparse_int nnz = 8; + + // CSR column pointers + rocsparse_int hcsr_row_ptr[6] = {0, 2, 4, 5, 7, 8}; + + // CSR row indices + rocsparse_int hcsr_col_ind[8] = {0, 2, 0, 1, 1, 0, 2, 2}; + + // CSR values + double hcsr_val[8] = {1.0, 6.0, 2.0, 4.0, 5.0, 3.0, 7.0, 8.0}; + + // Matrix C (m x n) + // ( 18.0 19.0 20.0 21.0 22.0 ) + // ( 23.0 24.0 25.0 26.0 27.0 ) + // ( 28.0 29.0 30.0 31.0 32.0 ) + + // Matrix C (m x n) in column-major + rocsparse_int ldc = m; + double hC[15] = { + 18.0, 23.0, 28.0, 19.0, 24.0, 29.0, 20.0, 25.0, 30.0, 21.0, 26.0, 31.0, 22.0, 27.0, 32.0}; + + // Scalar alpha and beta + double alpha = 3.7; + double beta = 1.3; + + // Matrix operations + rocsparse_operation trans_A = rocsparse_operation_none; + rocsparse_operation trans_B = rocsparse_operation_transpose; + + // Offload data to device + double* dA; + rocsparse_int* dcsr_row_ptr; + rocsparse_int* dcsr_col_ind; + double* dcsr_val; + double* dC; + + HIP_CHECK(hipMalloc((void**)&dA, sizeof(double) * m * k)); + HIP_CHECK(hipMalloc((void**)&dcsr_row_ptr, sizeof(rocsparse_int) * (n + 1))); + HIP_CHECK(hipMalloc((void**)&dcsr_col_ind, sizeof(rocsparse_int) * nnz)); + HIP_CHECK(hipMalloc((void**)&dcsr_val, sizeof(double) * nnz)); + HIP_CHECK(hipMalloc((void**)&dC, sizeof(double) * m * n)); + + HIP_CHECK(hipMemcpy(dA, hA, sizeof(double) * m * k, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy( + dcsr_row_ptr, hcsr_row_ptr, sizeof(rocsparse_int) * (n + 1), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(dcsr_col_ind, hcsr_col_ind, sizeof(rocsparse_int) * nnz, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dcsr_val, hcsr_val, sizeof(double) * nnz, hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(dC, hC, sizeof(double) * m * n, hipMemcpyHostToDevice)); + + // Call dgemmi + ROCSPARSE_CHECK(rocsparse_dgemmi(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + &alpha, + dA, + lda, + descr, + dcsr_val, + dcsr_row_ptr, + dcsr_col_ind, + &beta, + dC, + ldc)); + + // Print result + HIP_CHECK(hipMemcpy(hC, dC, sizeof(double) * m * n, hipMemcpyDeviceToHost)); + + std::cout.precision(2); + std::cout << "C:" << std::endl; + + for(int i = 0; i < m; ++i) + { + for(int j = 0; j < n; ++j) + { + std::cout << std::scientific << " " << hC[i + j * ldc]; + } + + std::cout << std::endl; + } + + // Clear rocSPARSE + ROCSPARSE_CHECK(rocsparse_destroy_handle(handle)); + ROCSPARSE_CHECK(rocsparse_destroy_mat_descr(descr)); + + // Clear device memory + HIP_CHECK(hipFree(dA)); + HIP_CHECK(hipFree(dcsr_row_ptr)); + HIP_CHECK(hipFree(dcsr_col_ind)); + HIP_CHECK(hipFree(dcsr_val)); + HIP_CHECK(hipFree(dC)); + + return 0; +} diff --git a/clients/tests/CMakeLists.txt b/clients/tests/CMakeLists.txt index 0112f92c..fd908275 100644 --- a/clients/tests/CMakeLists.txt +++ b/clients/tests/CMakeLists.txt @@ -171,6 +171,7 @@ set(ROCSPARSE_TEST_SOURCES test_bsrmm.cpp test_csrmm.cpp test_csrsm.cpp + test_gemmi.cpp test_csrgeam.cpp test_csrgemm.cpp test_csric0.cpp @@ -230,7 +231,7 @@ set_target_properties(rocsparse-test PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${PROJ set(ROCSPARSE_TEST_DATA "${PROJECT_BINARY_DIR}/staging/rocsparse_test.data") add_custom_command(OUTPUT "${ROCSPARSE_TEST_DATA}" COMMAND ../common/rocsparse_gentest.py -I ../include rocsparse_test.yaml -o "${ROCSPARSE_TEST_DATA}" - DEPENDS ../common/rocsparse_gentest.py rocsparse_test.yaml ../include/rocsparse_common.yaml known_bugs.yaml test_axpyi.yaml test_doti.yaml test_dotci.yaml test_gthr.yaml test_gthrz.yaml test_roti.yaml test_sctr.yaml test_bsrmv.yaml test_bsrsv.yaml test_coomv.yaml test_csrmv.yaml test_csrsv.yaml test_ellmv.yaml test_hybmv.yaml test_bsrmm.yaml test_csrmm.yaml test_csrsm.yaml test_csrgeam.yaml test_csrgemm.yaml test_csric0.yaml test_csrilu0.yaml test_csr2coo.yaml test_csr2csc.yaml test_csr2ell.yaml test_csr2hyb.yaml test_bsr2csr.yaml test_csr2bsr.yaml test_coo2csr.yaml test_ell2csr.yaml test_hyb2csr.yaml test_identity.yaml test_csrsort.yaml test_cscsort.yaml test_coosort.yaml test_csricsv.yaml test_csrilusv.yaml test_nnz.yaml test_dense2csr.yaml test_dense2csc.yaml test_csr2dense.yaml test_csc2dense.yaml test_csr2csr_compress.cpp + DEPENDS ../common/rocsparse_gentest.py rocsparse_test.yaml ../include/rocsparse_common.yaml known_bugs.yaml test_axpyi.yaml test_doti.yaml test_dotci.yaml test_gthr.yaml test_gthrz.yaml test_roti.yaml test_sctr.yaml test_bsrmv.yaml test_bsrsv.yaml test_coomv.yaml test_csrmv.yaml test_csrsv.yaml test_ellmv.yaml test_hybmv.yaml test_bsrmm.yaml test_csrmm.yaml test_csrsm.yaml test_gemmi.yaml test_csrgeam.yaml test_csrgemm.yaml test_csric0.yaml test_csrilu0.yaml test_csr2coo.yaml test_csr2csc.yaml test_csr2ell.yaml test_csr2hyb.yaml test_bsr2csr.yaml test_csr2bsr.yaml test_coo2csr.yaml test_ell2csr.yaml test_hyb2csr.yaml test_identity.yaml test_csrsort.yaml test_cscsort.yaml test_coosort.yaml test_csricsv.yaml test_csrilusv.yaml test_nnz.yaml test_dense2csr.yaml test_dense2csc.yaml test_csr2dense.yaml test_csc2dense.yaml test_csr2csr_compress.cpp WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}") add_custom_target(rocsparse-test-data DEPENDS "${ROCSPARSE_TEST_DATA}" ) diff --git a/clients/tests/rocsparse_test.yaml b/clients/tests/rocsparse_test.yaml index 1513081d..90bae453 100644 --- a/clients/tests/rocsparse_test.yaml +++ b/clients/tests/rocsparse_test.yaml @@ -38,6 +38,7 @@ include: test_hybmv.yaml include: test_bsrmm.yaml include: test_csrmm.yaml include: test_csrsm.yaml +include: test_gemmi.yaml include: test_csrgeam.yaml include: test_csrgemm.yaml include: test_csric0.yaml diff --git a/clients/tests/test_gemmi.cpp b/clients/tests/test_gemmi.cpp new file mode 100644 index 00000000..a2b1deb2 --- /dev/null +++ b/clients/tests/test_gemmi.cpp @@ -0,0 +1,115 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#include "rocsparse_data.hpp" +#include "rocsparse_datatype2string.hpp" +#include "rocsparse_test.hpp" +#include "testing_gemmi.hpp" +#include "type_dispatch.hpp" + +#include +#include +#include + +namespace +{ + // By default, this test does not apply to any types. + // The unnamed second parameter is used for enable_if below. + template + struct gemmi_testing : rocsparse_test_invalid + { + }; + + // When the condition in the second argument is satisfied, the type combination + // is valid. When the condition is false, this specialization does not apply. + template + struct gemmi_testing< + T, + typename std::enable_if{} || std::is_same{} + || std::is_same{} + || std::is_same{}>::type> + { + explicit operator bool() + { + return true; + } + void operator()(const Arguments& arg) + { + if(!strcmp(arg.function, "gemmi")) + testing_gemmi(arg); + else if(!strcmp(arg.function, "gemmi_bad_arg")) + testing_gemmi_bad_arg(arg); + else + FAIL() << "Internal error: Test called with unknown function: " << arg.function; + } + }; + + struct gemmi : RocSPARSE_Test + { + // Filter for which types apply to this suite + static bool type_filter(const Arguments& arg) + { + return rocsparse_simple_dispatch(arg); + } + + // Filter for which functions apply to this suite + static bool function_filter(const Arguments& arg) + { + return !strcmp(arg.function, "gemmi") || !strcmp(arg.function, "gemmi_bad_arg"); + } + + // Google Test name suffix based on parameters + static std::string name_suffix(const Arguments& arg) + { + if(arg.matrix == rocsparse_matrix_file_rocalution + || arg.matrix == rocsparse_matrix_file_mtx) + { + return RocSPARSE_TestName{} + << rocsparse_datatype2string(arg.compute_type) << '_' << arg.N << '_' + << arg.alpha << '_' << arg.alphai << '_' << arg.beta << '_' << arg.betai + << '_' << rocsparse_operation2string(arg.transA) << '_' + << rocsparse_operation2string(arg.transB) << '_' + << rocsparse_indexbase2string(arg.baseA) << '_' + << rocsparse_matrix2string(arg.matrix) << '_' << arg.filename; + } + else + { + return RocSPARSE_TestName{} << rocsparse_datatype2string(arg.compute_type) + << '_' << arg.M << '_' << arg.N << '_' << arg.K + << '_' << arg.alpha << '_' << arg.alphai << '_' + << arg.beta << '_' << arg.betai << '_' + << rocsparse_operation2string(arg.transA) << '_' + << rocsparse_operation2string(arg.transB) << '_' + << rocsparse_indexbase2string(arg.baseA) << '_' + << rocsparse_matrix2string(arg.matrix); + } + } + }; + + TEST_P(gemmi, level3) + { + rocsparse_simple_dispatch(GetParam()); + } + INSTANTIATE_TEST_CATEGORIES(gemmi); + +} // namespace diff --git a/clients/tests/test_gemmi.yaml b/clients/tests/test_gemmi.yaml new file mode 100644 index 00000000..6729ba61 --- /dev/null +++ b/clients/tests/test_gemmi.yaml @@ -0,0 +1,189 @@ +# ######################################################################## +# Copyright (c) 2020 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ######################################################################## + +--- +include: rocsparse_common.yaml +include: known_bugs.yaml + +Definitions: + - &alpha_beta_range_quick + - { alpha: 1.0, beta: -1.0, alphai: 1.0, betai: -0.5 } + - { alpha: -0.5, beta: 0.5, alphai: -0.5, betai: 1.0 } + + - &alpha_beta_range_checkin + - { alpha: 2.0, beta: 0.0, alphai: 0.5, betai: 0.5 } + - { alpha: 0.0, beta: 1.0, alphai: 1.5, betai: 0.5 } + - { alpha: 3.0, beta: 1.0, alphai: 0.0, betai: -0.5 } + + - &alpha_beta_range_nightly + - { alpha: 0.0, beta: 0.0, alphai: 1.5, betai: 0.5 } + - { alpha: 2.0, beta: 0.67, alphai: 0.0, betai: 1.5 } + - { alpha: 3.0, beta: 1.0, alphai: 1.5, betai: 0.0 } + - { alpha: -0.5, beta: 0.5, alphai: 1.0, betai: -0.5 } + +Tests: +- name: gemmi_bad_arg + category: pre_checkin + function: gemmi_bad_arg + precision: *single_double_precisions_complex_real + +- name: gemmi + category: quick + function: gemmi + precision: *single_double_precisions_complex_real + M: [-1, 0, 42, 275] + N: [-1, 0, 7, 19, 143] + K: [-1, 0, 50, 173] + alpha_beta: *alpha_beta_range_quick + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + matrix: [rocsparse_matrix_random] + +- name: gemmi + category: pre_checkin + function: gemmi + precision: *single_double_precisions_complex_real + M: [-1, 0, 511, 2059] + N: [-1, 0, 7, 33, 64, 78] + K: [-1, 0, 391, 1375] + alpha_beta: *alpha_beta_range_checkin + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_one] + matrix: [rocsparse_matrix_random] + +- name: gemmi + category: nightly + function: gemmi + precision: *single_double_precisions_complex_real + M: [1943, 4912] + N: [2, 27, 49] + K: [1134, 3291] + alpha_beta: *alpha_beta_range_nightly + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + matrix: [rocsparse_matrix_random] + +- name: gemmi_file + category: quick + function: gemmi + precision: *single_double_precisions + M: 1 + N: [4, 19] + K: 1 + alpha_beta: *alpha_beta_range_quick + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_one] + matrix: [rocsparse_matrix_file_rocalution] + filename: [mac_econ_fwd500, + nos2, + ASIC_320k, + nos4, + nos6, + scircuit] + +- name: gemmi_file + category: pre_checkin + function: gemmi + precision: *single_double_precisions + M: 1 + N: [-1, 0, 35, 73] + K: 1 + alpha_beta: *alpha_beta_range_checkin + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + matrix: [rocsparse_matrix_file_rocalution] + filename: [rma10, + mc2depi, + nos1, + nos3, + nos5, + nos7] + +- name: gemmi_file + category: nightly + function: gemmi + precision: *single_double_precisions + M: 1 + N: [16, 22, 38] + K: 1 + alpha_beta: *alpha_beta_range_nightly + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_one] + matrix: [rocsparse_matrix_file_rocalution] + filename: [bibd_22_8, + bmwcra_1, + amazon0312, + Chebyshev4, + sme3Dc, + webbase-1M, + shipsec1] + +- name: gemmi_file + category: quick + function: gemmi + precision: *single_double_precisions_complex + M: 1 + N: [3, 21] + K: 1 + alpha_beta: *alpha_beta_range_quick + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + matrix: [rocsparse_matrix_file_rocalution] + filename: [Chevron2, + qc2534] + +- name: gemmi_file + category: pre_checkin + function: gemmi + precision: *single_double_precisions_complex + M: 1 + N: [-1, 0, 32, 68] + K: 1 + alpha_beta: *alpha_beta_range_checkin + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_one] + matrix: [rocsparse_matrix_file_rocalution] + filename: [mplate, + Chevron3] + +- name: gemmi_file + category: nightly + function: gemmi + precision: *single_double_precisions_complex + M: 1 + N: [12, 31, 40] + K: 1 + alpha_beta: *alpha_beta_range_nightly + transA: [rocsparse_operation_none] + transB: [rocsparse_operation_transpose] + baseA: [rocsparse_index_base_zero] + matrix: [rocsparse_matrix_file_rocalution] + filename: [Chevron4] diff --git a/docs/source/usermanual.rst b/docs/source/usermanual.rst index b29b466a..a2f27ae7 100644 --- a/docs/source/usermanual.rst +++ b/docs/source/usermanual.rst @@ -649,6 +649,7 @@ Function name single :cpp:func:`rocsparse_csrsm_zero_pivot` :cpp:func:`rocsparse_csrsm_clear` :cpp:func:`rocsparse_Xcsrsm_solve() ` x x x x +:cpp:func:`rocsparse_Xgemmi() ` x x x x ========================================================================= ====== ====== ============== ============== Sparse Extra Functions @@ -1212,6 +1213,17 @@ rocsparse_csrsm_clear() .. _rocsparse_extra_functions_: +rocsparse_gemmi() +----------------- + +.. doxygenfunction:: rocsparse_sgemmi + :outline: +.. doxygenfunction:: rocsparse_dgemmi + :outline: +.. doxygenfunction:: rocsparse_cgemmi + :outline: +.. doxygenfunction:: rocsparse_zgemmi + Sparse Extra Functions ====================== diff --git a/library/include/rocsparse-functions.h b/library/include/rocsparse-functions.h index cf330d50..3140d0ad 100644 --- a/library/include/rocsparse-functions.h +++ b/library/include/rocsparse-functions.h @@ -3653,6 +3653,222 @@ rocsparse_status rocsparse_zcsrsm_solve(rocsparse_handle handle, void* temp_buffer); /**@}*/ +/*! \ingroup level3_module + * \brief Dense matrix sparse matrix multiplication using CSR storage format + * + * \details + * \p rocsparse_gemmi multiplies the scalar \f$\alpha\f$ with a dense \f$m \times k\f$ + * matrix \f$A\f$ and the sparse \f$k \times n\f$ matrix \f$B\f$, defined in CSR + * storage format and adds the result to the dense \f$m \times n\f$ matrix \f$C\f$ that + * is multiplied by the scalar \f$\beta\f$, such that + * \f[ + * C := \alpha \cdot op(A) \cdot op(B) + \beta \cdot C + * \f] + * with + * \f[ + * op(A) = \left\{ + * \begin{array}{ll} + * A, & \text{if trans_A == rocsparse_operation_none} \\ + * A^T, & \text{if trans_A == rocsparse_operation_transpose} \\ + * A^H, & \text{if trans_A == rocsparse_operation_conjugate_transpose} + * \end{array} + * \right. + * \f] + * and + * \f[ + * op(B) = \left\{ + * \begin{array}{ll} + * B, & \text{if trans_B == rocsparse_operation_none} \\ + * B^T, & \text{if trans_B == rocsparse_operation_transpose} \\ + * B^H, & \text{if trans_B == rocsparse_operation_conjugate_transpose} + * \end{array} + * \right. + * \f] + * + * \note + * This function is non blocking and executed asynchronously with respect to the host. + * It may return before the actual computation has finished. + * + * @param[in] + * handle handle to the rocsparse library context queue. + * @param[in] + * trans_A matrix \f$A\f$ operation type. + * @param[in] + * trans_B matrix \f$B\f$ operation type. + * @param[in] + * m number of rows of the dense matrix \f$A\f$. + * @param[in] + * n number of columns of the sparse CSR matrix \f$op(B)\f$ and \f$C\f$. + * @param[in] + * k number of columns of the dense matrix \f$A\f$. + * @param[in] + * nnz number of non-zero entries of the sparse CSR matrix \f$B\f$. + * @param[in] + * alpha scalar \f$\alpha\f$. + * @param[in] + * A array of dimension \f$lda \times k\f$ (\f$op(A) == A\f$) or + * \f$lda \times m\f$ (\f$op(A) == A^T\f$ or \f$op(A) == A^H\f$). + * @param[in] + * lda leading dimension of \f$A\f$, must be at least \f$m\f$ + * (\f$op(A) == A\f$) or \f$k\f$ (\f$op(A) == A^T\f$ or + * \f$op(A) == A^H\f$). + * @param[in] + * descr descriptor of the sparse CSR matrix \f$B\f$. Currently, only + * \ref rocsparse_matrix_type_general is supported. + * @param[in] + * csr_val array of \p nnz elements of the sparse CSR matrix \f$B\f$. + * @param[in] + * csr_row_ptr array of \p m+1 elements that point to the start of every row of the + * sparse CSR matrix \f$B\f$. + * @param[in] + * csr_col_ind array of \p nnz elements containing the column indices of the sparse CSR + * matrix \f$B\f$. + * @param[in] + * beta scalar \f$\beta\f$. + * @param[inout] + * C array of dimension \f$ldc \times n\f$ that holds the values of \f$C\f$. + * @param[in] + * ldc leading dimension of \f$C\f$, must be at least \f$m\f$. + * + * \retval rocsparse_status_success the operation completed successfully. + * \retval rocsparse_status_invalid_handle the library context was not initialized. + * \retval rocsparse_status_invalid_size \p m, \p n, \p k, \p nnz, \p lda or \p ldc + * is invalid. + * \retval rocsparse_status_invalid_pointer \p alpha, \p A, \p csr_val, + * \p csr_row_ptr, \p csr_col_ind, \p beta or \p C pointer is invalid. + * + * \par Example + * This example multiplies a dense matrix with a CSC matrix. + * \code{.c} + * rocsparse_int m = 2; + * rocsparse_int n = 5; + * rocsparse_int k = 3; + * rocsparse_int nnz = 8; + * rocsparse_int lda = m; + * rocsparse_int ldc = m; + * + * // Matrix A (m x k) + * // ( 9.0 10.0 11.0 ) + * // ( 12.0 13.0 14.0 ) + * + * // Matrix B (k x n) + * // ( 1.0 2.0 0.0 3.0 0.0 ) + * // ( 0.0 4.0 5.0 0.0 0.0 ) + * // ( 6.0 0.0 0.0 7.0 8.0 ) + * + * // Matrix C (m x n) + * // ( 15.0 16.0 17.0 18.0 19.0 ) + * // ( 20.0 21.0 22.0 23.0 24.0 ) + * + * A[lda * k] = {9.0, 12.0, 10.0, 13.0, 11.0, 14.0}; // device memory + * csc_col_ptr_B[n + 1] = {0, 2, 4, 5, 7, 8}; // device memory + * csc_row_ind_B[nnz] = {0, 0, 1, 1, 2, 3, 3, 4}; // device memory + * csc_val_B[nnz] = {1.0, 6.0, 2.0, 4.0, 5.0, 3.0, 7.0, 8.0}; // device memory + * C[ldc * n] = {15.0, 20.0, 16.0, 21.0, 17.0, 22.0, // device memory + * 18.0, 23.0, 19.0, 24.0}; + * + * // alpha and beta + * float alpha = 1.0f; + * float beta = 0.0f; + * + * // Perform the matrix multiplication + * rocsparse_sgemmi(handle, + * rocsparse_operation_none, + * rocsparse_operation_transpose, + * m, + * n, + * k, + * nnz, + * &alpha, + * A, + * lda, + * descr_B, + * csc_val_B, + * csc_col_ptr_B, + * csc_row_ind_B, + * &beta, + * C, + * ldc); + * \endcode + */ +/**@{*/ +ROCSPARSE_EXPORT +rocsparse_status rocsparse_sgemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const float* alpha, + const float* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const float* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const float* beta, + float* C, + rocsparse_int ldc); + +ROCSPARSE_EXPORT +rocsparse_status rocsparse_dgemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const double* alpha, + const double* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const double* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const double* beta, + double* C, + rocsparse_int ldc); + +ROCSPARSE_EXPORT +rocsparse_status rocsparse_cgemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const rocsparse_float_complex* alpha, + const rocsparse_float_complex* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const rocsparse_float_complex* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const rocsparse_float_complex* beta, + rocsparse_float_complex* C, + rocsparse_int ldc); + +ROCSPARSE_EXPORT +rocsparse_status rocsparse_zgemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const rocsparse_double_complex* alpha, + const rocsparse_double_complex* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const rocsparse_double_complex* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const rocsparse_double_complex* beta, + rocsparse_double_complex* C, + rocsparse_int ldc); +/**@}*/ + /* * =========================================================================== * extra SPARSE diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index a30dc45d..8a84a79f 100644 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -54,6 +54,7 @@ set(rocsparse_source src/level3/rocsparse_bsrmm.cpp src/level3/rocsparse_csrmm.cpp src/level3/rocsparse_csrsm.cpp + src/level3/rocsparse_gemmi.cpp # Extra src/extra/rocsparse_csrgeam.cpp diff --git a/library/src/level3/gemmi_device.h b/library/src/level3/gemmi_device.h new file mode 100644 index 00000000..22f6e188 --- /dev/null +++ b/library/src/level3/gemmi_device.h @@ -0,0 +1,95 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#pragma once +#ifndef GEMMI_DEVICE_H +#define GEMMI_DEVICE_H + +#include "common.h" + +#include + +template +__device__ void gemmi_scale_kernel(rocsparse_int size, T alpha, T* __restrict__ data) +{ + rocsparse_int idx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; + + if(idx >= size) + { + return; + } + + data[idx] *= alpha; +} + +template +__device__ void gemmit_kernel(rocsparse_int m, + T alpha, + const T* __restrict__ A, + rocsparse_int lda, + const rocsparse_int* __restrict__ csr_row_ptr, + const rocsparse_int* __restrict__ csr_col_ind, + const T* __restrict__ csr_val, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base base) +{ + rocsparse_int row = hipBlockIdx_y; + rocsparse_int col = hipBlockIdx_x * BLOCKSIZE + hipThreadIdx_x; + + // Do not run out of bounds + if(col >= m) + { + return; + } + + // Row entry into B + rocsparse_int row_begin = csr_row_ptr[row] - base; + rocsparse_int row_end = csr_row_ptr[row + 1] - base; + + // Accumulator + T sum = static_cast(0); + + // Loop over the column indices of B of the current row + for(rocsparse_int k = row_begin; k < row_end; ++k) + { + rocsparse_int col_B = csr_col_ind[k] - base; + T val_B = csr_val[k]; + T val_A = A[col_B * lda + col]; + + sum = rocsparse_fma(val_A, val_B, sum); + } + + // Write result back to C + if(beta != static_cast(0)) + { + C[row * ldc + col] = rocsparse_fma(beta, C[row * ldc + col], alpha * sum); + } + else + { + C[row * ldc + col] = alpha * sum; + } +} + +#endif // GEMMI_DEVICE_H diff --git a/library/src/level3/rocsparse_gemmi.cpp b/library/src/level3/rocsparse_gemmi.cpp new file mode 100644 index 00000000..b35c149d --- /dev/null +++ b/library/src/level3/rocsparse_gemmi.cpp @@ -0,0 +1,180 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#include "rocsparse.h" + +#include "rocsparse_gemmi.hpp" + +/* + * =========================================================================== + * C wrapper + * =========================================================================== + */ + +extern "C" rocsparse_status rocsparse_sgemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const float* alpha, + const float* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const float* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const float* beta, + float* C, + rocsparse_int ldc) +{ + return rocsparse_gemmi_template(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + alpha, + A, + lda, + descr, + csr_val, + csr_row_ptr, + csr_col_ind, + beta, + C, + ldc); +} + +extern "C" rocsparse_status rocsparse_dgemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const double* alpha, + const double* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const double* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const double* beta, + double* C, + rocsparse_int ldc) +{ + return rocsparse_gemmi_template(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + alpha, + A, + lda, + descr, + csr_val, + csr_row_ptr, + csr_col_ind, + beta, + C, + ldc); +} + +extern "C" rocsparse_status rocsparse_cgemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const rocsparse_float_complex* alpha, + const rocsparse_float_complex* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const rocsparse_float_complex* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const rocsparse_float_complex* beta, + rocsparse_float_complex* C, + rocsparse_int ldc) +{ + return rocsparse_gemmi_template(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + alpha, + A, + lda, + descr, + csr_val, + csr_row_ptr, + csr_col_ind, + beta, + C, + ldc); +} + +extern "C" rocsparse_status rocsparse_zgemmi(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const rocsparse_double_complex* alpha, + const rocsparse_double_complex* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const rocsparse_double_complex* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const rocsparse_double_complex* beta, + rocsparse_double_complex* C, + rocsparse_int ldc) +{ + return rocsparse_gemmi_template(handle, + trans_A, + trans_B, + m, + n, + k, + nnz, + alpha, + A, + lda, + descr, + csr_val, + csr_row_ptr, + csr_col_ind, + beta, + C, + ldc); +} diff --git a/library/src/level3/rocsparse_gemmi.hpp b/library/src/level3/rocsparse_gemmi.hpp new file mode 100644 index 00000000..eb509f6c --- /dev/null +++ b/library/src/level3/rocsparse_gemmi.hpp @@ -0,0 +1,345 @@ +/* ************************************************************************ + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#pragma once +#ifndef ROCSPARSE_GEMMI_HPP +#define ROCSPARSE_GEMMI_HPP + +#include "definitions.h" +#include "gemmi_device.h" +#include "handle.h" +#include "rocsparse.h" +#include "utility.h" + +#include + +template +__launch_bounds__(256) __global__ + void gemmi_scale_kernel_host_pointer(rocsparse_int size, T alpha, T* __restrict__ data) +{ + gemmi_scale_kernel(size, alpha, data); +} + +template +__launch_bounds__(256) __global__ + void gemmi_scale_kernel_device_pointer(rocsparse_int size, + const T* __restrict__ alpha, + T* __restrict__ data) +{ + gemmi_scale_kernel(size, *alpha, data); +} + +template +__launch_bounds__(BLOCKSIZE) __global__ + void gemmit_kernel_host_pointer(rocsparse_int m, + T alpha, + const T* __restrict__ A, + rocsparse_int lda, + const rocsparse_int* __restrict__ csr_row_ptr, + const rocsparse_int* __restrict__ csr_col_ind, + const T* __restrict__ csr_val, + T beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base base) +{ + gemmit_kernel( + m, alpha, A, lda, csr_row_ptr, csr_col_ind, csr_val, beta, C, ldc, base); +} + +template +__launch_bounds__(BLOCKSIZE) __global__ + void gemmit_kernel_device_pointer(rocsparse_int m, + const T* alpha, + const T* __restrict__ A, + rocsparse_int lda, + const rocsparse_int* __restrict__ csr_row_ptr, + const rocsparse_int* __restrict__ csr_col_ind, + const T* __restrict__ csr_val, + const T* beta, + T* __restrict__ C, + rocsparse_int ldc, + rocsparse_index_base base) +{ + if(*alpha == static_cast(0) && *beta == static_cast(1)) + { + return; + } + + gemmit_kernel( + m, *alpha, A, lda, csr_row_ptr, csr_col_ind, csr_val, *beta, C, ldc, base); +} + +template +rocsparse_status rocsparse_gemmi_template(rocsparse_handle handle, + rocsparse_operation trans_A, + rocsparse_operation trans_B, + rocsparse_int m, + rocsparse_int n, + rocsparse_int k, + rocsparse_int nnz, + const T* alpha, + const T* A, + rocsparse_int lda, + const rocsparse_mat_descr descr, + const T* csr_val, + const rocsparse_int* csr_row_ptr, + const rocsparse_int* csr_col_ind, + const T* beta, + T* C, + rocsparse_int ldc) +{ + // Check for valid handle and matrix descriptor + if(handle == nullptr) + { + return rocsparse_status_invalid_handle; + } + else if(descr == nullptr) + { + return rocsparse_status_invalid_pointer; + } + + // Logging + if(handle->pointer_mode == rocsparse_pointer_mode_host) + { + log_trace(handle, + replaceX("rocsparse_Xgemmi"), + trans_A, + trans_B, + m, + n, + k, + nnz, + *alpha, + (const void*&)A, + lda, + (const void*&)descr, + (const void*&)csr_val, + (const void*&)csr_row_ptr, + (const void*&)csr_col_ind, + *beta, + (const void*&)C, + ldc); + } + else + { + log_trace(handle, + replaceX("rocsparse_Xgemmi"), + trans_A, + trans_B, + m, + n, + k, + nnz, + (const void*&)alpha, + (const void*&)A, + lda, + (const void*&)descr, + (const void*&)csr_val, + (const void*&)csr_row_ptr, + (const void*&)csr_col_ind, + (const void*&)beta, + (const void*&)C, + ldc); + } + + // Check index base + if(descr->base != rocsparse_index_base_zero && descr->base != rocsparse_index_base_one) + { + return rocsparse_status_invalid_value; + } + if(descr->type != rocsparse_matrix_type_general) + { + return rocsparse_status_not_implemented; + } + if(trans_A != rocsparse_operation_none) + { + return rocsparse_status_not_implemented; + } + if(trans_B != rocsparse_operation_transpose) + { + return rocsparse_status_not_implemented; + } + + // Check sizes + if(m < 0 || n < 0 || k < 0 || nnz < 0) + { + return rocsparse_status_invalid_size; + } + + // Quick return if possible + if(m == 0 || n == 0) + { + return rocsparse_status_success; + } + + // Check pointer arguments + + // beta and C is always required + if(beta == nullptr || C == nullptr) + { + return rocsparse_status_invalid_pointer; + } + + // A is only required if k != 0 + if(k != 0 && (alpha == nullptr || A == nullptr)) + { + return rocsparse_status_invalid_pointer; + } + + // B is only required if k != 0 and nnz != 0 + if(k != 0 && nnz != 0 + && (csr_val == nullptr || csr_row_ptr == nullptr || csr_col_ind == nullptr)) + { + return rocsparse_status_invalid_pointer; + } + + // Check leading dimensions + if(lda < std::max(1, m) || ldc < std::max(1, m)) + { + return rocsparse_status_invalid_value; + } + + // Stream + hipStream_t stream = handle->stream; + + // If k == 0, scale C with beta + if(k == 0) + { +#define SCALE_DIM 256 + dim3 scale_blocks((m * n - 1) / SCALE_DIM + 1); + dim3 scale_threads(SCALE_DIM); +#undef SCALE_DIM + + if(handle->pointer_mode == rocsparse_pointer_mode_device) + { + hipLaunchKernelGGL((gemmi_scale_kernel_device_pointer), + scale_blocks, + scale_threads, + 0, + stream, + m * n, + beta, + C); + } + else + { + if(*beta == static_cast(0)) + { + RETURN_IF_HIP_ERROR(hipMemsetAsync(C, 0, sizeof(T) * m * n, stream)); + } + else if(*beta != static_cast(1)) + { + hipLaunchKernelGGL((gemmi_scale_kernel_host_pointer), + scale_blocks, + scale_threads, + 0, + stream, + m * n, + *beta, + C); + } + } + + return rocsparse_status_success; + } + +#define GEMMIT_DIM 256 + dim3 gemmit_blocks((m - 1) / GEMMIT_DIM + 1, n); + dim3 gemmit_threads(GEMMIT_DIM); + + if(handle->pointer_mode == rocsparse_pointer_mode_device) + { + hipLaunchKernelGGL((gemmit_kernel_device_pointer), + gemmit_blocks, + gemmit_threads, + 0, + stream, + m, + alpha, + A, + lda, + csr_row_ptr, + csr_col_ind, + csr_val, + beta, + C, + ldc, + descr->base); + } + else + { + // Quick return + if(*alpha == static_cast(0) && *beta == static_cast(1)) + { + return rocsparse_status_success; + } + else if(*alpha == static_cast(0)) + { + if(*beta == static_cast(0)) + { + RETURN_IF_HIP_ERROR(hipMemsetAsync(C, 0, sizeof(T) * m * n, stream)); + } + else + { +#define SCALE_DIM 256 + dim3 scale_blocks((m * n - 1) / SCALE_DIM + 1); + dim3 scale_threads(SCALE_DIM); +#undef SCALE_DIM + + hipLaunchKernelGGL((gemmi_scale_kernel_host_pointer), + scale_blocks, + scale_threads, + 0, + stream, + m * n, + *beta, + C); + } + + return rocsparse_status_success; + } + + hipLaunchKernelGGL((gemmit_kernel_host_pointer), + gemmit_blocks, + gemmit_threads, + 0, + stream, + m, + *alpha, + A, + lda, + csr_row_ptr, + csr_col_ind, + csr_val, + *beta, + C, + ldc, + descr->base); + } +#undef GEMMIT_DIM + + return rocsparse_status_success; +} + +#endif // ROCSPARSE_GEMMI_HPP diff --git a/library/src/rocsparse_module.f90 b/library/src/rocsparse_module.f90 index cc67ef89..6512761e 100644 --- a/library/src/rocsparse_module.f90 +++ b/library/src/rocsparse_module.f90 @@ -2266,6 +2266,107 @@ function rocsparse_zcsrsm_solve(handle, trans_A, trans_B, m, nrhs, nnz, alpha, & type(c_ptr), value :: temp_buffer end function rocsparse_zcsrsm_solve +! rocsparse_gemmi + function rocsparse_sgemmi(handle, trans_A, trans_B, m, n, k, nnz, alpha, A, & + lda, descr, csr_val, csr_row_ptr, csr_col_ind, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_sgemmi') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: m + integer(c_int), value :: n + integer(c_int), value :: k + integer(c_int), value :: nnz + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: A + integer(c_int), value :: lda + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: csr_val + type(c_ptr), intent(in), value :: csr_row_ptr + type(c_ptr), intent(in), value :: csr_col_ind + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_sgemmi + + function rocsparse_dgemmi(handle, trans_A, trans_B, m, n, k, nnz, alpha, A, & + lda, descr, csr_val, csr_row_ptr, csr_col_ind, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_dgemmi') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: m + integer(c_int), value :: n + integer(c_int), value :: k + integer(c_int), value :: nnz + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: A + integer(c_int), value :: lda + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: csr_val + type(c_ptr), intent(in), value :: csr_row_ptr + type(c_ptr), intent(in), value :: csr_col_ind + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_dgemmi + + function rocsparse_cgemmi(handle, trans_A, trans_B, m, n, k, nnz, alpha, A, & + lda, descr, csr_val, csr_row_ptr, csr_col_ind, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_cgemmi') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: m + integer(c_int), value :: n + integer(c_int), value :: k + integer(c_int), value :: nnz + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: A + integer(c_int), value :: lda + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: csr_val + type(c_ptr), intent(in), value :: csr_row_ptr + type(c_ptr), intent(in), value :: csr_col_ind + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_cgemmi + + function rocsparse_zgemmi(handle, trans_A, trans_B, m, n, k, nnz, alpha, A, & + lda, descr, csr_val, csr_row_ptr, csr_col_ind, beta, C, ldc) & + result(c_int) & + bind(c, name = 'rocsparse_zgemmi') + use iso_c_binding + implicit none + type(c_ptr), value :: handle + integer(c_int), value :: trans_A + integer(c_int), value :: trans_B + integer(c_int), value :: m + integer(c_int), value :: n + integer(c_int), value :: k + integer(c_int), value :: nnz + type(c_ptr), intent(in), value :: alpha + type(c_ptr), intent(in), value :: A + integer(c_int), value :: lda + type(c_ptr), intent(in), value :: descr + type(c_ptr), intent(in), value :: csr_val + type(c_ptr), intent(in), value :: csr_row_ptr + type(c_ptr), intent(in), value :: csr_col_ind + type(c_ptr), intent(in), value :: beta + type(c_ptr), value :: C + integer(c_int), value :: ldc + end function rocsparse_zgemmi + ! =========================================================================== ! extra SPARSE ! ===========================================================================