diff --git a/src/BLR/BLRBatch.cpp b/src/BLR/BLRBatch.cpp index 55c66290..5dabc764 100644 --- a/src/BLR/BLRBatch.cpp +++ b/src/BLR/BLRBatch.cpp @@ -273,52 +273,56 @@ namespace strumpack { #if defined(STRUMPACK_USE_KBLAS) auto B = tile_.size(); if (!B) return; - int maxm = 0, maxn = 0, maxminmn = 0; - std::vector mn(2*B); + int maxm = 0, maxn = 0, maxminmn = KBLAS_ARA_BLOCK_SIZE; + std::vector m_n_maxr(3*B); for (std::size_t i=0; iget()->D().rows(), n = tile_[i]->get()->D().cols(); - auto minmn = std::max(std::min(m, n), KBLAS_ARA_BLOCK_SIZE); + auto minmn = std::min(m, n); maxminmn = std::max(maxminmn, minmn); maxm = std::max(maxm, m); maxn = std::max(maxn, n); - mn[i ] = m; - mn[i+B] = n; + m_n_maxr[i ] = m; + m_n_maxr[i+B ] = n; + m_n_maxr[i+2*B] = m*n/(m+n); } std::size_t smem_size = 0; for (std::size_t i=0; iget()->D().rows()*maxminmn + tile_[i]->get()->D().cols()*maxminmn; std::size_t dmem_size = - gpu::round_up(3*B*sizeof(int)) + + gpu::round_up(5*B*sizeof(int)) + gpu::round_up(3*B*sizeof(scalar_t*)) + gpu::round_up(smem_size*sizeof(scalar_t)); auto dmem = workspace.get_device_bytes(dmem_size); auto dm = dmem.template as(); auto dn = dm + B; - auto dr = dn + B; - auto dA = gpu::aligned_ptr(dr+B); + auto dmaxr = dn + B; + auto dr = dmaxr + B; + auto dinfo = dr + B; + auto dA = gpu::aligned_ptr(dinfo+B); auto dU = dA + B; auto dV = dU + B; auto smem = gpu::aligned_ptr(dV+B); std::vector AUV(3*B); for (std::size_t i=0; iget()->D().data(); AUV[i+ B] = smem; smem += m*maxminmn; AUV[i+2*B] = smem; smem += n*maxminmn; } - gpu_check(gpu::copy_host_to_device(dm, mn.data(), 2*B)); + gpu_check(gpu::copy_host_to_device(dm, m_n_maxr.data(), 3*B)); gpu_check(gpu::copy_host_to_device(dA, AUV.data(), 3*B)); gpu::kblas::ara (handle, dm, dn, dA, dm, dU, dm, dV, dn, dr, - tol, maxm, maxn, maxminmn, KBLAS_ARA_BLOCK_SIZE, 10, 1, B); - std::vector ranks(B); + tol, maxm, maxn, dmaxr, KBLAS_ARA_BLOCK_SIZE, 10, dinfo, 1, B); + std::vector ranks(B), info(B); gpu_check(gpu::copy_device_to_host(ranks.data(), dr, B)); + gpu_check(gpu::copy_device_to_host(info.data(), dinfo, B)); for (std::size_t i=0; i** A_batch, int* lda_batch, std::complex** B_batch, int* ldb_batch, int* ranks_batch, float tol, - int max_rows, int max_cols, int max_rank, - int bs, int r, int relative, int num_ops) { + int max_rows, int max_cols, int* max_rank, + int bs, int r, int* info, int relative, int num_ops) { kblas_cara_batch (handle, rows_batch, cols_batch, (cuComplex**)M_batch, ldm_batch, (cuComplex**)A_batch, lda_batch, (cuComplex**)B_batch, ldb_batch, ranks_batch, - tol, max_rows, max_cols, max_rank, bs, r, + tol, max_rows, max_cols, max_rank, bs, r, info, handle.kblas_rand_state(), relative, num_ops); } void ara(BLASHandle& handle, int* rows_batch, int* cols_batch, @@ -112,14 +113,14 @@ namespace strumpack { std::complex** A_batch, int* lda_batch, std::complex** B_batch, int* ldb_batch, int* ranks_batch, double tol, - int max_rows, int max_cols, int max_rank, int bs, int r, - int relative, int num_ops) { + int max_rows, int max_cols, int* max_rank, + int bs, int r, int* info, int relative, int num_ops) { kblas_zara_batch (handle, rows_batch, cols_batch, (cuDoubleComplex**)M_batch, ldm_batch, (cuDoubleComplex**)A_batch, lda_batch, (cuDoubleComplex**)B_batch, ldb_batch, ranks_batch, - tol, max_rows, max_cols, max_rank, bs, r, + tol, max_rows, max_cols, max_rank, bs, r, info, handle.kblas_rand_state(), relative, num_ops); }