diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 4e9902de83..101759990d 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -2444,3 +2444,100 @@ function xt_trsm(side::Char, uplo::Char, transa::Char, diag::Char, alpha::Number # TODO: better way to perform synchronous copy xt_trsm!(side, uplo, transa, diag, alpha, A, @sync(copy(B))) end + +# TODO: use https://docs.nvidia.com/cuda/cublas/#cublasltmatmul for a more robust +# computeType mapping. Currently no one uses Lux with weird type combinations so we +# don't need to worry about it too much and just fall back to the generic +# implementation +# Computes C = act(α * A * B + β * C + bias) +# intermediates can be stored in `aux` if provided. +# If we fail to find and appropriate algorithm and need to terminate, we return -1 +function gemmBiasCublasLt!(transA::Char, transB::Char, transC::Char, alpha::Number, + A::StridedCuMatrix{aT}, B::StridedCuMatrix{bT}, beta::Number, + C::StridedCuMatrix{cT}, σ::F, + bias::Union{Nothing, StridedCuVector}, + aux::Union{Nothing, StridedCuMatrix} = nothing) where {F, aT, bT, cT} + m, n = size(C) + k = size(A, 2) + + # TODO: size check for the bias term + # TODO: General size check + + operationDesc = Ref{cublasLtMatmulDesc_t}() + + ## While querying the compute type, promote the types + computeType = gemmExComputeType(cT, aT, bT, m, k, n) + computeType === nothing && return -1 + dataType = convert(CUDA.cudaDataType, yT) + cublasLtMatmulDescCreate(operationDesc, computeType, dataType) + + # Set the matrix descriptors + Atransop = transA == 'N' ? CUBLAS_OP_N : CUBLAS_OP_T + Btransop = transB == 'N' ? CUBLAS_OP_N : CUBLAS_OP_T + Ctransop = transC == 'N' ? CUBLAS_OP_N : CUBLAS_OP_T + + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + Ref{cublasOperation_t}(Atransop), sizeof(Atransop)) + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + Ref{cublasOperation_t}(Btransop), sizeof(Btransop)) + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSC, + Ref{cublasOperation_t}(Ctransop), sizeof(Ctransop)) + + # Decide on the epilogue + # epilogue, activation_fused = __epilogue_act(σ, b, aux) + # CUBLAS.cublasLtMatmulDescSetAttribute( + # operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE, + # Ref{CUBLAS.cublasLtEpilogue_t}(epilogue), sizeof(epilogue)) + + # # We have a bias so set the bias pointer + # if b !== nothing + # bias_ptr = Ref{CuPtr{Cvoid}}(pointer(b)) + # CUBLAS.cublasLtMatmulDescSetAttribute( + # operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_BIAS_POINTER, + # bias_ptr, sizeof(bias_ptr)) + # end + + # if aux !== nothing + # aux_ptr = Ref{CuPtr{Cvoid}}(pointer(aux)) + # CUBLAS.cublasLtMatmulDescSetAttribute( + # operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + # aux_ptr, sizeof(aux_ptr)) + # ldaux = max(1, stride(aux, 2)) + # CUBLAS.cublasLtMatmulDescSetAttribute( + # operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + # Ref{Csize_t}(ldaux), sizeof(ldaux)) + # end + + # Create the Matrix Layouts + Adesc = Ref{cublasLtMatrixLayout_t}() + Bdesc = Ref{cublasLtMatrixLayout_t}() + Cdesc = Ref{cublasLtMatrixLayout_t}() + cublasLtMatrixLayoutCreate(Adesc, convert(CUDA.cudaDataType, aT), m, k, max(1, stride(A, 2))) + cublasLtMatrixLayoutCreate(Bdesc, convert(CUDA.cudaDataType, bT), k, n, max(1, stride(B, 2))) + cublasLtMatrixLayoutCreate(Cdesc, convert(CUDA.cudaDataType, cT), m, n, max(1, stride(C, 2))) + + # Create the preference + preference = Ref{cublasLtMatmulPreference_t}() + cublasLtMatmulPreferenceCreate(preference) + + # Create the light handle + lightHandle = Ref{cublasLtHandle_t}() + cublasLtCreate(lightHandle) + + # Search for the best algorithm + heuristic = Ref{cublasLtMatmulHeuristicResult_t}() + returnedResults = Ref{Cint}(0) + cublasLtMatmulHeuristicSearch(lightHandle[], operationDesc[], preference[], Adesc[], + Bdesc[], Cdesc[], Cdesc[], preference[], 1, heuristic, + returnedResults) + + returnedResults[] == 0 || return -1 + + cublasLtMatmul(lightHandle[], operationDesc[], Ref{typeof(α)}(α), A, Adesc[], B, + Bdesc[], Ref{typeof(β)}(β), C, Cdesc[], C, Cdesc[], + Ref(heuristic[].algo), CU_NULL, 0, CUDA.stream()) + + # !activation_fused && (C .= σ.(C)) + + return 0 +end