Skip to content

Commit

Permalink
add check for device with cap 7.x (#2535)
Browse files Browse the repository at this point in the history
* add check for device with cap 7.x

* update hint

* update->upgrade
  • Loading branch information
grimoire authored Oct 9, 2024
1 parent a5ee8df commit 7c6b107
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 14 deletions.
24 changes: 19 additions & 5 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def check_env_torch():
_handle_exception(e, 'PyTorch', logger)


MAX_TRITON_VERSION = '2.2.0'
MAX_TRITON_VERSION = '3.0.0'


def check_env_triton():
def check_env_triton(device: str):
"""check OpenAI Triton environment."""
from packaging import version
logger = get_logger('lmdeploy')
Expand All @@ -68,8 +68,8 @@ def check_env_triton():
logger.debug('Checking <Triton> environment.')
import torch
import triton
if version.parse(
triton.__version__) > version.parse(MAX_TRITON_VERSION):
triton_version = version.parse(triton.__version__)
if triton_version > version.parse(MAX_TRITON_VERSION):
logger.warning(
f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.')

Expand All @@ -91,6 +91,20 @@ def check_env_triton():
except Exception as e:
_handle_exception(e, 'Triton', logger)

if device == 'cuda':
device_cap = torch.cuda.get_device_capability()
TRITON_VER_220 = version.parse('2.2.0')
TRITON_VER_231 = version.parse('2.3.1')

if device_cap[0] <= 7:
if (triton_version >= TRITON_VER_220
and triton_version <= TRITON_VER_231):
err = RuntimeError(
'Attention triton kernel does not fully support '
'triton[2.2.0~2.3.1] on device with capability<8. '
'Please upgrade your triton version.')
_handle_exception(err, 'Triton', logger)


def check_env(device_type: str):
"""check all environment."""
Expand All @@ -99,7 +113,7 @@ def check_env(device_type: str):
check_env_deeplink(device_type)
check_env_torch()
if device_type == 'cuda':
check_env_triton()
check_env_triton('cuda')


MIN_TRANSFORMERS_VERSION = '4.33.0'
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,11 @@ def __init__(self,
super().__init__()
quantization_config = getattr(config, 'quantization_config', None)
# gate up
mlp_bias = getattr(config, 'mlp_bias', False)
self.gate_up_proj = build_merged_colwise_linear(
config.hidden_size,
[config.intermediate_size, config.intermediate_size],
bias=config.mlp_bias,
bias=mlp_bias,
dtype=dtype,
device=device,
quant_config=quantization_config,
Expand All @@ -130,7 +131,7 @@ def __init__(self,
# down
self.down_proj = build_rowwise_linear(config.intermediate_size,
config.hidden_size,
bias=config.mlp_bias,
bias=mlp_bias,
quant_config=quantization_config,
dtype=dtype,
device=device,
Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ tiktoken
torch<=2.3.1,>=2.0.0
torchvision<=0.18.1,>=0.15.0
transformers
triton>=2.1.0,<=2.3.1; sys_platform == "linux"
triton>=2.1.0,<=3.0.0; sys_platform == "linux"
uvicorn
11 changes: 9 additions & 2 deletions tests/pytorch/kernel/test_apply_rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ def _rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def _bf16_mark():
return pytest.mark.skipif(not torch.cuda.is_bf16_supported(),
reason='bf16 not supported.')


class TestApplyRotary:

@pytest.fixture
Expand Down Expand Up @@ -87,8 +92,10 @@ def gt(self, q_states, k_states, cos, sin, position_ids_1d):

yield q_embed, k_embed

@pytest.mark.parametrize('dtype',
[torch.bfloat16, torch.float16, torch.float32],
@pytest.mark.parametrize('dtype', [
pytest.param(torch.bfloat16, marks=_bf16_mark()), torch.float16,
torch.float32
],
indirect=True)
@pytest.mark.parametrize(('num_heads_q', 'num_heads_k'), [(8, 8), (8, 4)],
indirect=True)
Expand Down
11 changes: 9 additions & 2 deletions tests/pytorch/kernel/test_multinomial_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from lmdeploy.pytorch.kernels import multinomial_sampling


def _bf16_mark():
return pytest.mark.skipif(not torch.cuda.is_bf16_supported(),
reason='bf16 not supported.')


class TestMultinomialSampling:

@pytest.fixture
Expand Down Expand Up @@ -50,8 +55,10 @@ def gt(self, batch_size, select_ids, indices):
batch_ids = torch.arange(batch_size).cuda()
yield indices[batch_ids, select_ids]

@pytest.mark.parametrize('dtype',
[torch.float32, torch.half, torch.bfloat16])
@pytest.mark.parametrize('dtype', [
torch.float32, torch.half,
pytest.param(torch.bfloat16, marks=_bf16_mark())
])
@pytest.mark.parametrize(['num_tokens', 'select_ids'], [
(8, (4, 2) * 30),
(2000, (500, 1500)),
Expand Down
11 changes: 9 additions & 2 deletions tests/pytorch/kernel/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import torch


def _bf16_mark():
return pytest.mark.skipif(not torch.cuda.is_bf16_supported(),
reason='bf16 not supported.')


class TestRMSNorm:

@pytest.fixture(scope='class')
Expand All @@ -28,8 +33,10 @@ def gt(self, input, weight, eps):
input = input * torch.rsqrt(variance + eps)
return weight * input.to(input_dtype)

@pytest.mark.parametrize('dtype',
[torch.bfloat16, torch.float16, torch.float32],
@pytest.mark.parametrize('dtype', [
pytest.param(torch.bfloat16, marks=_bf16_mark()), torch.float16,
torch.float32
],
indirect=True)
def test_rms_norm(self, input, weight, eps, gt):
from lmdeploy.pytorch.kernels import rms_norm
Expand Down

0 comments on commit 7c6b107

Please sign in to comment.