Skip to content

Commit

Permalink
Update HPLAI_blas.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
WuK authored Apr 23, 2021
1 parent 62aafcf commit 70650ee
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/blas/HPLAI_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,13 @@ void blas::copy<float, float>(

#if defined(HPLAI_CUBLASGEMMEX_COMPUTETYPE)

static cublasOperation_t HPLAI_op2cublas(blas::Op trans)
{
return trans == blas::Op::NoTrans ? CUBLAS_OP_N
: trans == blas::Op::Trans ? CUBLAS_OP_T
: CUBLAS_OP_C;
}

static cudaDataType_t HPLAI_GET_cudaDataType_t(float t)
{
return CUDA_R_32F;
Expand Down Expand Up @@ -447,8 +454,8 @@ void blas::gemm<HPLAI_T_AFLOAT, HPLAI_T_AFLOAT, HPLAI_T_AFLOAT>(
HPLAI_T_AFLOAT rone = HPLAI_rone;
cublasGemmEx(
HPLAI_DEVICE_BLASPP_QUEUE->handle(),
blas::device::op2cublas(TRANSA),
blas::device::op2cublas(TRANSB),
HPLAI_op2cublas(TRANSA),
HPLAI_op2cublas(TRANSB),
M1,
N1,
K1,
Expand Down

0 comments on commit 70650ee

Please sign in to comment.