diff --git a/src/blas/HPLAI_blas.cc b/src/blas/HPLAI_blas.cc index 229a61a..9c8d1c2 100644 --- a/src/blas/HPLAI_blas.cc +++ b/src/blas/HPLAI_blas.cc @@ -304,6 +304,13 @@ void blas::copy( #if defined(HPLAI_CUBLASGEMMEX_COMPUTETYPE) +static cublasOperation_t HPLAI_op2cublas(blas::Op trans) +{ + return trans == blas::Op::NoTrans ? CUBLAS_OP_N + : trans == blas::Op::Trans ? CUBLAS_OP_T + : CUBLAS_OP_C; +} + static cudaDataType_t HPLAI_GET_cudaDataType_t(float t) { return CUDA_R_32F; @@ -447,8 +454,8 @@ void blas::gemm( HPLAI_T_AFLOAT rone = HPLAI_rone; cublasGemmEx( HPLAI_DEVICE_BLASPP_QUEUE->handle(), - blas::device::op2cublas(TRANSA), - blas::device::op2cublas(TRANSB), + HPLAI_op2cublas(TRANSA), + HPLAI_op2cublas(TRANSB), M1, N1, K1,