From dd3500e0d8e9e8ed1279928203f21c706c4bf53a Mon Sep 17 00:00:00 2001 From: Juan Zuniga-Anaya <50754207+jzuniga-amd@users.noreply.github.com> Date: Thu, 14 Jan 2021 13:04:37 -0700 Subject: [PATCH] Second round of optimizations for LU (GETRF) (#202) * new block sizes * fix no pivot cases * changelog * remove rocblas_initialize from bench client --- CHANGELOG.md | 32 ++--- rocsolver/clients/benchmarks/client.cpp | 4 +- rocsolver/library/src/include/ideal_sizes.hpp | 18 ++- .../library/src/lapack/roclapack_getf2.cpp | 29 ++-- .../library/src/lapack/roclapack_getf2.hpp | 132 +++++++++++++----- .../src/lapack/roclapack_getf2_batched.cpp | 43 +++--- .../roclapack_getf2_strided_batched.cpp | 45 +++--- .../library/src/lapack/roclapack_getrf.cpp | 31 ++-- .../library/src/lapack/roclapack_getrf.hpp | 85 +++++++---- .../src/lapack/roclapack_getrf_batched.cpp | 45 +++--- .../roclapack_getrf_strided_batched.cpp | 47 +++---- 11 files changed, 302 insertions(+), 209 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b299bc4b6..67c558990 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,27 +2,7 @@ Full documentation for rocSOLVER is available at [rocsolver.readthedocs.io](https://rocsolver.readthedocs.io/en/latest/). -## [(Unreleased) rocSOLVER for ROCm 4.1.0] -### Added -- Sample code and unit test for unified memory model/Heterogeneous Memory Management (HMM) - -### Optimizations - -### Changed - -### Deprecated - -### Removed - -### Fixed -- Fixed runtime errors in debug mode caused by incorrect kernel launch bounds -- Fixed complex unit test bug caused by incorrect zaxpy function signature -- Eliminated a small memory transfer that was being done on the default stream -- Fixed GESVD right singular vectors for 1x1 matrices - - - -## [(Unreleased) rocSOLVER 3.11.0 for ROCm 4.0.0] +## [(Unreleased) rocSOLVER 3.11.0 for ROCm 4.1.0] ### Added - Eigensolver routines for symmetric/hermitian matrices: - STERF, STEQR @@ -36,15 +16,25 @@ Full documentation for rocSOLVER is available at [rocsolver.readthedocs.io](http - LATRD - SYTD2, SYTRD (with batched and strided\_batched versions) - HETD2, HETRD (with batched and strided\_batched versions) +- Sample code and unit test for unified memory model/Heterogeneous Memory Management (HMM) + +### Optimizations +- Improved performance of LU factorization of small and mid-size matrices (n <= 2048) ### Changed - Raised minimum requirement for building rocSOLVER from source to CMake 3.8 - Switched to use semantic versioning for the library +- Enabled automatic reallocation of memory workspace in rocsolver clients ### Removed - Removed `-DOPTIMAL` from the `roc::rocsolver` CMake usage requirements. This is an internal rocSOLVER definition, and does not need to be defined by library users +### Fixed +- Fixed runtime errors in debug mode caused by incorrect kernel launch bounds +- Fixed complex unit test bug caused by incorrect zaxpy function signature +- Eliminated a small memory transfer that was being done on the default stream +- Fixed GESVD right singular vectors for 1x1 matrices ## [rocSOLVER 3.10.0 for ROCm 3.10.0] diff --git a/rocsolver/clients/benchmarks/client.cpp b/rocsolver/clients/benchmarks/client.cpp index 010c50840..5ec44a655 100644 --- a/rocsolver/clients/benchmarks/client.cpp +++ b/rocsolver/clients/benchmarks/client.cpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (c) 2016-2020 Advanced Micro Devices, Inc. + * Copyright (c) 2016-2021 Advanced Micro Devices, Inc. * ************************************************************************ */ #include "testing_bdsqr.hpp" @@ -42,8 +42,6 @@ namespace po = boost::program_options; int main(int argc, char* argv[]) try { - rocblas_initialize(); - Arguments argus; // disable unit_check in client benchmark, it is only diff --git a/rocsolver/library/src/include/ideal_sizes.hpp b/rocsolver/library/src/include/ideal_sizes.hpp index 25a50d3be..c85ef43f3 100644 --- a/rocsolver/library/src/include/ideal_sizes.hpp +++ b/rocsolver/library/src/include/ideal_sizes.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. + * Copyright (c) 2019-2021 Advanced Micro Devices, Inc. * ************************************************************************ */ #pragma once @@ -23,13 +23,23 @@ #define ORMxx_ORMxx_BLOCKSIZE 32 // getf2/getfr -#define GETRF_GETF2_SWITCHSIZE 64 #define GETF2_MAX_THDS 256 -#define GETRF_GETF2_BLOCKSIZE 64 #define GETF2_OPTIM_NGRP \ 16, 15, 8, 8, 8, 8, 8, 8, 6, 6, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 -#define GETF2_BATCH_OPTIM_MAX_SIZE 2048 +#define GETF2_BATCH_OPTIM_MAX_SIZE 1024 #define GETF2_OPTIM_MAX_SIZE 1024 +#define GETRF_NUM_INTERVALS_NORMAL 4 +#define GETRF_INTERVALS_NORMAL 65, 657, 1217, 5249 +#define GETRF_BLKSIZES_NORMAL 1, 32, 1, 128, 192 +#define GETRF_NUM_INTERVALS_BATCH 3 +#define GETRF_INTERVALS_BATCH 65, 497, 2049 +#define GETRF_BLKSIZES_BATCH 1, 16, 32, 64 +#define GETRF_NPVT_NUM_INTERVALS_NORMAL 3 +#define GETRF_NPVT_INTERVALS_NORMAL 65, 3073, 4609 +#define GETRF_NPVT_BLKSIZES_NORMAL 1, 32, 64, 192 +#define GETRF_NPVT_NUM_INTERVALS_BATCH 3 +#define GETRF_NPVT_INTERVALS_BATCH 45, 181, 2049 +#define GETRF_NPVT_BLKSIZES_BATCH 1, 16, 32, 64 // getri #define GETRI_SWITCHSIZE_MID 64 diff --git a/rocsolver/library/src/lapack/roclapack_getf2.cpp b/rocsolver/library/src/lapack/roclapack_getf2.cpp index 8f1227df4..da2c6860a 100644 --- a/rocsolver/library/src/lapack/roclapack_getf2.cpp +++ b/rocsolver/library/src/lapack/roclapack_getf2.cpp @@ -1,18 +1,17 @@ /* ************************************************************************ - * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. + * Copyright (c) 2019-2021 Advanced Micro Devices, Inc. * ************************************************************************ */ #include "roclapack_getf2.hpp" -template +template rocblas_status rocsolver_getf2_impl(rocblas_handle handle, const rocblas_int m, const rocblas_int n, U A, const rocblas_int lda, rocblas_int* ipiv, - rocblas_int* info, - const int pivot) + rocblas_int* info) { using S = decltype(std::real(T{})); @@ -22,7 +21,7 @@ rocblas_status rocsolver_getf2_impl(rocblas_handle handle, // logging is missing ??? // argument checking - rocblas_status st = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, pivot); + rocblas_status st = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, PIVOT); if(st != rocblas_status_continue) return st; @@ -65,8 +64,8 @@ rocblas_status rocsolver_getf2_impl(rocblas_handle handle, init_scalars(handle, (T*)scalars); // execution - return rocsolver_getf2_template( - handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, pivot, + return rocsolver_getf2_template( + handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, (T*)scalars, (rocblas_index_value_t*)work, (T*)pivotval, (rocblas_int*)pivotidx); } @@ -86,7 +85,7 @@ rocblas_status rocsolver_sgetf2(rocblas_handle handle, rocblas_int* ipiv, rocblas_int* info) { - return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 1); + return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_dgetf2(rocblas_handle handle, @@ -97,7 +96,7 @@ rocblas_status rocsolver_dgetf2(rocblas_handle handle, rocblas_int* ipiv, rocblas_int* info) { - return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 1); + return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_cgetf2(rocblas_handle handle, @@ -108,7 +107,7 @@ rocblas_status rocsolver_cgetf2(rocblas_handle handle, rocblas_int* ipiv, rocblas_int* info) { - return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 1); + return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_zgetf2(rocblas_handle handle, @@ -119,7 +118,7 @@ rocblas_status rocsolver_zgetf2(rocblas_handle handle, rocblas_int* ipiv, rocblas_int* info) { - return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 1); + return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_sgetf2_npvt(rocblas_handle handle, @@ -130,7 +129,7 @@ rocblas_status rocsolver_sgetf2_npvt(rocblas_handle handle, rocblas_int* info) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 0); + return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_dgetf2_npvt(rocblas_handle handle, @@ -141,7 +140,7 @@ rocblas_status rocsolver_dgetf2_npvt(rocblas_handle handle, rocblas_int* info) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 0); + return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_cgetf2_npvt(rocblas_handle handle, @@ -152,7 +151,7 @@ rocblas_status rocsolver_cgetf2_npvt(rocblas_handle handle, rocblas_int* info) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 0); + return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_zgetf2_npvt(rocblas_handle handle, @@ -163,7 +162,7 @@ rocblas_status rocsolver_zgetf2_npvt(rocblas_handle handle, rocblas_int* info) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info, 0); + return rocsolver_getf2_impl(handle, m, n, A, lda, ipiv, info); } } // extern C diff --git a/rocsolver/library/src/lapack/roclapack_getf2.hpp b/rocsolver/library/src/lapack/roclapack_getf2.hpp index 9fb8118a0..c7a293dfc 100644 --- a/rocsolver/library/src/lapack/roclapack_getf2.hpp +++ b/rocsolver/library/src/lapack/roclapack_getf2.hpp @@ -170,7 +170,7 @@ __global__ void __launch_bounds__(GETF2_MAX_THDS) LUfact_panel_kernel(const rocb GETF2_MAX_THDS <= m <= GETF2_OPTIM_MAX_SIZE and n = WAVESIZE (to be used by GETRF if block size = WAVESIZE) *******************************************************************/ -template +template __global__ void __launch_bounds__(GETF2_MAX_THDS) LUfact_panel_kernel_blk(const rocblas_int m, U AA, @@ -212,9 +212,9 @@ __global__ void __launch_bounds__(GETF2_MAX_THDS) int tmp; int pivot_index; int myinfo = 0; // to build info - int mypivs[DIM]; // to build ipiv - int myrows[DIM]; // to store this-thread active-rows-indices - T rA[DIM][WAVESIZE]; // to store this-thread active-rows-values + int mypivs[DIM1]; // to build ipiv + int myrows[DIM1]; // to store this-thread active-rows-indices + T rA[DIM1][DIM2]; // to store this-thread active-rows-values // initialization for(int i = 0; i < nrows; ++i) @@ -226,14 +226,16 @@ __global__ void __launch_bounds__(GETF2_MAX_THDS) // read corresponding rows from global memory into local array for(int i = 0; i < nrows; ++i) { -#pragma unroll WAVESIZE - for(int j = 0; j < WAVESIZE; ++j) +#pragma unroll DIM2 + for(int j = 0; j < DIM2; ++j) + { rA[i][j] = A[myrows[i] + j * lda]; + } } -// for each pivot (main loop) -#pragma unroll WAVESIZE - for(int k = 0; k < WAVESIZE; ++k) + // for each pivot (main loop) +#pragma unroll DIM2 + for(int k = 0; k < DIM2; ++k) { // share current column for(int i = 0; i < nrows; ++i) @@ -269,7 +271,7 @@ __global__ void __launch_bounds__(GETF2_MAX_THDS) { myrows[i] = k; // share pivot row - for(int j = k + 1; j < WAVESIZE; ++j) + for(int j = k + 1; j < DIM2; ++j) common[j] = rA[i][j]; } else if(myrows[i] == k) @@ -286,7 +288,7 @@ __global__ void __launch_bounds__(GETF2_MAX_THDS) if(myrows[i] > k) { rA[i][k] *= pivot_value; - for(int j = k + 1; j < WAVESIZE; ++j) + for(int j = k + 1; j < DIM2; ++j) rA[i][j] -= rA[i][k] * common[j]; } } @@ -296,22 +298,25 @@ __global__ void __launch_bounds__(GETF2_MAX_THDS) // write results to global memory if(myrow == 0) *info = myinfo; + if(pivot) { for(int i = 0; i < nrows; ++i) { - if(myrows[i] < WAVESIZE) + if(myrows[i] < DIM2) ipiv[myrows[i]] = mypivs[i]; } } + for(int i = 0; i < nrows; ++i) { -#pragma unroll WAVESIZE - for(int j = 0; j < WAVESIZE; ++j) +#pragma unroll DIM2 + for(int j = 0; j < DIM2; ++j) + { A[myrows[i] + j * lda] = rA[i][j]; + } } } - /************************************************************************** Launcher of LUfact_panel kernels **************************************************************************/ @@ -330,13 +335,13 @@ rocblas_status LUfact_panel(rocblas_handle handle, const rocblas_int batch_count, const rocblas_int pivot) { -#define RUN_LUFACT_PANEL(DIM) \ - if(n == 64) \ - hipLaunchKernelGGL((LUfact_panel_kernel_blk), grid, block, lmemsize, stream, m, A, \ - shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, pivot); \ - else \ - hipLaunchKernelGGL((LUfact_panel_kernel), grid, block, lmemsize, stream, m, n, A, \ - shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, pivot) +#define RUN_LUFACT_PANEL_BLK(DIM1, DIM2) \ + hipLaunchKernelGGL((LUfact_panel_kernel_blk), grid, block, lmemsize, stream, m, \ + A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, pivot) + +#define RUN_LUFACT_PANEL(DIM) \ + hipLaunchKernelGGL((LUfact_panel_kernel), grid, block, lmemsize, stream, m, n, A, \ + shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, pivot) // determine sizes rocblas_int blocks = batch_count; @@ -356,13 +361,69 @@ rocblas_status LUfact_panel(rocblas_handle handle, // GETF2_OPTIM_MAX_SIZE are tunned) kernel launch switch(dim) { - case 2: RUN_LUFACT_PANEL(2); break; - case 3: RUN_LUFACT_PANEL(3); break; - case 4: RUN_LUFACT_PANEL(4); break; - case 5: RUN_LUFACT_PANEL(5); break; - case 6: RUN_LUFACT_PANEL(6); break; - case 7: RUN_LUFACT_PANEL(7); break; - case 8: RUN_LUFACT_PANEL(8); break; + case 2: + switch(n) + { + case 16: RUN_LUFACT_PANEL_BLK(2, 16); break; + case 32: RUN_LUFACT_PANEL_BLK(2, 32); break; + case 64: RUN_LUFACT_PANEL_BLK(2, 64); break; + default: RUN_LUFACT_PANEL(2); + } + break; + case 3: + switch(n) + { + case 16: RUN_LUFACT_PANEL_BLK(3, 16); break; + case 32: RUN_LUFACT_PANEL_BLK(3, 32); break; + case 64: RUN_LUFACT_PANEL_BLK(3, 64); break; + default: RUN_LUFACT_PANEL(3); + } + break; + case 4: + switch(n) + { + case 16: RUN_LUFACT_PANEL_BLK(4, 16); break; + case 32: RUN_LUFACT_PANEL_BLK(4, 32); break; + case 64: RUN_LUFACT_PANEL_BLK(4, 64); break; + default: RUN_LUFACT_PANEL(4); + } + break; + case 5: + switch(n) + { + case 16: RUN_LUFACT_PANEL_BLK(5, 16); break; + case 32: RUN_LUFACT_PANEL_BLK(5, 32); break; + case 64: RUN_LUFACT_PANEL_BLK(5, 64); break; + default: RUN_LUFACT_PANEL(5); + } + break; + case 6: + switch(n) + { + case 16: RUN_LUFACT_PANEL_BLK(6, 16); break; + case 32: RUN_LUFACT_PANEL_BLK(6, 32); break; + case 64: RUN_LUFACT_PANEL_BLK(6, 64); break; + default: RUN_LUFACT_PANEL(6); + } + break; + case 7: + switch(n) + { + case 16: RUN_LUFACT_PANEL_BLK(7, 16); break; + case 32: RUN_LUFACT_PANEL_BLK(7, 32); break; + case 64: RUN_LUFACT_PANEL_BLK(7, 64); break; + default: RUN_LUFACT_PANEL(7); + } + break; + case 8: + switch(n) + { + case 16: RUN_LUFACT_PANEL_BLK(8, 16); break; + case 32: RUN_LUFACT_PANEL_BLK(8, 32); break; + case 64: RUN_LUFACT_PANEL_BLK(8, 64); break; + default: RUN_LUFACT_PANEL(8); + } + break; default: ROCSOLVER_UNREACHABLE(); } @@ -722,7 +783,7 @@ rocblas_status rocsolver_getf2_getrf_argCheck(rocblas_handle handle, return rocblas_status_continue; } -template +template rocblas_status rocsolver_getf2_template(rocblas_handle handle, const rocblas_int m, const rocblas_int n, @@ -735,7 +796,6 @@ rocblas_status rocsolver_getf2_template(rocblas_handle handle, const rocblas_stride strideP, rocblas_int* info, const rocblas_int batch_count, - const rocblas_int pivot, T* scalars, rocblas_index_value_t* work, T* pivotval, @@ -766,11 +826,11 @@ rocblas_status rocsolver_getf2_template(rocblas_handle handle, { if(m <= GETF2_MAX_THDS) return LUfact_small(handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, - info, batch_count, pivot); + info, batch_count, PIVOT); else if((m <= GETF2_OPTIM_MAX_SIZE && !ISBATCHED) || (m <= GETF2_BATCH_OPTIM_MAX_SIZE && ISBATCHED)) return LUfact_panel(handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, - info, batch_count, pivot); + info, batch_count, PIVOT); } #endif @@ -781,7 +841,7 @@ rocblas_status rocsolver_getf2_template(rocblas_handle handle, for(rocblas_int j = 0; j < dim; ++j) { - if(pivot) + if(PIVOT) // find pivot. Use Fortran 1-based indexing for the ipiv array as iamax // does that as well! rocblasCall_iamax(handle, m - j, A, shiftA + idx2D(j, j, lda), 1, @@ -790,9 +850,9 @@ rocblas_status rocsolver_getf2_template(rocblas_handle handle, // adjust pivot indices and check singularity hipLaunchKernelGGL(getf2_check_singularity, dim3(batch_count), dim3(1), 0, stream, A, shiftA, strideA, ipiv, shiftP, strideP, j, lda, pivotval, pivotidx, info, - pivot); + PIVOT); - if(pivot) + if(PIVOT) // Swap pivot row and j-th row rocsolver_laswp_template(handle, n, A, shiftA, lda, strideA, j + 1, j + 1, ipiv, shiftP, strideP, 1, batch_count); diff --git a/rocsolver/library/src/lapack/roclapack_getf2_batched.cpp b/rocsolver/library/src/lapack/roclapack_getf2_batched.cpp index fa9099e7b..507f8cd19 100644 --- a/rocsolver/library/src/lapack/roclapack_getf2_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getf2_batched.cpp @@ -1,10 +1,10 @@ /* ************************************************************************ - * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. + * Copyright (c) 2019-2021 Advanced Micro Devices, Inc. * ************************************************************************ */ #include "roclapack_getf2.hpp" -template +template rocblas_status rocsolver_getf2_batched_impl(rocblas_handle handle, const rocblas_int m, const rocblas_int n, @@ -13,8 +13,7 @@ rocblas_status rocsolver_getf2_batched_impl(rocblas_handle handle, rocblas_int* ipiv, const rocblas_stride strideP, rocblas_int* info, - const rocblas_int batch_count, - const int pivot) + const rocblas_int batch_count) { using S = decltype(std::real(T{})); @@ -25,7 +24,7 @@ rocblas_status rocsolver_getf2_batched_impl(rocblas_handle handle, // argument checking rocblas_status st - = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, pivot, batch_count); + = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, PIVOT, batch_count); if(st != rocblas_status_continue) return st; @@ -66,8 +65,8 @@ rocblas_status rocsolver_getf2_batched_impl(rocblas_handle handle, init_scalars(handle, (T*)scalars); // execution - return rocsolver_getf2_template( - handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, pivot, + return rocsolver_getf2_template( + handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, (T*)scalars, (rocblas_index_value_t*)work, (T*)pivotval, (rocblas_int*)pivotidx); } @@ -89,8 +88,8 @@ rocblas_status rocsolver_sgetf2_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, - batch_count, 1); + return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, + batch_count); } rocblas_status rocsolver_dgetf2_batched(rocblas_handle handle, @@ -103,8 +102,8 @@ rocblas_status rocsolver_dgetf2_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, - batch_count, 1); + return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, + batch_count); } rocblas_status rocsolver_cgetf2_batched(rocblas_handle handle, @@ -117,8 +116,8 @@ rocblas_status rocsolver_cgetf2_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, strideP, - info, batch_count, 1); + return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, + strideP, info, batch_count); } rocblas_status rocsolver_zgetf2_batched(rocblas_handle handle, @@ -131,8 +130,8 @@ rocblas_status rocsolver_zgetf2_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, strideP, - info, batch_count, 1); + return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, + strideP, info, batch_count); } rocblas_status rocsolver_sgetf2_npvt_batched(rocblas_handle handle, @@ -144,7 +143,8 @@ rocblas_status rocsolver_sgetf2_npvt_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, 0, info, batch_count, 0); + return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, 0, info, + batch_count); } rocblas_status rocsolver_dgetf2_npvt_batched(rocblas_handle handle, @@ -156,7 +156,8 @@ rocblas_status rocsolver_dgetf2_npvt_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, 0, info, batch_count, 0); + return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, 0, info, + batch_count); } rocblas_status rocsolver_cgetf2_npvt_batched(rocblas_handle handle, @@ -168,8 +169,8 @@ rocblas_status rocsolver_cgetf2_npvt_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, 0, info, - batch_count, 0); + return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, 0, + info, batch_count); } rocblas_status rocsolver_zgetf2_npvt_batched(rocblas_handle handle, @@ -181,8 +182,8 @@ rocblas_status rocsolver_zgetf2_npvt_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, 0, info, - batch_count, 0); + return rocsolver_getf2_batched_impl(handle, m, n, A, lda, ipiv, + 0, info, batch_count); } } // extern C diff --git a/rocsolver/library/src/lapack/roclapack_getf2_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_getf2_strided_batched.cpp index 8e75f2a2e..2c9157d72 100644 --- a/rocsolver/library/src/lapack/roclapack_getf2_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getf2_strided_batched.cpp @@ -1,10 +1,10 @@ /* ************************************************************************ - * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. + * Copyright (c) 2019-2021 Advanced Micro Devices, Inc. * ************************************************************************ */ #include "roclapack_getf2.hpp" -template +template rocblas_status rocsolver_getf2_strided_batched_impl(rocblas_handle handle, const rocblas_int m, const rocblas_int n, @@ -14,8 +14,7 @@ rocblas_status rocsolver_getf2_strided_batched_impl(rocblas_handle handle, rocblas_int* ipiv, const rocblas_stride strideP, rocblas_int* info, - const rocblas_int batch_count, - const int pivot) + const rocblas_int batch_count) { using S = decltype(std::real(T{})); @@ -26,7 +25,7 @@ rocblas_status rocsolver_getf2_strided_batched_impl(rocblas_handle handle, // argument checking rocblas_status st - = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, pivot, batch_count); + = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, PIVOT, batch_count); if(st != rocblas_status_continue) return st; @@ -64,8 +63,8 @@ rocblas_status rocsolver_getf2_strided_batched_impl(rocblas_handle handle, init_scalars(handle, (T*)scalars); // execution - return rocsolver_getf2_template( - handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, pivot, + return rocsolver_getf2_template( + handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, (T*)scalars, (rocblas_index_value_t*)work, (T*)pivotval, (rocblas_int*)pivotidx); } @@ -88,8 +87,8 @@ rocblas_status rocsolver_sgetf2_strided_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getf2_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, strideP, - info, batch_count, 1); + return rocsolver_getf2_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, + strideP, info, batch_count); } rocblas_status rocsolver_dgetf2_strided_batched(rocblas_handle handle, @@ -103,8 +102,8 @@ rocblas_status rocsolver_dgetf2_strided_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getf2_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, - strideP, info, batch_count, 1); + return rocsolver_getf2_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, + strideP, info, batch_count); } rocblas_status rocsolver_cgetf2_strided_batched(rocblas_handle handle, @@ -118,8 +117,8 @@ rocblas_status rocsolver_cgetf2_strided_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getf2_strided_batched_impl( - handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); + return rocsolver_getf2_strided_batched_impl( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count); } rocblas_status rocsolver_zgetf2_strided_batched(rocblas_handle handle, @@ -133,8 +132,8 @@ rocblas_status rocsolver_zgetf2_strided_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getf2_strided_batched_impl( - handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); + return rocsolver_getf2_strided_batched_impl( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count); } rocblas_status rocsolver_sgetf2_npvt_strided_batched(rocblas_handle handle, @@ -147,8 +146,8 @@ rocblas_status rocsolver_sgetf2_npvt_strided_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, 0, info, - batch_count, 0); + return rocsolver_getf2_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, + 0, info, batch_count); } rocblas_status rocsolver_dgetf2_npvt_strided_batched(rocblas_handle handle, @@ -161,8 +160,8 @@ rocblas_status rocsolver_dgetf2_npvt_strided_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, 0, - info, batch_count, 0); + return rocsolver_getf2_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, + 0, info, batch_count); } rocblas_status rocsolver_cgetf2_npvt_strided_batched(rocblas_handle handle, @@ -175,8 +174,8 @@ rocblas_status rocsolver_cgetf2_npvt_strided_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_strided_batched_impl( - handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count, 0); + return rocsolver_getf2_strided_batched_impl( + handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count); } rocblas_status rocsolver_zgetf2_npvt_strided_batched(rocblas_handle handle, @@ -189,8 +188,8 @@ rocblas_status rocsolver_zgetf2_npvt_strided_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getf2_strided_batched_impl( - handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count, 0); + return rocsolver_getf2_strided_batched_impl( + handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count); } } // extern C diff --git a/rocsolver/library/src/lapack/roclapack_getrf.cpp b/rocsolver/library/src/lapack/roclapack_getrf.cpp index 091f05ebb..5792e419b 100644 --- a/rocsolver/library/src/lapack/roclapack_getrf.cpp +++ b/rocsolver/library/src/lapack/roclapack_getrf.cpp @@ -1,18 +1,17 @@ /* ************************************************************************ - * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. + * Copyright (c) 2019-2021 Advanced Micro Devices, Inc. * ************************************************************************ */ #include "roclapack_getrf.hpp" -template +template rocblas_status rocsolver_getrf_impl(rocblas_handle handle, const rocblas_int m, const rocblas_int n, U A, const rocblas_int lda, rocblas_int* ipiv, - rocblas_int* info, - const int pivot) + rocblas_int* info) { using S = decltype(std::real(T{})); @@ -22,7 +21,7 @@ rocblas_status rocsolver_getrf_impl(rocblas_handle handle, // logging is missing ??? // argument checking - rocblas_status st = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, pivot); + rocblas_status st = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, PIVOT); if(st != rocblas_status_continue) return st; @@ -44,7 +43,7 @@ rocblas_status rocsolver_getrf_impl(rocblas_handle handle, size_t size_pivotval, size_pivotidx; // size to store info about singularity of each subblock size_t size_iinfo; - rocsolver_getrf_getMemorySize( + rocsolver_getrf_getMemorySize( m, n, batch_count, &size_scalars, &size_work, &size_work1, &size_work2, &size_work3, &size_work4, &size_pivotval, &size_pivotidx, &size_iinfo); @@ -77,8 +76,8 @@ rocblas_status rocsolver_getrf_impl(rocblas_handle handle, init_scalars(handle, (T*)scalars); // execution - return rocsolver_getrf_template( - handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, pivot, + return rocsolver_getrf_template( + handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, (T*)scalars, (rocblas_index_value_t*)work, work1, work2, work3, work4, (T*)pivotval, (rocblas_int*)pivotidx, (rocblas_int*)iinfo, optim_mem); } @@ -99,7 +98,7 @@ rocblas_status rocsolver_sgetrf(rocblas_handle handle, rocblas_int* ipiv, rocblas_int* info) { - return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 1); + return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_dgetrf(rocblas_handle handle, @@ -110,7 +109,7 @@ rocblas_status rocsolver_dgetrf(rocblas_handle handle, rocblas_int* ipiv, rocblas_int* info) { - return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 1); + return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_cgetrf(rocblas_handle handle, @@ -121,7 +120,7 @@ rocblas_status rocsolver_cgetrf(rocblas_handle handle, rocblas_int* ipiv, rocblas_int* info) { - return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 1); + return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_zgetrf(rocblas_handle handle, @@ -132,7 +131,7 @@ rocblas_status rocsolver_zgetrf(rocblas_handle handle, rocblas_int* ipiv, rocblas_int* info) { - return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 1); + return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_sgetrf_npvt(rocblas_handle handle, @@ -143,7 +142,7 @@ rocblas_status rocsolver_sgetrf_npvt(rocblas_handle handle, rocblas_int* info) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 0); + return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_dgetrf_npvt(rocblas_handle handle, @@ -154,7 +153,7 @@ rocblas_status rocsolver_dgetrf_npvt(rocblas_handle handle, rocblas_int* info) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 0); + return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_cgetrf_npvt(rocblas_handle handle, @@ -165,7 +164,7 @@ rocblas_status rocsolver_cgetrf_npvt(rocblas_handle handle, rocblas_int* info) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 0); + return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info); } rocblas_status rocsolver_zgetrf_npvt(rocblas_handle handle, @@ -176,7 +175,7 @@ rocblas_status rocsolver_zgetrf_npvt(rocblas_handle handle, rocblas_int* info) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info, 0); + return rocsolver_getrf_impl(handle, m, n, A, lda, ipiv, info); } } // extern C diff --git a/rocsolver/library/src/lapack/roclapack_getrf.hpp b/rocsolver/library/src/lapack/roclapack_getrf.hpp index 266585b21..4db4c2602 100644 --- a/rocsolver/library/src/lapack/roclapack_getrf.hpp +++ b/rocsolver/library/src/lapack/roclapack_getrf.hpp @@ -4,7 +4,7 @@ * Univ. of Tennessee, Univ. of California Berkeley, * Univ. of Colorado Denver and NAG Ltd.. * December 2016 - * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. + * Copyright (c) 2019-2021 Advanced Micro Devices, Inc. * ***********************************************************************/ #pragma once @@ -39,7 +39,44 @@ __global__ void getrf_check_singularity(const rocblas_int n, } } -template +template +rocblas_int get_blksize(rocblas_int dim) +{ + rocblas_int max; + rocblas_int i; + +#if ISBATCHED +#if PIVOT + std::vector size{GETRF_BLKSIZES_BATCH}; + std::vector intervals{GETRF_INTERVALS_BATCH}; + max = GETRF_NUM_INTERVALS_BATCH; +#else + std::vector size{GETRF_NPVT_BLKSIZES_BATCH}; + std::vector intervals{GETRF_NPVT_INTERVALS_BATCH}; + max = GETRF_NPVT_NUM_INTERVALS_BATCH; +#endif +#else +#if PIVOT + std::vector size{GETRF_BLKSIZES_NORMAL}; + std::vector intervals{GETRF_INTERVALS_NORMAL}; + max = GETRF_NUM_INTERVALS_NORMAL; +#else + std::vector size{GETRF_NPVT_BLKSIZES_NORMAL}; + std::vector intervals{GETRF_NPVT_INTERVALS_NORMAL}; + max = GETRF_NPVT_NUM_INTERVALS_NORMAL; +#endif +#endif + + for(i = 0; i < max; ++i) + { + if(dim < intervals[i]) + break; + } + + return size[i]; +} + +template void rocsolver_getrf_getMemorySize(const rocblas_int m, const rocblas_int n, const rocblas_int batch_count, @@ -70,7 +107,10 @@ void rocsolver_getrf_getMemorySize(const rocblas_int m, return; } - if(m < GETRF_GETF2_SWITCHSIZE || n < GETRF_GETF2_SWITCHSIZE) + rocblas_int dim = min(m, n); + rocblas_int blk = get_blksize(dim); + + if(blk == 1) { // requirements for one single GETF2 rocsolver_getf2_getMemorySize(m, n, batch_count, size_scalars, size_work, @@ -83,22 +123,20 @@ void rocsolver_getrf_getMemorySize(const rocblas_int m, } else { - rocblas_int jb = GETRF_GETF2_SWITCHSIZE; - // requirements for calling GETF2 for the sub blocks - rocsolver_getf2_getMemorySize(m, jb, batch_count, size_scalars, size_work, + rocsolver_getf2_getMemorySize(m, blk, batch_count, size_scalars, size_work, size_pivotval, size_pivotidx); // to store info about singularity of sub blocks *size_iinfo = sizeof(rocblas_int) * batch_count; // extra workspace (for calling TRSM) - rocblasCall_trsm_mem(rocblas_side_left, jb, n - jb, batch_count, size_work1, + rocblasCall_trsm_mem(rocblas_side_left, blk, n - blk, batch_count, size_work1, size_work2, size_work3, size_work4); } } -template +template rocblas_status rocsolver_getrf_template(rocblas_handle handle, const rocblas_int m, const rocblas_int n, @@ -111,7 +149,6 @@ rocblas_status rocsolver_getrf_template(rocblas_handle handle, const rocblas_stride strideP, rocblas_int* info, const rocblas_int batch_count, - const rocblas_int pivot, T* scalars, rocblas_index_value_t* work, void* work1, @@ -145,13 +182,6 @@ rocblas_status rocsolver_getrf_template(rocblas_handle handle, static constexpr bool ISBATCHED = BATCHED || STRIDED; - // if the matrix is small, use the unblocked (level-2-blas) variant of the - // algorithm - if(m < GETRF_GETF2_SWITCHSIZE || n < GETRF_GETF2_SWITCHSIZE) - return rocsolver_getf2_template(handle, m, n, A, shiftA, lda, strideA, ipiv, - shiftP, strideP, info, batch_count, pivot, - scalars, work, pivotval, pivotidx); - // everything must be executed with scalars on the host rocblas_pointer_mode old_mode; rocblas_get_pointer_mode(handle, &old_mode); @@ -164,30 +194,37 @@ rocblas_status rocsolver_getrf_template(rocblas_handle handle, rocblas_int dim = min(m, n); // total number of pivots rocblas_int jb, sizePivot; - for(rocblas_int j = 0; j < dim; j += GETRF_GETF2_SWITCHSIZE) + rocblas_int blk = get_blksize(dim); + + if(blk == 1) + return rocsolver_getf2_template(handle, m, n, A, shiftA, lda, strideA, + ipiv, shiftP, strideP, info, batch_count, + scalars, work, pivotval, pivotidx); + + for(rocblas_int j = 0; j < dim; j += blk) //dim { // Factor diagonal and subdiagonal blocks - jb = min(dim - j, GETRF_GETF2_SWITCHSIZE); // number of columns in the block + jb = min(dim - j, blk); // number of columns in the block hipLaunchKernelGGL(reset_info, gridReset, threads, 0, stream, iinfo, batch_count, 0); - rocsolver_getf2_template(handle, m - j, jb, A, shiftA + idx2D(j, j, lda), lda, - strideA, ipiv, shiftP + j, strideP, iinfo, - batch_count, pivot, scalars, work, pivotval, pivotidx); + rocsolver_getf2_template( + handle, m - j, jb, A, shiftA + idx2D(j, j, lda), lda, strideA, ipiv, shiftP + j, + strideP, iinfo, batch_count, scalars, work, pivotval, pivotidx); // adjust pivot indices and check singularity sizePivot = min(m - j, jb); // number of pivots in the block blocksPivot = (sizePivot - 1) / BLOCKSIZE + 1; gridPivot = dim3(blocksPivot, batch_count, 1); hipLaunchKernelGGL(getrf_check_singularity, gridPivot, threads, 0, stream, sizePivot, j, - ipiv, shiftP + j, strideP, iinfo, info, pivot); + ipiv, shiftP + j, strideP, iinfo, info, PIVOT); // apply interchanges to columns 1 : j-1 - if(pivot) + if(PIVOT) rocsolver_laswp_template(handle, j, A, shiftA, lda, strideA, j + 1, j + jb, ipiv, shiftP, strideP, 1, batch_count); if(j + jb < n) { - if(pivot) + if(PIVOT) { // apply interchanges to columns j+jb : n rocsolver_laswp_template(handle, (n - j - jb), A, shiftA + idx2D(0, j + jb, lda), diff --git a/rocsolver/library/src/lapack/roclapack_getrf_batched.cpp b/rocsolver/library/src/lapack/roclapack_getrf_batched.cpp index 035b98b92..4edd9986b 100644 --- a/rocsolver/library/src/lapack/roclapack_getrf_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getrf_batched.cpp @@ -1,10 +1,10 @@ /* ************************************************************************ - * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. + * Copyright (c) 2019-2021 Advanced Micro Devices, Inc. * ************************************************************************ */ #include "roclapack_getrf.hpp" -template +template rocblas_status rocsolver_getrf_batched_impl(rocblas_handle handle, rocblas_int m, rocblas_int n, @@ -13,8 +13,7 @@ rocblas_status rocsolver_getrf_batched_impl(rocblas_handle handle, rocblas_int* ipiv, const rocblas_stride strideP, rocblas_int* info, - rocblas_int batch_count, - const int pivot) + rocblas_int batch_count) { using S = decltype(std::real(T{})); @@ -25,7 +24,7 @@ rocblas_status rocsolver_getrf_batched_impl(rocblas_handle handle, // argument checking rocblas_status st - = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, pivot, batch_count); + = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, PIVOT, batch_count); if(st != rocblas_status_continue) return st; @@ -45,7 +44,7 @@ rocblas_status rocsolver_getrf_batched_impl(rocblas_handle handle, size_t size_pivotval, size_pivotidx; // size to store info about singularity of each subblock size_t size_iinfo; - rocsolver_getrf_getMemorySize( + rocsolver_getrf_getMemorySize( m, n, batch_count, &size_scalars, &size_work, &size_work1, &size_work2, &size_work3, &size_work4, &size_pivotval, &size_pivotidx, &size_iinfo); @@ -78,8 +77,8 @@ rocblas_status rocsolver_getrf_batched_impl(rocblas_handle handle, init_scalars(handle, (T*)scalars); // execution - return rocsolver_getrf_template( - handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, pivot, + return rocsolver_getrf_template( + handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, (T*)scalars, (rocblas_index_value_t*)work, work1, work2, work3, work4, (T*)pivotval, (rocblas_int*)pivotidx, (rocblas_int*)iinfo, optim_mem); } @@ -102,8 +101,8 @@ rocblas_status rocsolver_sgetrf_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, - batch_count, 1); + return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, + batch_count); } rocblas_status rocsolver_dgetrf_batched(rocblas_handle handle, @@ -116,8 +115,8 @@ rocblas_status rocsolver_dgetrf_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, - batch_count, 1); + return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, strideP, info, + batch_count); } rocblas_status rocsolver_cgetrf_batched(rocblas_handle handle, @@ -130,8 +129,8 @@ rocblas_status rocsolver_cgetrf_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, strideP, - info, batch_count, 1); + return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, + strideP, info, batch_count); } rocblas_status rocsolver_zgetrf_batched(rocblas_handle handle, @@ -144,8 +143,8 @@ rocblas_status rocsolver_zgetrf_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, strideP, - info, batch_count, 1); + return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, + strideP, info, batch_count); } rocblas_status rocsolver_sgetrf_npvt_batched(rocblas_handle handle, @@ -157,7 +156,8 @@ rocblas_status rocsolver_sgetrf_npvt_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, 0, info, batch_count, 0); + return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, 0, info, + batch_count); } rocblas_status rocsolver_dgetrf_npvt_batched(rocblas_handle handle, @@ -169,7 +169,8 @@ rocblas_status rocsolver_dgetrf_npvt_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, 0, info, batch_count, 0); + return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, 0, info, + batch_count); } rocblas_status rocsolver_cgetrf_npvt_batched(rocblas_handle handle, @@ -181,8 +182,8 @@ rocblas_status rocsolver_cgetrf_npvt_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, 0, info, - batch_count, 0); + return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, 0, + info, batch_count); } rocblas_status rocsolver_zgetrf_npvt_batched(rocblas_handle handle, @@ -194,8 +195,8 @@ rocblas_status rocsolver_zgetrf_npvt_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, 0, info, - batch_count, 0); + return rocsolver_getrf_batched_impl(handle, m, n, A, lda, ipiv, + 0, info, batch_count); } } // extern C diff --git a/rocsolver/library/src/lapack/roclapack_getrf_strided_batched.cpp b/rocsolver/library/src/lapack/roclapack_getrf_strided_batched.cpp index ccd768fad..0e55a75a1 100644 --- a/rocsolver/library/src/lapack/roclapack_getrf_strided_batched.cpp +++ b/rocsolver/library/src/lapack/roclapack_getrf_strided_batched.cpp @@ -1,10 +1,10 @@ /* ************************************************************************ - * Copyright (c) 2019-2020 Advanced Micro Devices, Inc. + * Copyright (c) 2019-2021 Advanced Micro Devices, Inc. * ************************************************************************ */ #include "roclapack_getrf.hpp" -template +template rocblas_status rocsolver_getrf_strided_batched_impl(rocblas_handle handle, const rocblas_int m, const rocblas_int n, @@ -14,8 +14,7 @@ rocblas_status rocsolver_getrf_strided_batched_impl(rocblas_handle handle, rocblas_int* ipiv, const rocblas_stride strideP, rocblas_int* info, - const rocblas_int batch_count, - const int pivot) + const rocblas_int batch_count) { using S = decltype(std::real(T{})); @@ -26,7 +25,7 @@ rocblas_status rocsolver_getrf_strided_batched_impl(rocblas_handle handle, // argument checking rocblas_status st - = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, pivot, batch_count); + = rocsolver_getf2_getrf_argCheck(handle, m, n, lda, A, ipiv, info, PIVOT, batch_count); if(st != rocblas_status_continue) return st; @@ -43,7 +42,7 @@ rocblas_status rocsolver_getrf_strided_batched_impl(rocblas_handle handle, size_t size_pivotval, size_pivotidx; // size to store info about singularity of each subblock size_t size_iinfo; - rocsolver_getrf_getMemorySize( + rocsolver_getrf_getMemorySize( m, n, batch_count, &size_scalars, &size_work, &size_work1, &size_work2, &size_work3, &size_work4, &size_pivotval, &size_pivotidx, &size_iinfo); @@ -76,8 +75,8 @@ rocblas_status rocsolver_getrf_strided_batched_impl(rocblas_handle handle, init_scalars(handle, (T*)scalars); // execution - return rocsolver_getrf_template( - handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, pivot, + return rocsolver_getrf_template( + handle, m, n, A, shiftA, lda, strideA, ipiv, shiftP, strideP, info, batch_count, (T*)scalars, (rocblas_index_value_t*)work, work1, work2, work3, work4, (T*)pivotval, (rocblas_int*)pivotidx, (rocblas_int*)iinfo, optim_mem); } @@ -101,8 +100,8 @@ rocblas_status rocsolver_sgetrf_strided_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getrf_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, strideP, - info, batch_count, 1); + return rocsolver_getrf_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, + strideP, info, batch_count); } rocblas_status rocsolver_dgetrf_strided_batched(rocblas_handle handle, @@ -116,8 +115,8 @@ rocblas_status rocsolver_dgetrf_strided_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getrf_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, - strideP, info, batch_count, 1); + return rocsolver_getrf_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, + strideP, info, batch_count); } rocblas_status rocsolver_cgetrf_strided_batched(rocblas_handle handle, @@ -131,8 +130,8 @@ rocblas_status rocsolver_cgetrf_strided_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getrf_strided_batched_impl( - handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); + return rocsolver_getrf_strided_batched_impl( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count); } rocblas_status rocsolver_zgetrf_strided_batched(rocblas_handle handle, @@ -146,8 +145,8 @@ rocblas_status rocsolver_zgetrf_strided_batched(rocblas_handle handle, rocblas_int* info, const rocblas_int batch_count) { - return rocsolver_getrf_strided_batched_impl( - handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count, 1); + return rocsolver_getrf_strided_batched_impl( + handle, m, n, A, lda, strideA, ipiv, strideP, info, batch_count); } rocblas_status rocsolver_sgetrf_npvt_strided_batched(rocblas_handle handle, @@ -160,8 +159,8 @@ rocblas_status rocsolver_sgetrf_npvt_strided_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, 0, info, - batch_count, 0); + return rocsolver_getrf_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, + 0, info, batch_count); } rocblas_status rocsolver_dgetrf_npvt_strided_batched(rocblas_handle handle, @@ -174,8 +173,8 @@ rocblas_status rocsolver_dgetrf_npvt_strided_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, 0, - info, batch_count, 0); + return rocsolver_getrf_strided_batched_impl(handle, m, n, A, lda, strideA, ipiv, + 0, info, batch_count); } rocblas_status rocsolver_cgetrf_npvt_strided_batched(rocblas_handle handle, @@ -188,8 +187,8 @@ rocblas_status rocsolver_cgetrf_npvt_strided_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_strided_batched_impl( - handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count, 0); + return rocsolver_getrf_strided_batched_impl( + handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count); } rocblas_status rocsolver_zgetrf_npvt_strided_batched(rocblas_handle handle, @@ -202,8 +201,8 @@ rocblas_status rocsolver_zgetrf_npvt_strided_batched(rocblas_handle handle, const rocblas_int batch_count) { rocblas_int* ipiv = nullptr; - return rocsolver_getrf_strided_batched_impl( - handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count, 0); + return rocsolver_getrf_strided_batched_impl( + handle, m, n, A, lda, strideA, ipiv, 0, info, batch_count); } } // extern C