Skip to content

Commit

Permalink
fix (kernel): 修复attention算子中的concat
Browse files Browse the repository at this point in the history
  • Loading branch information
PanZezhong1725 committed Feb 21, 2024
1 parent e16c679 commit bfa8e9f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 19 deletions.
25 changes: 13 additions & 12 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ namespace refactor::kernel {
void const *__restrict__ value,
dim_t pageStrideI,
dim_t pageStrideO,
dim_t lineStride,
dim_t pastOffset) {

auto tid = blockIdx.x * blockDim.x + threadIdx.x,
dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO;
reinterpret_cast<float4 *>(cache)[dst] = reinterpret_cast<float4 const *>(value)[tid];
dim_t pastOffset,
dim_t n_items) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n_items) {
auto dst = tid / pageStrideI * pageStrideO + pastOffset + (tid % pageStrideI);
reinterpret_cast<float4 *>(cache)[dst] = reinterpret_cast<float4 const *>(value)[tid];
}
}
constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20;// 试出来 40MiB 是够用的

Expand Down Expand Up @@ -231,7 +232,8 @@ namespace refactor::kernel {
auto q = inputs[0];
auto k = inputs[1];
auto v = inputs[2];
auto past = *reinterpret_cast<int64_t const *>(inputs[3]);
int64_t past;
cudaMemcpy(&past, inputs[3], sizeof(int64_t), cudaMemcpyDeviceToHost);
auto attLen = info.attLen(past);
auto o = reinterpret_cast<half *>(outputs[0]);
auto kCache = reinterpret_cast<half *>(outputs[1]);
Expand All @@ -242,19 +244,18 @@ namespace refactor::kernel {
auto itemsPerLine = info.headDim * sizeof(half) / sizeof(float4);
auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine;
auto blocks = (threads + 1023) / 1024;
concatCache<<<blocks, 1024, 0, stream>>>(
kCache, k,
info.seqLen * itemsPerLine,
info.cacheLen * itemsPerLine,
itemsPerLine,
past * itemsPerLine);
past * itemsPerLine,
threads);
concatCache<<<blocks, 1024, 0, stream>>>(
vCache, v,
info.seqLen * itemsPerLine,
info.cacheLen * itemsPerLine,
itemsPerLine,
past * itemsPerLine);
past * itemsPerLine,
threads);
}
MatrixDescriptor
q_(MatrixLayout{
Expand Down
12 changes: 5 additions & 7 deletions src/08-01llm/src/operators/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ namespace refactor::llm {
if (pastSeqLen.dataType != DataType::I64 || pastSeqLen.shape != Shape{DimExpr(1)}) {
return Err(InferError(ERROR_MSG("Past seqlen error")));
}
auto pastSeqLenVal = pastSeqLen.data->get<int64_t>()[0];
if (maxSeqLen <= 0) {
auto pastSeqLenVal = pastSeqLen.data->get<int64_t>()[0];
return outputs(pastSeqLenVal + seqlen);
} else if (maxSeqLen >= pastSeqLenVal + seqlen) {
} else if (maxSeqLen >= 1 + seqlen) {
return outputs(maxSeqLen);
} else {
return Err(InferError(ERROR_MSG("max_seq_len must not less than seqlen")));
Expand All @@ -94,7 +94,6 @@ namespace refactor::llm {
if (pastSeqLen.dataType != DataType::I64 || pastSeqLen.shape != Shape{DimExpr(1)}) {
return Err(InferError(ERROR_MSG("Past seqlen error")));
}
auto pastSeqLenVal = pastSeqLen.data->get<int64_t>()[0];

auto const &kCahce = inputs[4],
&vCache = inputs[5];
Expand All @@ -107,15 +106,14 @@ namespace refactor::llm {
kCahce.shape[3] != kvShape[3] ||
kCahce.shape[0] != kvShape[0] ||
kCahce.shape[2] != kvShape[2] ||
kCahce.shape[3] != kvShape[3] ||
pastSeqLenVal < kCacheSeqLen ||
pastSeqLenVal < vCacheSeqLen) {
kCahce.shape[3] != kvShape[3]) {
return Err(InferError(ERROR_MSG("KV cache error")));
}

if (maxSeqLen <= 0) {
auto pastSeqLenVal = pastSeqLen.data->get<int64_t>()[0];
return outputs(pastSeqLenVal + seqlen);
} else if (maxSeqLen >= pastSeqLenVal + seqlen) {
} else if (maxSeqLen >= 1 + seqlen) {
return outputs(maxSeqLen);
} else {
return Err(InferError(ERROR_MSG("max_seq_len must not less than seqlen")));
Expand Down

0 comments on commit bfa8e9f

Please sign in to comment.