forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkernel_reduction.cuh
63 lines (59 loc) · 3.23 KB
/
kernel_reduction.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// Copyright 2024 FP6-LLM authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_reduction.cuh
/***************************************************************************
* Copyright 2023 The FLash-LLM Authors. All rights reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
***************************************************************************/
// Used for the reduction of result matrix if Split-K is used
// Reduction_Workspace: (Split_K, M_Global, N_Global), column major
// C: (M_Global, N_Global), column major
// Each thread deals with 8 output elements, each elements is the sum of Split_K elements
// Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> 256 float per warp
// Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread (128bit) -> 256 half per warp
// GridSize = (M_Global*N_Global) / 256
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#define REDUCTION_ELEMENT_PER_THREADBLOCK 256
#define HALF_PER_128BIT 8
__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K)
{
half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;
float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x;
half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT;
float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT;
// Initializing Thread-Local Results
float Results[HALF_PER_128BIT];
#pragma unroll
for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f;
// Reduction
for (int i = 0; i < Split_K; i++) {
#pragma unroll
for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j];
THREAD_GPTR_R += M_Global * N_Global;
}
// Writing to global memory
#pragma unroll
for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]);
}