From eaa4e6f7f2772f5557c1dbeec919b9f95e27edca Mon Sep 17 00:00:00 2001 From: q yao Date: Thu, 24 Oct 2024 12:59:02 +0800 Subject: [PATCH] update check for triton (#2641) --- lmdeploy/pytorch/check_env/__init__.py | 8 +++++--- lmdeploy/pytorch/check_env/triton_custom_add.py | 8 ++++++++ lmdeploy/pytorch/kernels/cuda/pagedattention.py | 4 ++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 2b4b3cc521..291b1afb35 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -66,6 +66,10 @@ def check_env_triton(device: str): from packaging import version logger = get_logger('lmdeploy') + msg = ( + 'Please ensure that your device is functioning properly with .\n' # noqa: E501 + 'You can verify your environment by running ' + '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.') try: logger.debug('Checking environment.') import torch @@ -87,11 +91,9 @@ def check_env_triton(device: str): 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501 ' or reinstall the driver.') - else: - msg = None _handle_exception(e, 'Triton', logger, msg) except Exception as e: - _handle_exception(e, 'Triton', logger) + _handle_exception(e, 'Triton', logger, msg) if device == 'cuda': device_cap = torch.cuda.get_device_capability() diff --git a/lmdeploy/pytorch/check_env/triton_custom_add.py b/lmdeploy/pytorch/check_env/triton_custom_add.py index ef77fb8105..077359110b 100644 --- a/lmdeploy/pytorch/check_env/triton_custom_add.py +++ b/lmdeploy/pytorch/check_env/triton_custom_add.py @@ -23,3 +23,11 @@ def custom_add(a, b): grid = (triton.cdiv(size, BLOCK), ) _add_kernel[grid](a, b, c, size, BLOCK=BLOCK) return c + + +if __name__ == '__main__': + a = torch.tensor([1, 2], device='cuda') + b = a.new_tensor([3, 4], device='cuda') + c = custom_add(a, b) + torch.testing.assert_close(c, a + b) + print('Done.') diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index 7790a44b19..e15ab911fc 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -1153,9 +1153,9 @@ def _get_block_d(Lk): if not is_decoding: BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq) if _nv_cap[0] < 8: - BLOCK_M = max(16, min(BLOCK, 8192 // BLOCK_DMODEL)) + BLOCK_M = max(16, 8192 // BLOCK_DMODEL) else: - BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL)) + BLOCK_M = max(16, 16384 // BLOCK_DMODEL) num_warps = 4 num_stages = 2 kv_head = k.shape[h_dim]