From 70650eef9a14a92785937827ef3ec1940555c860 Mon Sep 17 00:00:00 2001 From: WuK Date: Fri, 23 Apr 2021 12:58:25 +0800 Subject: [PATCH] Update HPLAI_blas.cc --- src/blas/HPLAI_blas.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/blas/HPLAI_blas.cc b/src/blas/HPLAI_blas.cc index 229a61a..9c8d1c2 100644 --- a/src/blas/HPLAI_blas.cc +++ b/src/blas/HPLAI_blas.cc @@ -304,6 +304,13 @@ void blas::copy( #if defined(HPLAI_CUBLASGEMMEX_COMPUTETYPE) +static cublasOperation_t HPLAI_op2cublas(blas::Op trans) +{ + return trans == blas::Op::NoTrans ? CUBLAS_OP_N + : trans == blas::Op::Trans ? CUBLAS_OP_T + : CUBLAS_OP_C; +} + static cudaDataType_t HPLAI_GET_cudaDataType_t(float t) { return CUDA_R_32F; @@ -447,8 +454,8 @@ void blas::gemm( HPLAI_T_AFLOAT rone = HPLAI_rone; cublasGemmEx( HPLAI_DEVICE_BLASPP_QUEUE->handle(), - blas::device::op2cublas(TRANSA), - blas::device::op2cublas(TRANSB), + HPLAI_op2cublas(TRANSA), + HPLAI_op2cublas(TRANSB), M1, N1, K1,