Skip to content

Commit

Permalink
Vendor punica kernels (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Nov 26, 2023
1 parent 8c8109c commit 611dfd1
Show file tree
Hide file tree
Showing 40 changed files with 8,057 additions and 20 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
submodules: recursive

- name: Docker meta
id: meta
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "server/punica_kernels/third_party/cutlass"]
path = server/punica_kernels/third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
10 changes: 4 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,13 @@ RUN make build-vllm
# Build punica CUDA kernels
FROM kernel-builder as punica-builder

RUN /opt/conda/bin/conda install packaging

WORKDIR /usr/src

COPY server/Makefile-punica Makefile
COPY server/punica_kernels/ .

ENV TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX"
# Build specific version of punica
RUN make build-punica
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX"
RUN python setup.py build

# Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
Expand Down Expand Up @@ -195,7 +193,7 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

# Copy builds artifacts from punica builder
COPY --from=punica-builder /usr/src/punica/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=punica-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

# Install flash-attention dependencies
RUN pip install einops --no-cache-dir
Expand Down
13 changes: 0 additions & 13 deletions server/Makefile-punica

This file was deleted.

2 changes: 1 addition & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from accelerate import init_empty_weights

try:
from punica.ops import add_lora_sgmv_cutlass
from lorax_server.utils.sgmv import add_lora_sgmv_cutlass
HAS_SGMV = True
except ImportError:
warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
Expand Down
36 changes: 36 additions & 0 deletions server/lorax_server/utils/sgmv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch


import punica_kernels as _kernels


# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
def add_lora_sgmv_cutlass(
y: torch.Tensor,
x: torch.Tensor,
wa_ptr: torch.Tensor,
wb_ptr: torch.Tensor,
s: torch.IntTensor,
layer_idx: int,
lora_rank: int,
):
"""
Semantics:
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]) @ deref(wb_ptr[i])
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, H1, R]`.
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
Weight matrix shape: `[num_layers, R, H2]`.
s: Shape: `[S+1]`, DType: torch.int32. Indptr of the weight matrices.\
`s[0] == 0`, `s[-1] == B`.
layer_idx: Layer index of the weight matrices.
"""
tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
_kernels.sgmv_cutlass(v, x, wa_ptr, s, tmp, layer_idx)
_kernels.sgmv_cutlass(y, v, wb_ptr, s, tmp, layer_idx)
1 change: 1 addition & 0 deletions server/punica_kernels/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
These kernels are forked from the [Punica](https://github.com/punica-ai/punica) project.
5 changes: 5 additions & 0 deletions server/punica_kernels/punica_kernels/bgmv/bgmv_all.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"

FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half)
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16)
38 changes: 38 additions & 0 deletions server/punica_kernels/punica_kernels/bgmv/bgmv_config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

template <int feat_in, int feat_out, typename T>
void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X,
const T* __restrict__ W, const int64_t* __restrict__ indicies,
int64_t batch_size, int64_t num_layers, int64_t layer_idx,
float scale);

// clang-format off

#define FOR_BGMV_WIDE(f, T, narrow) \
f(T, narrow, 768) \
f(T, narrow, 1024) \
f(T, narrow, 2048) \
f(T, narrow, 2560) \
f(T, narrow, 3072) \
f(T, narrow, 4096) \
f(T, narrow, 5120) \
f(T, narrow, 7168) \
f(T, narrow, 8192) \
f(T, narrow, 9216) \
f(T, narrow, 10240) \
f(T, narrow, 11008) \
f(T, narrow, 12288) \
f(T, narrow, 13824) \
f(T, narrow, 16384) \
f(T, narrow, 20480) \
f(T, narrow, 28672) \
f(T, narrow, 36864) \
f(T, narrow, 49152) \

#define FOR_BGMV_WIDE_NARROW(f, T) \
FOR_BGMV_WIDE(f, T, 8) \
FOR_BGMV_WIDE(f, T, 16) \
FOR_BGMV_WIDE(f, T, 32) \
FOR_BGMV_WIDE(f, T, 64)

// clang-format on
217 changes: 217 additions & 0 deletions server/punica_kernels/punica_kernels/bgmv/bgmv_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
#pragma once

#include <cooperative_groups.h>
#include <cuda_runtime.h>

#include <cuda/pipeline>
#include <iostream>

#include "../flashinfer/vec_dtypes.cuh"

namespace cg = cooperative_groups;

// nthrs = (32, 4)
template <int feat_in, int feat_out, typename T>
__global__ void bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X,
const T* __restrict__ W,
const int64_t* __restrict__ indicies,
int64_t num_layers, int64_t layer_idx,
float scale) {
auto block = cg::this_thread_block();
size_t j = blockIdx.x;
size_t batch_idx = blockIdx.y;
constexpr size_t vec_size = 16 / sizeof(T);
constexpr size_t tx = 32;
constexpr size_t ty = 4;
constexpr size_t num_pipeline_stages = 2;
constexpr size_t tile_size = tx * ty * vec_size;
__shared__ T W_shared[num_pipeline_stages * tile_size];
__shared__ T X_shared[num_pipeline_stages * tile_size];
__shared__ float y_warpwise[ty];

int64_t idx = indicies[batch_idx] * num_layers + layer_idx;

size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
auto pipe = cuda::make_pipeline();

// pipeline load W/X and compute WX;
pipe.producer_acquire();
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<16>(16), pipe);
cuda::memcpy_async(
X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) + (threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<16>(16), pipe);
pipe.producer_commit();
size_t copy_idx, compute_idx;
float y = 0.f;
flashinfer::vec_t<T, vec_size> x_vec, w_vec;
size_t tile_idx;

#pragma unroll
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
++tile_idx) {
copy_idx = tile_idx % num_pipeline_stages;
// pipeline stage: async copy W fragment
pipe.producer_acquire();
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<16>(16), pipe);
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) + tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<16>(16), pipe);
}
pipe.producer_commit();

compute_idx = (tile_idx - 1) % num_pipeline_stages;
// pipeline stage: compute WX
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] = sum;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}

block.sync();
pipe.consumer_release();
}

compute_idx = (tile_idx - 1) % num_pipeline_stages;
// final pipeline stage
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] =
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
? sum
: 0.f;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}

block.sync();
pipe.consumer_release();

// write Y;
if (block.thread_rank() == 0) {
Y[batch_idx * feat_out + j] += y;
}
}

// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, typename T>
__global__ void bgmv_expand_kernel(T* __restrict__ Y, const T* __restrict__ X,
const T* __restrict__ W,
const int64_t* __restrict__ indicies,
int64_t num_layers, int64_t layer_idx,
float scale) {
auto block = cg::this_thread_block();
constexpr size_t vec_size = 16 / sizeof(T);
constexpr size_t tx = feat_in / vec_size;
static_assert(feat_in % vec_size == 0);
constexpr size_t ty = 32 / tx;
static_assert(32 % tx == 0);
constexpr size_t tz = 4;
size_t tile_idx = blockIdx.x;
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;

// load X;
flashinfer::vec_t<T, vec_size> x_vec;
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);

// load W;
flashinfer::vec_t<T, vec_size> w_vec;
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
block.thread_rank() * vec_size);

float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}

cg::thread_block_tile g = cg::tiled_partition<tx>(block);
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += g.shfl_down(sum, offset);
}
sum = g.shfl(sum, 0);

if (threadIdx.x == 0) {
Y[batch_idx * feat_out + tile_idx * (tz * ty) + threadIdx.z * ty +
threadIdx.y] += sum;
}
}

template <int feat_in, int feat_out, typename T>
void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X,
const T* __restrict__ W, const int64_t* __restrict__ indicies,
int64_t batch_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t vec_size = 16 / sizeof(T);
if constexpr (feat_in < feat_out) {
size_t tx = feat_in / vec_size;
size_t ty = 32 / tx;
size_t tz = 4;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);

bgmv_expand_kernel<feat_in, feat_out>
<<<nblks, nthrs>>>(Y, X, W, indicies, num_layers, layer_idx, scale);
} else {
assert(feat_in % (vec_size * 32) == 0);
dim3 nblks(feat_out, batch_size);
dim3 nthrs(32, 4);
bgmv_shrink_kernel<feat_in, feat_out>
<<<nblks, nthrs>>>(Y, X, W, indicies, num_layers, layer_idx, scale);
}
}

#define INST_BGMV(feat_in, feat_out, T) \
template void bgmv_kernel<feat_in, feat_out>( \
T* __restrict__ Y, const T* __restrict__ X, const T* __restrict__ W, \
const int64_t* __restrict__ indicies, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);

#define INST_BGMV_TWOSIDE(T, narrow, wide) \
INST_BGMV(narrow, wide, T) \
INST_BGMV(wide, narrow, T)
5 changes: 5 additions & 0 deletions server/punica_kernels/punica_kernels/flashinfer/.clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# https://github.com/yzh119/flashinfer/blob/main/.clang-format
BasedOnStyle: Google
DerivePointerAlignment: false
ColumnLimit: 100
PointerAlignment: Left
Loading

0 comments on commit 611dfd1

Please sign in to comment.