Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor PyTorchEngine check env #2870

Merged
merged 7 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 5 additions & 268 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,277 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from logging import Logger
from typing import List

from lmdeploy.utils import get_logger


def _handle_exception(e: Exception,
mod_name: str,
logger: Logger,
message: str = None):
red_color = '\033[31m'
reset_color = '\033[0m'
if message is None:
message = 'Please ensure it has been installed correctly.'
logger.debug('Exception', exc_info=1)
logger.error(f'{type(e).__name__}: {e}')
logger.error(f'{red_color}'
f'<{mod_name}> test failed!\n'
f'{message}'
f'{reset_color}')
exit(1)
from .base import BaseChecker # noqa: F401


def check_env_deeplink(device_type: str):
"""check Deeplink environment."""
try_import_deeplink(device_type)
from .deeplink import DeeplinkChecker
checker = DeeplinkChecker(device_type)
checker.handle()


def try_import_deeplink(device_type: str):
"""import dlinfer if specific device_type is set."""
deeplink_device_type_list = [
'ascend',
'npu',
'maca',
]
if device_type in deeplink_device_type_list:
logger = get_logger('lmdeploy')
try:
import dlinfer.framework.lmdeploy_ext # noqa: F401
except Exception as e:
_handle_exception(e, 'PyTorch', logger)


def check_env_torch():
"""check PyTorch environment."""
logger = get_logger('lmdeploy')

try:
logger.debug('Checking <PyTorch> environment.')
import torch

a = torch.tensor([1, 2], device='cuda')
b = a.new_tensor([3, 4], device='cuda')
c = a + b
torch.testing.assert_close(c, a.new_tensor([4, 6]))
except Exception as e:
_handle_exception(e, 'PyTorch', logger)


MAX_TRITON_VERSION = '3.0.0'


def check_env_triton(device: str):
"""check OpenAI Triton environment."""
from packaging import version
logger = get_logger('lmdeploy')

msg = (
'Please ensure that your device is functioning properly with <Triton>.\n' # noqa: E501
'You can verify your environment by running '
'`python -m lmdeploy.pytorch.check_env.triton_custom_add`.')
try:
logger.debug('Checking <Triton> environment.')
import torch
import triton
triton_version = version.parse(triton.__version__)
if triton_version > version.parse(MAX_TRITON_VERSION):
logger.warning(
f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.')

from .triton_custom_add import custom_add
a = torch.tensor([1, 2], device='cuda')
b = a.new_tensor([3, 4], device='cuda')
c = custom_add(a, b)
torch.testing.assert_close(c, a + b)
except RuntimeError as e:
ptxas_error = 'device kernel image is invalid'
if len(e.args) > 0 and ptxas_error in e.args[0]:
msg = (
'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501
'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501
' or reinstall the driver.')
_handle_exception(e, 'Triton', logger, msg)
except Exception as e:
_handle_exception(e, 'Triton', logger, msg)

if device == 'cuda':
device_cap = torch.cuda.get_device_capability()
TRITON_VER_231 = version.parse('2.3.1')

if device_cap[0] <= 7:
if triton_version <= TRITON_VER_231:
err = RuntimeError(
'Attention triton kernel does not fully support '
'triton<3.0.0 on device with capability<8. '
'Please upgrade your triton version.')
_handle_exception(err, 'Triton', logger)


def check_env(device_type: str):
"""check all environment."""
logger = get_logger('lmdeploy')
logger.info('Checking environment for PyTorch Engine.')
"""check Deeplink environment."""
check_env_deeplink(device_type)
check_env_torch()
if device_type == 'cuda':
check_env_triton('cuda')


MIN_TRANSFORMERS_VERSION = '4.33.0'
MAX_TRANSFORMERS_VERSION = '4.44.1'


def check_awq(hf_config, device_type):
"""check awq support."""
logger = get_logger('lmdeploy')
if device_type == 'cuda':
quantization_config = getattr(hf_config, 'quantization_config', dict())
quant_method = quantization_config.get('quant_method', None)
if quant_method != 'awq':
return
try:
import awq # noqa
except Exception as e:
_handle_exception(e, 'autoawq', logger)

try:
import awq_ext # noqa
except Exception:
logger.debug('Exception:', exc_info=1)
logger.warning('Failed to import `awq_ext`. '
'Try reinstall it from source: '
'https://github.com/casper-hansen/AutoAWQ_kernels')


def check_transformers_version(model_path: str,
trust_remote_code: bool = True,
dtype: str = 'auto',
device_type: str = 'cuda'):
"""check transformers version."""
from packaging import version
logger = get_logger('lmdeploy')

def __check_transformers_version():
"""check transformers version."""
logger.debug('Checking <transformers> version.')
trans_version = None
try:
import transformers
trans_version = version.parse(transformers.__version__)
min_version = version.parse(MIN_TRANSFORMERS_VERSION)
max_version = version.parse(MAX_TRANSFORMERS_VERSION)
if trans_version < min_version or trans_version > max_version:
logger.warning('LMDeploy requires transformers version: '
f'[{MIN_TRANSFORMERS_VERSION} ~ '
f'{MAX_TRANSFORMERS_VERSION}], '
'but found version: '
f'{transformers.__version__}')
except Exception as e:
_handle_exception(e, 'transformers', logger)
return transformers, trans_version

def __check_config(trans_version):
"""check config."""
logger.debug('Checking <Model> AutoConfig.from_pretrained.')
try:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code)
except Exception as e:
message = (
f'Load model config with transformers=={trans_version}'
' failed. '
'Please make sure model can be loaded with transformers API.')
_handle_exception(e, 'transformers', logger, message=message)
return config

def __check_model_transformers_version(config, trans_version):
"""check model transformers version."""
logger.debug('Checking <Model> required transformers version.')
try:
model_trans_version = getattr(config, 'transformers_version', None)
if model_trans_version is not None:
model_trans_version = version.parse(model_trans_version)
assert trans_version >= model_trans_version, \
'Version mismatch.'
except Exception as e:
message = (f'model `{model_path}` requires '
f'transformers version {model_trans_version} '
f'but transformers {trans_version} is installed.')
_handle_exception(e, 'transformers', logger, message=message)

def __check_model_dtype_support(config, device_type):
"""Checking model dtype support."""
logger.debug('Checking <Model> dtype support.')

import torch

from lmdeploy.pytorch.config import ModelConfig
from lmdeploy.utils import is_bf16_supported

try:
model_config = ModelConfig.from_hf_config(config,
model_path=model_path,
dtype=dtype)
if model_config.dtype == torch.bfloat16:
assert is_bf16_supported(device_type), (
'bf16 is not supported on your device')
except AssertionError as e:
message = (
f'Your device does not support `{model_config.dtype}`. '
'You can set `dtype` to float16 in PyTorchEngineConfig or '
'`--dtype float16` to api_server.\n'
'Note that this might have negative effect!')
_handle_exception(e, 'Model', logger, message=message)
except Exception as e:
message = (f'Checking failed with error {e}',
'Please send issue to LMDeploy with error logs.')
_handle_exception(e, 'Model', logger, message=message)

return model_config

_, trans_version = __check_transformers_version()
config = __check_config(trans_version)
__check_model_transformers_version(config, trans_version)
__check_model_dtype_support(config, device_type)
check_awq(config, device_type)


def check_model(model_path: str,
trust_remote_code: bool = True,
dtype: str = 'auto',
device_type: str = 'cuda'):
"""check model requirements."""
logger = get_logger('lmdeploy')
logger.info('Checking model.')
check_transformers_version(model_path, trust_remote_code, dtype,
device_type)


def check_adapter(path: str):
"""check adapter."""
logger = get_logger('lmdeploy')
logger.debug(f'Checking <Adapter>: {path}.')

try:
from peft import PeftConfig
PeftConfig.from_pretrained(path)
except Exception as e:
message = ('Please make sure the adapter can be loaded with '
'`peft.PeftConfig.from_pretrained`\n')
err_msg = '' if len(e.args) == 0 else e.args[0]
if 'got an unexpected keyword argument' in err_msg:
message += ('Or try remove all unexpected keywords '
'in `adapter_config.json`.')
_handle_exception(e, 'Model', logger, message=message)


def check_adapters(adapter_paths: List[str]):
"""check adapters."""
if len(adapter_paths) <= 0:
return
logger = get_logger('lmdeploy')
logger.info('Checking adapters.')
for path in adapter_paths:
check_adapter(path)
31 changes: 31 additions & 0 deletions lmdeploy/pytorch/check_env/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseChecker


class AdapterChecker(BaseChecker):
"""check adapter is available."""

def __init__(self, adapter_path: str, logger=None):
super().__init__(logger)
self.adapter_path = adapter_path

def check(self):
"""check."""
path = self.adapter_path

try:
import peft # noqa: F401
except Exception as e:
self.log_and_exit(e, 'Adapter', message='Failed to import peft.')

try:
from peft import PeftConfig
PeftConfig.from_pretrained(path)
except Exception as e:
message = ('Please make sure the adapter can be loaded with '
'`peft.PeftConfig.from_pretrained`\n')
err_msg = '' if len(e.args) == 0 else e.args[0]
if 'got an unexpected keyword argument' in err_msg:
message += ('Or try remove all unexpected keywords '
'in `adapter_config.json`.')
self.log_and_exit(e, 'Adapter', message=message)
62 changes: 62 additions & 0 deletions lmdeploy/pytorch/check_env/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
from logging import Logger
from typing import List

from lmdeploy.utils import get_logger

RED_COLOR = '\033[31m'
RESET_COLOR = '\033[0m'


def _red_text(text: str):
"""red text."""
return f'{RED_COLOR}{text}{RESET_COLOR}'


class BaseChecker:
"""base checker."""

def __init__(self, logger: Logger = None):
if logger is None:
logger = get_logger('lmdeploy')
self.logger = logger
self._is_passed = False
self._required_checker: List[BaseChecker] = list()

def get_logger(self):
"""get logger."""
return self.logger

def register_required_checker(self, checker: 'BaseChecker'):
"""register_required."""
self._required_checker.append(checker)

def handle(self):
"""handle check."""
is_passed = getattr(self, '_is_passed', False)
if not is_passed:
checker_name = type(self).__name__
self.logger.debug(f'Checking <{checker_name}>:')
for checker in self._required_checker:
checker.handle()
self.check()
self.is_passed = True

def log_and_exit(self,
e: Exception = None,
mod_name: str = None,
message: str = None):
logger = self.logger
if mod_name is None:
mod_name = type(self).__name__
if message is None:
message = 'Please check your environment.'
logger.debug('Exception', exc_info=1)
if e is not None:
logger.error(f'{type(e).__name__}: {e}')
logger.error(f'<{mod_name}> check failed!\n{_red_text(message)}')
exit(1)

def check(self):
"""check."""
raise NotImplementedError('check not implemented.')
Loading
Loading