Skip to content

Commit

Permalink
Updates for the new KBLAS ARA interface
Browse files Browse the repository at this point in the history
  • Loading branch information
pghysels committed Sep 22, 2023
1 parent dd455f2 commit 779d058
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 31 deletions.
32 changes: 18 additions & 14 deletions src/BLR/BLRBatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,52 +273,56 @@ namespace strumpack {
#if defined(STRUMPACK_USE_KBLAS)
auto B = tile_.size();
if (!B) return;
int maxm = 0, maxn = 0, maxminmn = 0;
std::vector<int> mn(2*B);
int maxm = 0, maxn = 0, maxminmn = KBLAS_ARA_BLOCK_SIZE;
std::vector<int> m_n_maxr(3*B);
for (std::size_t i=0; i<B; i++) {
int m = tile_[i]->get()->D().rows(),
n = tile_[i]->get()->D().cols();
auto minmn = std::max(std::min(m, n), KBLAS_ARA_BLOCK_SIZE);
auto minmn = std::min(m, n);
maxminmn = std::max(maxminmn, minmn);
maxm = std::max(maxm, m);
maxn = std::max(maxn, n);
mn[i ] = m;
mn[i+B] = n;
m_n_maxr[i ] = m;
m_n_maxr[i+B ] = n;
m_n_maxr[i+2*B] = m*n/(m+n);
}
std::size_t smem_size = 0;
for (std::size_t i=0; i<B; i++)
smem_size += tile_[i]->get()->D().rows()*maxminmn +
tile_[i]->get()->D().cols()*maxminmn;
std::size_t dmem_size =
gpu::round_up(3*B*sizeof(int)) +
gpu::round_up(5*B*sizeof(int)) +
gpu::round_up(3*B*sizeof(scalar_t*)) +
gpu::round_up(smem_size*sizeof(scalar_t));
auto dmem = workspace.get_device_bytes(dmem_size);
auto dm = dmem.template as<int>();
auto dn = dm + B;
auto dr = dn + B;
auto dA = gpu::aligned_ptr<scalar_t*>(dr+B);
auto dmaxr = dn + B;
auto dr = dmaxr + B;
auto dinfo = dr + B;
auto dA = gpu::aligned_ptr<scalar_t*>(dinfo+B);
auto dU = dA + B;
auto dV = dU + B;
auto smem = gpu::aligned_ptr<scalar_t>(dV+B);
std::vector<scalar_t*> AUV(3*B);
for (std::size_t i=0; i<B; i++) {
auto m = mn[i], n = mn[i+B];
auto m = m_n_maxr[i], n = m_n_maxr[i+B];
AUV[i ] = tile_[i]->get()->D().data();
AUV[i+ B] = smem; smem += m*maxminmn;
AUV[i+2*B] = smem; smem += n*maxminmn;
}
gpu_check(gpu::copy_host_to_device(dm, mn.data(), 2*B));
gpu_check(gpu::copy_host_to_device(dm, m_n_maxr.data(), 3*B));
gpu_check(gpu::copy_host_to_device(dA, AUV.data(), 3*B));
gpu::kblas::ara
(handle, dm, dn, dA, dm, dU, dm, dV, dn, dr,
tol, maxm, maxn, maxminmn, KBLAS_ARA_BLOCK_SIZE, 10, 1, B);
std::vector<int> ranks(B);
tol, maxm, maxn, dmaxr, KBLAS_ARA_BLOCK_SIZE, 10, dinfo, 1, B);
std::vector<int> ranks(B), info(B);
gpu_check(gpu::copy_device_to_host(ranks.data(), dr, B));
gpu_check(gpu::copy_device_to_host(info.data(), dinfo, B));
for (std::size_t i=0; i<B; i++) {
auto rank = ranks[i], m = mn[i], n = mn[i+B];
auto rank = ranks[i], m = m_n_maxr[i], n = m_n_maxr[i+B];
STRUMPACK_FLOPS(blas::ara_flops(m, n, rank, 10));
if (rank*(m+n) < m*n) {
if (info[i] == KBLAS_Success) {
auto dA = AUV[i];
DenseMW_t tU(m, rank, dA, m),
tV(rank, n, dA+m*rank, rank);
Expand Down
35 changes: 18 additions & 17 deletions src/dense/KBLASWrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@
#include "DenseMatrix.hpp"
#include "CUDAWrapper.hpp"

#include "kblas.h"
#include "kblas_operators.h"
#include "batch_geqp.h"
#include "batch_qr.h"
#include "batch_ara.h"
#include "kblas.h"
#include "kblas_defs.h"

namespace strumpack {
namespace gpu {
Expand Down Expand Up @@ -71,55 +70,57 @@ namespace strumpack {
void ara(BLASHandle& handle, int* rows_batch, int* cols_batch,
float** M_batch, int* ldm_batch,
float** A_batch, int* lda_batch,
float** B_batch, int* ldb_batch, int* ranks_batch,
float tol, int max_rows, int max_cols, int max_rank,
int bs, int r, int relative, int num_ops) {
float** B_batch, int* ldb_batch,
int* ranks_batch, float tol,
int max_rows, int max_cols, int* max_rank,
int bs, int r, int* info, int relative, int num_ops) {
kblas_sara_batch
(handle, rows_batch, cols_batch, M_batch, ldm_batch,
A_batch, lda_batch, B_batch, ldb_batch, ranks_batch,
tol, max_rows, max_cols, max_rank, bs, r,
tol, max_rows, max_cols, max_rank, bs, r, info,
handle.kblas_rand_state(), relative, num_ops);
}
void ara(BLASHandle& handle, int* rows_batch, int* cols_batch,
double** M_batch, int* ldm_batch,
double** A_batch, int* lda_batch,
double** B_batch, int* ldb_batch, int* ranks_batch,
double tol, int max_rows, int max_cols, int max_rank,
int bs, int r, int relative, int num_ops) {
double** B_batch, int* ldb_batch,
int* ranks_batch, double tol,
int max_rows, int max_cols, int* max_rank,
int bs, int r, int* info, int relative, int num_ops) {
kblas_dara_batch
(handle, rows_batch, cols_batch, M_batch, ldm_batch,
A_batch, lda_batch, B_batch, ldb_batch, ranks_batch,
tol, max_rows, max_cols, max_rank, bs, r,
tol, max_rows, max_cols, max_rank, bs, r, info,
handle.kblas_rand_state(), relative, num_ops);
}
void ara(BLASHandle& handle, int* rows_batch, int* cols_batch,
std::complex<float>** M_batch, int* ldm_batch,
std::complex<float>** A_batch, int* lda_batch,
std::complex<float>** B_batch, int* ldb_batch,
int* ranks_batch, float tol,
int max_rows, int max_cols, int max_rank,
int bs, int r, int relative, int num_ops) {
int max_rows, int max_cols, int* max_rank,
int bs, int r, int* info, int relative, int num_ops) {
kblas_cara_batch
(handle, rows_batch, cols_batch,
(cuComplex**)M_batch, ldm_batch,
(cuComplex**)A_batch, lda_batch,
(cuComplex**)B_batch, ldb_batch, ranks_batch,
tol, max_rows, max_cols, max_rank, bs, r,
tol, max_rows, max_cols, max_rank, bs, r, info,
handle.kblas_rand_state(), relative, num_ops);
}
void ara(BLASHandle& handle, int* rows_batch, int* cols_batch,
std::complex<double>** M_batch, int* ldm_batch,
std::complex<double>** A_batch, int* lda_batch,
std::complex<double>** B_batch, int* ldb_batch,
int* ranks_batch, double tol,
int max_rows, int max_cols, int max_rank, int bs, int r,
int relative, int num_ops) {
int max_rows, int max_cols, int* max_rank,
int bs, int r, int* info, int relative, int num_ops) {
kblas_zara_batch
(handle, rows_batch, cols_batch,
(cuDoubleComplex**)M_batch, ldm_batch,
(cuDoubleComplex**)A_batch, lda_batch,
(cuDoubleComplex**)B_batch, ldb_batch, ranks_batch,
tol, max_rows, max_cols, max_rank, bs, r,
tol, max_rows, max_cols, max_rank, bs, r, info,
handle.kblas_rand_state(), relative, num_ops);
}

Expand Down

0 comments on commit 779d058

Please sign in to comment.