From 7c6b1077060a66991d166aab8475b0e4043382b0 Mon Sep 17 00:00:00 2001 From: q yao Date: Wed, 9 Oct 2024 11:53:04 +0800 Subject: [PATCH] add check for device with cap 7.x (#2535) * add check for device with cap 7.x * update hint * update->upgrade --- lmdeploy/pytorch/check_env/__init__.py | 24 +++++++++++++++---- lmdeploy/pytorch/models/llama.py | 5 ++-- requirements/runtime.txt | 2 +- tests/pytorch/kernel/test_apply_rotary.py | 11 +++++++-- .../kernel/test_multinomial_sampling.py | 11 +++++++-- tests/pytorch/kernel/test_rms_norm.py | 11 +++++++-- 6 files changed, 50 insertions(+), 14 deletions(-) diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index b066d24423..ea2dda8e8d 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -56,10 +56,10 @@ def check_env_torch(): _handle_exception(e, 'PyTorch', logger) -MAX_TRITON_VERSION = '2.2.0' +MAX_TRITON_VERSION = '3.0.0' -def check_env_triton(): +def check_env_triton(device: str): """check OpenAI Triton environment.""" from packaging import version logger = get_logger('lmdeploy') @@ -68,8 +68,8 @@ def check_env_triton(): logger.debug('Checking environment.') import torch import triton - if version.parse( - triton.__version__) > version.parse(MAX_TRITON_VERSION): + triton_version = version.parse(triton.__version__) + if triton_version > version.parse(MAX_TRITON_VERSION): logger.warning( f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.') @@ -91,6 +91,20 @@ def check_env_triton(): except Exception as e: _handle_exception(e, 'Triton', logger) + if device == 'cuda': + device_cap = torch.cuda.get_device_capability() + TRITON_VER_220 = version.parse('2.2.0') + TRITON_VER_231 = version.parse('2.3.1') + + if device_cap[0] <= 7: + if (triton_version >= TRITON_VER_220 + and triton_version <= TRITON_VER_231): + err = RuntimeError( + 'Attention triton kernel does not fully support ' + 'triton[2.2.0~2.3.1] on device with capability<8. ' + 'Please upgrade your triton version.') + _handle_exception(err, 'Triton', logger) + def check_env(device_type: str): """check all environment.""" @@ -99,7 +113,7 @@ def check_env(device_type: str): check_env_deeplink(device_type) check_env_torch() if device_type == 'cuda': - check_env_triton() + check_env_triton('cuda') MIN_TRANSFORMERS_VERSION = '4.33.0' diff --git a/lmdeploy/pytorch/models/llama.py b/lmdeploy/pytorch/models/llama.py index 4af4e55ba4..2641429683 100644 --- a/lmdeploy/pytorch/models/llama.py +++ b/lmdeploy/pytorch/models/llama.py @@ -114,10 +114,11 @@ def __init__(self, super().__init__() quantization_config = getattr(config, 'quantization_config', None) # gate up + mlp_bias = getattr(config, 'mlp_bias', False) self.gate_up_proj = build_merged_colwise_linear( config.hidden_size, [config.intermediate_size, config.intermediate_size], - bias=config.mlp_bias, + bias=mlp_bias, dtype=dtype, device=device, quant_config=quantization_config, @@ -130,7 +131,7 @@ def __init__(self, # down self.down_proj = build_rowwise_linear(config.intermediate_size, config.hidden_size, - bias=config.mlp_bias, + bias=mlp_bias, quant_config=quantization_config, dtype=dtype, device=device, diff --git a/requirements/runtime.txt b/requirements/runtime.txt index afa048bb5e..7e5058c17b 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -18,5 +18,5 @@ tiktoken torch<=2.3.1,>=2.0.0 torchvision<=0.18.1,>=0.15.0 transformers -triton>=2.1.0,<=2.3.1; sys_platform == "linux" +triton>=2.1.0,<=3.0.0; sys_platform == "linux" uvicorn diff --git a/tests/pytorch/kernel/test_apply_rotary.py b/tests/pytorch/kernel/test_apply_rotary.py index 0050a59ef8..c8ca1dd77c 100644 --- a/tests/pytorch/kernel/test_apply_rotary.py +++ b/tests/pytorch/kernel/test_apply_rotary.py @@ -11,6 +11,11 @@ def _rotate_half(x): return torch.cat((-x2, x1), dim=-1) +def _bf16_mark(): + return pytest.mark.skipif(not torch.cuda.is_bf16_supported(), + reason='bf16 not supported.') + + class TestApplyRotary: @pytest.fixture @@ -87,8 +92,10 @@ def gt(self, q_states, k_states, cos, sin, position_ids_1d): yield q_embed, k_embed - @pytest.mark.parametrize('dtype', - [torch.bfloat16, torch.float16, torch.float32], + @pytest.mark.parametrize('dtype', [ + pytest.param(torch.bfloat16, marks=_bf16_mark()), torch.float16, + torch.float32 + ], indirect=True) @pytest.mark.parametrize(('num_heads_q', 'num_heads_k'), [(8, 8), (8, 4)], indirect=True) diff --git a/tests/pytorch/kernel/test_multinomial_sampling.py b/tests/pytorch/kernel/test_multinomial_sampling.py index 260224feec..9636fa5d3f 100644 --- a/tests/pytorch/kernel/test_multinomial_sampling.py +++ b/tests/pytorch/kernel/test_multinomial_sampling.py @@ -4,6 +4,11 @@ from lmdeploy.pytorch.kernels import multinomial_sampling +def _bf16_mark(): + return pytest.mark.skipif(not torch.cuda.is_bf16_supported(), + reason='bf16 not supported.') + + class TestMultinomialSampling: @pytest.fixture @@ -50,8 +55,10 @@ def gt(self, batch_size, select_ids, indices): batch_ids = torch.arange(batch_size).cuda() yield indices[batch_ids, select_ids] - @pytest.mark.parametrize('dtype', - [torch.float32, torch.half, torch.bfloat16]) + @pytest.mark.parametrize('dtype', [ + torch.float32, torch.half, + pytest.param(torch.bfloat16, marks=_bf16_mark()) + ]) @pytest.mark.parametrize(['num_tokens', 'select_ids'], [ (8, (4, 2) * 30), (2000, (500, 1500)), diff --git a/tests/pytorch/kernel/test_rms_norm.py b/tests/pytorch/kernel/test_rms_norm.py index 0511ac5f43..b731f372de 100644 --- a/tests/pytorch/kernel/test_rms_norm.py +++ b/tests/pytorch/kernel/test_rms_norm.py @@ -2,6 +2,11 @@ import torch +def _bf16_mark(): + return pytest.mark.skipif(not torch.cuda.is_bf16_supported(), + reason='bf16 not supported.') + + class TestRMSNorm: @pytest.fixture(scope='class') @@ -28,8 +33,10 @@ def gt(self, input, weight, eps): input = input * torch.rsqrt(variance + eps) return weight * input.to(input_dtype) - @pytest.mark.parametrize('dtype', - [torch.bfloat16, torch.float16, torch.float32], + @pytest.mark.parametrize('dtype', [ + pytest.param(torch.bfloat16, marks=_bf16_mark()), torch.float16, + torch.float32 + ], indirect=True) def test_rms_norm(self, input, weight, eps, gt): from lmdeploy.pytorch.kernels import rms_norm