From 314e1e398b8b6d81fa2340a853c39cd19c08b407 Mon Sep 17 00:00:00 2001 From: cjknight Date: Mon, 4 Nov 2024 23:28:42 -0600 Subject: [PATCH] v0 gemm working with cuda backend --- gpu/mini-apps/math/main.cpp | 11 +++++++++-- gpu/src/mathlib_cuda.cpp | 4 ++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/gpu/mini-apps/math/main.cpp b/gpu/mini-apps/math/main.cpp index 1260dde9..d9fb4c43 100644 --- a/gpu/mini-apps/math/main.cpp +++ b/gpu/mini-apps/math/main.cpp @@ -292,13 +292,13 @@ int main( int argc, char* argv[] ) const int lda = _NUM_COLS_A; // lead dimension of second matrix A^T const int ldc = _NUM_COLS_B; // lead dimension of result matrix C^T - cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, d_b, ldb, d_a, lda, &beta, d_c, ldc); + ml->gemm((char *) "N", (char *) "N", &m, &n, &k, &alpha, d_b, &ldb, d_a, &lda, &beta, d_c, &ldc, handle); pm->dev_barrier(); double t0 = MPI_Wtime(); for(int i=0; i<_NUM_ITERATIONS_CPU; ++i) - cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, d_b, ldb, d_a, lda, &beta, d_c, ldc); + ml->gemm((char *) "N", (char *) "N", &m, &n, &k, &alpha, d_b, &ldb, d_a, &lda, &beta, d_c, &ldc, handle); pm->dev_barrier(); t = MPI_Wtime() - t0; #endif @@ -319,6 +319,13 @@ int main( int argc, char* argv[] ) // Clean up +#if defined(_USE_GPU) +#if defined(_GPU_CUDA) + cublasDestroy(handle); + pm->dev_stream_destroy(stream); +#endif +#endif + delete ml; pm->dev_free(d_a); diff --git a/gpu/src/mathlib_cuda.cpp b/gpu/src/mathlib_cuda.cpp index 0cd4c7e3..a4aa87a3 100644 --- a/gpu/src/mathlib_cuda.cpp +++ b/gpu/src/mathlib_cuda.cpp @@ -20,9 +20,9 @@ void MATHLIB::gemm(const char * transa, const char * transb, { #ifdef _SINGLE_PRECISION - cublasSgemm(q, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc); + cublasSgemm(q, CUBLAS_OP_N, CUBLAS_OP_N, *m, *n, *k, alpha, a, *lda, b, *ldb, beta, c, *ldc); #else - cublasDgemm(q, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc); + cublasDgemm(q, CUBLAS_OP_N, CUBLAS_OP_N, *m, *n, *k, alpha, a, *lda, b, *ldb, beta, c, *ldc); #endif }