Skip to content

Commit

Permalink
ci: reduce binary size (#172)
Browse files Browse the repository at this point in the history
1. Do not generate prefill kernels for `page_size=8`
2. Build with `-Xfatbin=-compress-all` to reduce binary size.

Followup of #171 , @Qubitium the cuda architectures to be compiled could
be controlled by environment variable `TORCH_CUDA_ARCH_LIST`, so I
removed the gencode/archs specified in compile args.
  • Loading branch information
yzh119 authored Mar 11, 2024
1 parent 2657813 commit bd5b60a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ on:
# required: true

env:
TORCH_CUDA_ARCH_LIST: "8.0 8.6 8.9 9.0+PTX"
TORCH_CUDA_ARCH_LIST: "8.0 8.9 9.0+PTX"

jobs:
build:
Expand Down
2 changes: 1 addition & 1 deletion python/generate_batch_paged_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def get_cu_file_str(
dtype_in,
dtype_out,
idtype,
page_size_choices=[1, 8, 16, 32],
):
num_frags_x_choices = [1, 2]
page_size_choices = [1, 8, 16, 32]
insts = "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<page_storage, {kv_layout}, {num_frags_x}, {page_size}, {group_size}, {head_dim}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {causal}, {dtype_in}, {dtype_out}, {idtype}>(
Expand Down
6 changes: 3 additions & 3 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def get_instantiation_cu() -> List[str]:
dtype,
dtype,
idtype,
page_size_choices=[1, 16, 32],
)
write_if_different(root / prefix / fname, content)

Expand Down Expand Up @@ -378,9 +379,8 @@ def __init__(self, *args, **kwargs) -> None:
str(root.resolve() / "include"),
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17", "--threads", "8", "-gencode", "arch=compute_80,code=sm_80",
"-gencode", "arch=compute_89,code=sm_89", "-gencode", "arch=compute_90,code=sm_90"],
"cxx": ["-O3"],
"nvcc": ["-O3", "-std=c++17", "--threads", "8", "-Xfatbin", "-compress-all"],
},
)
)
Expand Down

0 comments on commit bd5b60a

Please sign in to comment.