From e16c6791ff4cfae247f95b49de2f7d1aa8083b23 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 16:37:09 +0800 Subject: [PATCH] =?UTF-8?q?feat(kernel):=20=E6=94=AF=E6=8C=81=20kv=20cache?= =?UTF-8?q?=20attention?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../kernel/attributes/attention_info.h | 1 + src/04kernel/src/attributes/attention_info.cc | 4 + .../src/kernels/attention/cuda_kernel.cu | 154 +++++++++++++++++- .../src/utilities/cuda/cublaslt_utils.cu | 8 +- .../src/utilities/cuda/cublaslt_utils.cuh | 3 +- 5 files changed, 162 insertions(+), 8 deletions(-) diff --git a/src/04kernel/include/kernel/attributes/attention_info.h b/src/04kernel/include/kernel/attributes/attention_info.h index 9cd64a56..38677681 100644 --- a/src/04kernel/include/kernel/attributes/attention_info.h +++ b/src/04kernel/include/kernel/attributes/attention_info.h @@ -12,6 +12,7 @@ namespace refactor::kernel { dim_t attLen(dim_t pastSeqLen) const noexcept; size_t attSize(dim_t pastSeqLen) const noexcept; + size_t maxAttSize() const noexcept; }; }// namespace refactor::kernel diff --git a/src/04kernel/src/attributes/attention_info.cc b/src/04kernel/src/attributes/attention_info.cc index c16c59fa..a867fd3f 100644 --- a/src/04kernel/src/attributes/attention_info.cc +++ b/src/04kernel/src/attributes/attention_info.cc @@ -10,4 +10,8 @@ namespace refactor::kernel { return batch * nHead * seqLen * attLen(pastSeqLen) * dataType.size(); } + size_t AttentionInfo::maxAttSize() const noexcept { + return batch * nHead * seqLen * (cacheLen ? cacheLen : seqLen) * dataType.size(); + } + }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 89ce72f4..da0a69ad 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -70,6 +70,20 @@ namespace refactor::kernel { } } + static __global__ void concatCache( + void *__restrict__ cache, + void const *__restrict__ value, + dim_t pageStrideI, + dim_t pageStrideO, + dim_t lineStride, + dim_t pastOffset) { + + auto tid = blockIdx.x * blockDim.x + threadIdx.x, + dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO; + reinterpret_cast(cache)[dst] = reinterpret_cast(value)[tid]; + } + constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20;// 试出来 40MiB 是够用的 + RoutineWorkspace K::lower(Resources &res) const { auto handle = res.fetchOrStore()->handle; @@ -125,8 +139,8 @@ namespace refactor::kernel { .batchCount = static_cast(info.batch * info.nHead), .batchStride = static_cast(info.seqLen * info.seqLen), }) { - auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att); - auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q); + auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att, DYNAMIC_WORKSPACE_SIZE); + auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q, DYNAMIC_WORKSPACE_SIZE); algoQK = algoQK_; algoAV = algoAV_; workspaceSizeQK = workspaceSizeQK_; @@ -187,12 +201,146 @@ namespace refactor::kernel { &d->algoAV, workspaceAV, d->workspaceSizeAV, stream); - }; + } }; return {std::move(routine), workspaceSize}; } + TODO(""); } + if (info.concatCache && !info.resetCache) { + if (info.nHead == info.nKVHead) { + + // RAII for closure + struct Descriptors { + MatMulDescriptor mul; + + Descriptors(AttentionInfo info) + : mul(computeTypeConvert(info.dataType), + dataTypeConvert(info.dataType)) {} + }; + + auto const &context = *res.fetchOrStore(); + auto d = std::make_shared(info); + auto attentionSize = info.maxAttSize(); + auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize; + + auto routine = [d = std::move(d), info = this->info]// + (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { + auto handle = res.fetchOrStore()->handle; + auto q = inputs[0]; + auto k = inputs[1]; + auto v = inputs[2]; + auto past = *reinterpret_cast(inputs[3]); + auto attLen = info.attLen(past); + auto o = reinterpret_cast(outputs[0]); + auto kCache = reinterpret_cast(outputs[1]); + auto vCache = reinterpret_cast(outputs[2]); + auto att = reinterpret_cast(reinterpret_cast(workspace) + DYNAMIC_WORKSPACE_SIZE); + auto stream = cudaStreamLegacy; + { + auto itemsPerLine = info.headDim * sizeof(half) / sizeof(float4); + auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine; + auto blocks = (threads + 1023) / 1024; + + concatCache<<>>( + kCache, k, + info.seqLen * itemsPerLine, + info.cacheLen * itemsPerLine, + itemsPerLine, + past * itemsPerLine); + concatCache<<>>( + vCache, v, + info.seqLen * itemsPerLine, + info.cacheLen * itemsPerLine, + itemsPerLine, + past * itemsPerLine); + } + MatrixDescriptor + q_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.headDim), + }), + k_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.headDim), + .cols = static_cast(attLen), + .majorStride = static_cast(info.headDim), + .order = COL_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.headDim), + }), + v_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(attLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.headDim), + }), + att_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(attLen), + .majorStride = static_cast(info.cacheLen), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.seqLen), + }); + { + auto [algo, workspaceSize] = tune( + handle, d->mul, + q_, k_, att_, + DYNAMIC_WORKSPACE_SIZE); + half alpha = rsqrtf(info.headDim), beta = 0; + cublasLtMatmul( + handle, d->mul.get(), + &alpha, + q, q_.get(), + kCache, k_.get(), + &beta, + att, att_.get(), + att, att_.get(), + &algo, + workspace, workspaceSize, + stream); + } + softmax<<>>( + att, AttentionCausualMask(), attLen, info.cacheLen); + { + auto [algo, workspaceSize] = tune( + handle, d->mul, + att_, v_, q_, + DYNAMIC_WORKSPACE_SIZE); + half alpha = 1, beta = 0; + cublasLtMatmul( + handle, d->mul.get(), + &alpha, + att, att_.get(), + vCache, v_.get(), + &beta, + o, q_.get(), + o, q_.get(), + &algo, + workspace, workspaceSize, + stream); + } + }; + + return {std::move(routine), workspaceSize}; + } + TODO(""); + } + TODO(""); } diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu index 6fc7e717..ab797e8f 100644 --- a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu @@ -101,14 +101,14 @@ namespace refactor::kernel::cublas { MatMulDescriptor const &matmul, MatrixDescriptor const &a, MatrixDescriptor const &b, - MatrixDescriptor const &c) { + MatrixDescriptor const &c, + uint64_t maxWorkspace) { int device; CUDA_ASSERT(cudaGetDevice(&device)); cudaDeviceProp prop; CUDA_ASSERT(cudaGetDeviceProperties(&prop, device)); - auto workspace = std::numeric_limits::max(); uint32_t alignment = prop.textureAlignment; cublasLtMatmulPreference_t preference; @@ -116,8 +116,8 @@ namespace refactor::kernel::cublas { CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace, - sizeof(workspace))); + &maxWorkspace, + sizeof(maxWorkspace))); CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh index ccaad7ec..33de075a 100644 --- a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh @@ -68,7 +68,8 @@ namespace refactor::kernel::cublas { MatMulDescriptor const &, MatrixDescriptor const &, MatrixDescriptor const &, - MatrixDescriptor const &); + MatrixDescriptor const &, + uint64_t); }// namespace refactor::kernel::cublas