Skip to content

Commit

Permalink
optimize paged attention on triton3 (#2553)
Browse files Browse the repository at this point in the history
* optimize paged attention on triton3

* fix w8a8 kernel

* optimize prefill

* optimize short decoding

* optimize sm<8

* optimize short context

* fix triton2.2.0

* recovery test

* add ut for custom layout

* update stride

* update ut
  • Loading branch information
grimoire authored Oct 18, 2024
1 parent fec94c9 commit 7dc0a5c
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 151 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
- name: Install pytorch
run: |
python3 -m pip cache dir
python3 -m pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118
python3 -m pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
- name: Build lmdeploy
run: |
python3 -m pip install cmake
Expand All @@ -77,7 +77,7 @@ jobs:
run: |
python3 -m pip install pynvml packaging protobuf transformers_stream_generator
# manually install flash attn
python3 -m pip install /root/packages/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
python3 -m pip install -r requirements.txt -r requirements/test.txt
python3 -m pip install .
- name: Check env
Expand Down
18 changes: 10 additions & 8 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,13 @@ def check_env_triton(device: str):

if device == 'cuda':
device_cap = torch.cuda.get_device_capability()
TRITON_VER_220 = version.parse('2.2.0')
TRITON_VER_231 = version.parse('2.3.1')

if device_cap[0] <= 7:
if (triton_version >= TRITON_VER_220
and triton_version <= TRITON_VER_231):
if triton_version <= TRITON_VER_231:
err = RuntimeError(
'Attention triton kernel does not fully support '
'triton[2.2.0~2.3.1] on device with capability<8. '
'triton<3.0.0 on device with capability<8. '
'Please upgrade your triton version.')
_handle_exception(err, 'Triton', logger)

Expand Down Expand Up @@ -142,7 +140,8 @@ def check_awq(hf_config):


def check_transformers_version(model_path: str,
trust_remote_code: bool = True):
trust_remote_code: bool = True,
dtype: str = 'auto'):
"""check transformers version."""
from packaging import version
logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -206,7 +205,8 @@ def __check_model_dtype_support(config):

try:
model_config = ModelConfig.from_hf_config(config,
model_path=model_path)
model_path=model_path,
dtype=dtype)
if model_config.dtype == torch.bfloat16:
assert torch.cuda.is_bf16_supported(), (
'bf16 is not supported on your device')
Expand All @@ -229,11 +229,13 @@ def __check_model_dtype_support(config):
check_awq(config)


def check_model(model_path: str, trust_remote_code: bool = True):
def check_model(model_path: str,
trust_remote_code: bool = True,
dtype: str = 'auto'):
"""check model requirements."""
logger = get_logger('lmdeploy')
logger.info('Checking model.')
check_transformers_version(model_path, trust_remote_code)
check_transformers_version(model_path, trust_remote_code, dtype)


def check_adapter(path: str):
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self,
else:
engine_config = copy.deepcopy(engine_config)
check_env(engine_config.device_type)
check_model(model_path, trust_remote_code)
check_model(model_path, trust_remote_code, engine_config.dtype)
if engine_config.max_batch_size is None:
engine_config.max_batch_size = get_max_batch_size(
engine_config.device_type)
Expand Down
61 changes: 35 additions & 26 deletions lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,21 @@ def fill_kv_cache(k_states: Tensor,
block_offsets: Tensor,
k_scales_zeros: Tensor = None,
v_scales_zeros: Tensor = None,
quant_policy: Literal[0, 4, 8] = 0):
quant_policy: Literal[0, 4, 8] = 0,
kv_layout: str = 'bshd'):
"""fill key/value state to cache for paged attention."""
if kv_layout == 'bshd':
b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
elif kv_layout == 'bhsd':
b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)
else:
raise RuntimeError('Unsupported layout.')

block_offsets = block_offsets.contiguous()
batch_size = block_offsets.size(0)
block_size, num_heads, head_dim = k_caches.size()[1:]
block_size = k_caches.size(s_dim)
num_heads = k_caches.size(h_dim)
head_dim = k_caches.size(d_dim)
head_dim_v = v_states.size(-1)
max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1

Expand Down Expand Up @@ -412,14 +421,14 @@ def fill_kv_cache(k_states: Tensor,
stride_vss=v_states.stride(-3),
stride_vsh=v_states.stride(-2),
stride_vsd=v_states.stride(-1),
stride_kcn=k_caches.stride(0),
stride_kcb=k_caches.stride(1),
stride_kch=k_caches.stride(2),
stride_kcd=k_caches.stride(3),
stride_vcn=v_caches.stride(0),
stride_vcb=v_caches.stride(1),
stride_vch=v_caches.stride(2),
stride_vcd=v_caches.stride(3),
stride_kcn=k_caches.stride(b_dim),
stride_kcb=k_caches.stride(s_dim),
stride_kch=k_caches.stride(h_dim),
stride_kcd=k_caches.stride(d_dim),
stride_vcn=v_caches.stride(b_dim),
stride_vcb=v_caches.stride(s_dim),
stride_vch=v_caches.stride(h_dim),
stride_vcd=v_caches.stride(d_dim),
stride_boff=block_offsets.stride(0),
BLOCK=BLOCK,
BLOCK_D=BLOCK_D,
Expand Down Expand Up @@ -450,22 +459,22 @@ def fill_kv_cache(k_states: Tensor,
stride_vss=v_states.stride(-3),
stride_vsh=v_states.stride(-2),
stride_vsd=v_states.stride(-1),
stride_kcn=k_caches.stride(0),
stride_kcb=k_caches.stride(1),
stride_kch=k_caches.stride(2),
stride_kcd=k_caches.stride(3),
stride_vcn=v_caches.stride(0),
stride_vcb=v_caches.stride(1),
stride_vch=v_caches.stride(2),
stride_vcd=v_caches.stride(3),
stride_kszn=k_scales_zeros.stride(0),
stride_kszb=k_scales_zeros.stride(1),
stride_kszh=k_scales_zeros.stride(2),
stride_kszd=k_scales_zeros.stride(3),
stride_vszn=v_scales_zeros.stride(0),
stride_vszb=v_scales_zeros.stride(1),
stride_vszh=v_scales_zeros.stride(2),
stride_vszd=v_scales_zeros.stride(3),
stride_kcn=k_caches.stride(b_dim),
stride_kcb=k_caches.stride(s_dim),
stride_kch=k_caches.stride(h_dim),
stride_kcd=k_caches.stride(d_dim),
stride_vcn=v_caches.stride(b_dim),
stride_vcb=v_caches.stride(s_dim),
stride_vch=v_caches.stride(h_dim),
stride_vcd=v_caches.stride(d_dim),
stride_kszn=k_scales_zeros.stride(b_dim),
stride_kszb=k_scales_zeros.stride(s_dim),
stride_kszh=k_scales_zeros.stride(h_dim),
stride_kszd=k_scales_zeros.stride(d_dim),
stride_vszn=v_scales_zeros.stride(b_dim),
stride_vszb=v_scales_zeros.stride(s_dim),
stride_vszh=v_scales_zeros.stride(h_dim),
stride_vszd=v_scales_zeros.stride(d_dim),
quant_policy=quant_policy,
stride_boff=block_offsets.stride(0),
BLOCK=BLOCK,
Expand Down
Loading

0 comments on commit 7dc0a5c

Please sign in to comment.