diff --git a/.gitmodules b/.gitmodules index 099f2906d..c0d8a7b60 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,4 +3,4 @@ url = https://github.com/NVIDIA/cutlass.git [submodule "server/punica_kernels/third_party/flashinfer"] path = server/punica_kernels/third_party/flashinfer - url = https://github.com/flashinfer-ai/flashinfer.git + url = https://github.com/tgaddair/flashinfer.git diff --git a/docs/guides/contributing/development_env.md b/docs/guides/contributing/development_env.md index 8b5395628..f33c4c2d6 100644 --- a/docs/guides/contributing/development_env.md +++ b/docs/guides/contributing/development_env.md @@ -149,7 +149,7 @@ apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install && rm -rf /var/lib/apt/lists/* conda update --force conda -/opt/conda/bin/conda install -c "nvidia/label/cuda-12.1.0" cuda==12.1 cudnn && \ +/opt/conda/bin/conda install -c "nvidia" cuda==12.4 cudnn && \ /opt/conda/bin/conda clean -ya ``` diff --git a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h index ac2a66d92..947822ea8 100644 --- a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h +++ b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h @@ -50,6 +50,7 @@ void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X, f(T, narrow, 14336) \ f(T, narrow, 15360) \ f(T, narrow, 16384) \ + f(T, narrow, 18944) \ f(T, narrow, 20480) \ f(T, narrow, 22016) \ f(T, narrow, 24576) \ @@ -78,7 +79,6 @@ void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X, FOR_BGMV_WIDE(f, T, 16) \ FOR_BGMV_WIDE(f, T, 32) \ FOR_BGMV_WIDE(f, T, 64) \ - FOR_BGMV_WIDE(f, T, 96) \ FOR_BGMV_WIDE(f, T, 128) // clang-format on diff --git a/server/punica_kernels/third_party/flashinfer b/server/punica_kernels/third_party/flashinfer index 8159aec0b..1baac901e 160000 --- a/server/punica_kernels/third_party/flashinfer +++ b/server/punica_kernels/third_party/flashinfer @@ -1 +1 @@ -Subproject commit 8159aec0baa883eb8ad84ffb4f47a5b7c1f65984 +Subproject commit 1baac901e7c394e5011acab38cb68c8e3a9dbf4d