Skip to content

Commit

Permalink
feat(kernel): 支持 kv cache attention
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 2, 2024
1 parent 10e38f5 commit 60cebc8
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/04kernel/include/kernel/attributes/attention_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace refactor::kernel {

dim_t attLen(dim_t pastSeqLen) const noexcept;
size_t attSize(dim_t pastSeqLen) const noexcept;
size_t maxAttSize() const noexcept;
};

}// namespace refactor::kernel
Expand Down
4 changes: 4 additions & 0 deletions src/04kernel/src/attributes/attention_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ namespace refactor::kernel {
return batch * nHead * seqLen * attLen(pastSeqLen) * dataType.size();
}

size_t AttentionInfo::maxAttSize() const noexcept {
return batch * nHead * seqLen * (cacheLen ? cacheLen : seqLen) * dataType.size();
}

}// namespace refactor::kernel
154 changes: 151 additions & 3 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ namespace refactor::kernel {
}
}

static __global__ void concatCache(
void *__restrict__ cache,
void const *__restrict__ value,
dim_t pageStrideI,
dim_t pageStrideO,
dim_t lineStride,
dim_t pastOffset) {

auto tid = blockIdx.x * blockDim.x + threadIdx.x,
dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO;
reinterpret_cast<float4 *>(cache)[dst] = reinterpret_cast<float4 const *>(value)[tid];
}
constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20;// 试出来 40MiB 是够用的

RoutineWorkspace K::lower(Resources &res) const {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;

Expand Down Expand Up @@ -125,8 +139,8 @@ namespace refactor::kernel {
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.seqLen),
}) {
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att);
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q);
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att, DYNAMIC_WORKSPACE_SIZE);
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q, DYNAMIC_WORKSPACE_SIZE);
algoQK = algoQK_;
algoAV = algoAV_;
workspaceSizeQK = workspaceSizeQK_;
Expand Down Expand Up @@ -187,12 +201,146 @@ namespace refactor::kernel {
&d->algoAV,
workspaceAV, d->workspaceSizeAV,
stream);
};
}
};
return {std::move(routine), workspaceSize};
}
TODO("");
}
if (info.concatCache && !info.resetCache) {
if (info.nHead == info.nKVHead) {
// RAII for closure
struct Descriptors {
MatMulDescriptor mul;
Descriptors(AttentionInfo info)
: mul(computeTypeConvert(info.dataType),
dataTypeConvert(info.dataType)) {}
};
auto const &context = *res.fetchOrStore<CublasLtContext>();
auto d = std::make_shared<Descriptors>(info);
auto attentionSize = info.maxAttSize();
auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize;
auto routine = [d = std::move(d), info = this->info]//
(Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;
auto q = inputs[0];
auto k = inputs[1];
auto v = inputs[2];
auto past = *reinterpret_cast<int64_t const *>(inputs[3]);
auto attLen = info.attLen(past);
auto o = reinterpret_cast<half *>(outputs[0]);
auto kCache = reinterpret_cast<half *>(outputs[1]);
auto vCache = reinterpret_cast<half *>(outputs[2]);
auto att = reinterpret_cast<half *>(reinterpret_cast<uint8_t *>(workspace) + DYNAMIC_WORKSPACE_SIZE);
auto stream = cudaStreamLegacy;
{
auto itemsPerLine = info.headDim * sizeof(half) / sizeof(float4);
auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine;
auto blocks = (threads + 1023) / 1024;
concatCache<<<blocks, 1024, 0, stream>>>(
kCache, k,
info.seqLen * itemsPerLine,
info.cacheLen * itemsPerLine,
itemsPerLine,
past * itemsPerLine);
concatCache<<<blocks, 1024, 0, stream>>>(
vCache, v,
info.seqLen * itemsPerLine,
info.cacheLen * itemsPerLine,
itemsPerLine,
past * itemsPerLine);
}
MatrixDescriptor
q_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
k_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.headDim),
.cols = static_cast<uint64_t>(attLen),
.majorStride = static_cast<int64_t>(info.headDim),
.order = COL_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.cacheLen * info.headDim),
}),
v_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(attLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.cacheLen * info.headDim),
}),
att_(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(attLen),
.majorStride = static_cast<int64_t>(info.cacheLen),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.cacheLen * info.seqLen),
});
{
auto [algo, workspaceSize] = tune(
handle, d->mul,
q_, k_, att_,
DYNAMIC_WORKSPACE_SIZE);
half alpha = rsqrtf(info.headDim), beta = 0;
cublasLtMatmul(
handle, d->mul.get(),
&alpha,
q, q_.get(),
k, k_.get(),
&beta,
att, att_.get(),
att, att_.get(),
&algo,
workspace, workspaceSize,
stream);
}
softmax<<<dim3(info.batch * info.nHead, info.seqLen),
std::min(1024u, attLen),
attLen * sizeof(float),
stream>>>(
att, AttentionCausualMask(), attLen, info.cacheLen);
{
auto [algo, workspaceSize] = tune(
handle, d->mul,
att_, v_, q_,
DYNAMIC_WORKSPACE_SIZE);
half alpha = 1, beta = 0;
cublasLtMatmul(
handle, d->mul.get(),
&alpha,
att, att_.get(),
v, v_.get(),
&beta,
o, q_.get(),
o, q_.get(),
&algo,
workspace, workspaceSize,
stream);
}
};
return {std::move(routine), workspaceSize};
}
TODO("");
}
TODO("");
}
Expand Down
8 changes: 4 additions & 4 deletions src/04kernel/src/utilities/cuda/cublaslt_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,23 @@ namespace refactor::kernel::cublas {
MatMulDescriptor const &matmul,
MatrixDescriptor const &a,
MatrixDescriptor const &b,
MatrixDescriptor const &c) {
MatrixDescriptor const &c,
uint64_t maxWorkspace) {

int device;
CUDA_ASSERT(cudaGetDevice(&device));
cudaDeviceProp prop;
CUDA_ASSERT(cudaGetDeviceProperties(&prop, device));

auto workspace = std::numeric_limits<uint64_t>::max();
uint32_t alignment = prop.textureAlignment;

cublasLtMatmulPreference_t preference;
CUBLASLT_ASSERT(cublasLtMatmulPreferenceCreate(&preference));
CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace,
sizeof(workspace)));
&maxWorkspace,
sizeof(maxWorkspace)));
CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
preference,
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
Expand Down
3 changes: 2 additions & 1 deletion src/04kernel/src/utilities/cuda/cublaslt_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ namespace refactor::kernel::cublas {
MatMulDescriptor const &,
MatrixDescriptor const &,
MatrixDescriptor const &,
MatrixDescriptor const &);
MatrixDescriptor const &,
uint64_t);

}// namespace refactor::kernel::cublas

Expand Down

0 comments on commit 60cebc8

Please sign in to comment.