From b99a5da255631ea4d24f03bdefccf44bf89aa108 Mon Sep 17 00:00:00 2001 From: q yao Date: Thu, 12 Dec 2024 17:45:14 +0800 Subject: [PATCH] refactor PyTorchEngine check env (#2870) * refactor checker * config builder * fix * fix * update triton * remove dockerfile update * update torch version --- lmdeploy/pytorch/check_env/__init__.py | 273 +-------------------- lmdeploy/pytorch/check_env/adapter.py | 31 +++ lmdeploy/pytorch/check_env/base.py | 62 +++++ lmdeploy/pytorch/check_env/deeplink.py | 25 ++ lmdeploy/pytorch/check_env/model.py | 117 +++++++++ lmdeploy/pytorch/check_env/torch.py | 21 ++ lmdeploy/pytorch/check_env/transformers.py | 29 +++ lmdeploy/pytorch/check_env/triton.py | 60 +++++ lmdeploy/pytorch/engine/engine.py | 78 +++--- lmdeploy/pytorch/engine/engine_checker.py | 77 ++++++ requirements/runtime.txt | 4 +- 11 files changed, 474 insertions(+), 303 deletions(-) create mode 100644 lmdeploy/pytorch/check_env/adapter.py create mode 100644 lmdeploy/pytorch/check_env/base.py create mode 100644 lmdeploy/pytorch/check_env/deeplink.py create mode 100644 lmdeploy/pytorch/check_env/model.py create mode 100644 lmdeploy/pytorch/check_env/torch.py create mode 100644 lmdeploy/pytorch/check_env/transformers.py create mode 100644 lmdeploy/pytorch/check_env/triton.py create mode 100644 lmdeploy/pytorch/engine/engine_checker.py diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 7d7243822..bc95a32be 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -1,277 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from logging import Logger -from typing import List - -from lmdeploy.utils import get_logger - - -def _handle_exception(e: Exception, - mod_name: str, - logger: Logger, - message: str = None): - red_color = '\033[31m' - reset_color = '\033[0m' - if message is None: - message = 'Please ensure it has been installed correctly.' - logger.debug('Exception', exc_info=1) - logger.error(f'{type(e).__name__}: {e}') - logger.error(f'{red_color}' - f'<{mod_name}> test failed!\n' - f'{message}' - f'{reset_color}') - exit(1) +from .base import BaseChecker # noqa: F401 def check_env_deeplink(device_type: str): """check Deeplink environment.""" - try_import_deeplink(device_type) + from .deeplink import DeeplinkChecker + checker = DeeplinkChecker(device_type) + checker.handle() def try_import_deeplink(device_type: str): - """import dlinfer if specific device_type is set.""" - deeplink_device_type_list = [ - 'ascend', - 'npu', - 'maca', - ] - if device_type in deeplink_device_type_list: - logger = get_logger('lmdeploy') - try: - import dlinfer.framework.lmdeploy_ext # noqa: F401 - except Exception as e: - _handle_exception(e, 'PyTorch', logger) - - -def check_env_torch(): - """check PyTorch environment.""" - logger = get_logger('lmdeploy') - - try: - logger.debug('Checking environment.') - import torch - - a = torch.tensor([1, 2], device='cuda') - b = a.new_tensor([3, 4], device='cuda') - c = a + b - torch.testing.assert_close(c, a.new_tensor([4, 6])) - except Exception as e: - _handle_exception(e, 'PyTorch', logger) - - -MAX_TRITON_VERSION = '3.0.0' - - -def check_env_triton(device: str): - """check OpenAI Triton environment.""" - from packaging import version - logger = get_logger('lmdeploy') - - msg = ( - 'Please ensure that your device is functioning properly with .\n' # noqa: E501 - 'You can verify your environment by running ' - '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.') - try: - logger.debug('Checking environment.') - import torch - import triton - triton_version = version.parse(triton.__version__) - if triton_version > version.parse(MAX_TRITON_VERSION): - logger.warning( - f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.') - - from .triton_custom_add import custom_add - a = torch.tensor([1, 2], device='cuda') - b = a.new_tensor([3, 4], device='cuda') - c = custom_add(a, b) - torch.testing.assert_close(c, a + b) - except RuntimeError as e: - ptxas_error = 'device kernel image is invalid' - if len(e.args) > 0 and ptxas_error in e.args[0]: - msg = ( - 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501 - 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501 - ' or reinstall the driver.') - _handle_exception(e, 'Triton', logger, msg) - except Exception as e: - _handle_exception(e, 'Triton', logger, msg) - - if device == 'cuda': - device_cap = torch.cuda.get_device_capability() - TRITON_VER_231 = version.parse('2.3.1') - - if device_cap[0] <= 7: - if triton_version <= TRITON_VER_231: - err = RuntimeError( - 'Attention triton kernel does not fully support ' - 'triton<3.0.0 on device with capability<8. ' - 'Please upgrade your triton version.') - _handle_exception(err, 'Triton', logger) - - -def check_env(device_type: str): - """check all environment.""" - logger = get_logger('lmdeploy') - logger.info('Checking environment for PyTorch Engine.') + """check Deeplink environment.""" check_env_deeplink(device_type) - check_env_torch() - if device_type == 'cuda': - check_env_triton('cuda') - - -MIN_TRANSFORMERS_VERSION = '4.33.0' -MAX_TRANSFORMERS_VERSION = '4.44.1' - - -def check_awq(hf_config, device_type): - """check awq support.""" - logger = get_logger('lmdeploy') - if device_type == 'cuda': - quantization_config = getattr(hf_config, 'quantization_config', dict()) - quant_method = quantization_config.get('quant_method', None) - if quant_method != 'awq': - return - try: - import awq # noqa - except Exception as e: - _handle_exception(e, 'autoawq', logger) - - try: - import awq_ext # noqa - except Exception: - logger.debug('Exception:', exc_info=1) - logger.warning('Failed to import `awq_ext`. ' - 'Try reinstall it from source: ' - 'https://github.com/casper-hansen/AutoAWQ_kernels') - - -def check_transformers_version(model_path: str, - trust_remote_code: bool = True, - dtype: str = 'auto', - device_type: str = 'cuda'): - """check transformers version.""" - from packaging import version - logger = get_logger('lmdeploy') - - def __check_transformers_version(): - """check transformers version.""" - logger.debug('Checking version.') - trans_version = None - try: - import transformers - trans_version = version.parse(transformers.__version__) - min_version = version.parse(MIN_TRANSFORMERS_VERSION) - max_version = version.parse(MAX_TRANSFORMERS_VERSION) - if trans_version < min_version or trans_version > max_version: - logger.warning('LMDeploy requires transformers version: ' - f'[{MIN_TRANSFORMERS_VERSION} ~ ' - f'{MAX_TRANSFORMERS_VERSION}], ' - 'but found version: ' - f'{transformers.__version__}') - except Exception as e: - _handle_exception(e, 'transformers', logger) - return transformers, trans_version - - def __check_config(trans_version): - """check config.""" - logger.debug('Checking AutoConfig.from_pretrained.') - try: - from transformers import AutoConfig - config = AutoConfig.from_pretrained( - model_path, trust_remote_code=trust_remote_code) - except Exception as e: - message = ( - f'Load model config with transformers=={trans_version}' - ' failed. ' - 'Please make sure model can be loaded with transformers API.') - _handle_exception(e, 'transformers', logger, message=message) - return config - - def __check_model_transformers_version(config, trans_version): - """check model transformers version.""" - logger.debug('Checking required transformers version.') - try: - model_trans_version = getattr(config, 'transformers_version', None) - if model_trans_version is not None: - model_trans_version = version.parse(model_trans_version) - assert trans_version >= model_trans_version, \ - 'Version mismatch.' - except Exception as e: - message = (f'model `{model_path}` requires ' - f'transformers version {model_trans_version} ' - f'but transformers {trans_version} is installed.') - _handle_exception(e, 'transformers', logger, message=message) - - def __check_model_dtype_support(config, device_type): - """Checking model dtype support.""" - logger.debug('Checking dtype support.') - - import torch - - from lmdeploy.pytorch.config import ModelConfig - from lmdeploy.utils import is_bf16_supported - - try: - model_config = ModelConfig.from_hf_config(config, - model_path=model_path, - dtype=dtype) - if model_config.dtype == torch.bfloat16: - assert is_bf16_supported(device_type), ( - 'bf16 is not supported on your device') - except AssertionError as e: - message = ( - f'Your device does not support `{model_config.dtype}`. ' - 'You can set `dtype` to float16 in PyTorchEngineConfig or ' - '`--dtype float16` to api_server.\n' - 'Note that this might have negative effect!') - _handle_exception(e, 'Model', logger, message=message) - except Exception as e: - message = (f'Checking failed with error {e}', - 'Please send issue to LMDeploy with error logs.') - _handle_exception(e, 'Model', logger, message=message) - - return model_config - - _, trans_version = __check_transformers_version() - config = __check_config(trans_version) - __check_model_transformers_version(config, trans_version) - __check_model_dtype_support(config, device_type) - check_awq(config, device_type) - - -def check_model(model_path: str, - trust_remote_code: bool = True, - dtype: str = 'auto', - device_type: str = 'cuda'): - """check model requirements.""" - logger = get_logger('lmdeploy') - logger.info('Checking model.') - check_transformers_version(model_path, trust_remote_code, dtype, - device_type) - - -def check_adapter(path: str): - """check adapter.""" - logger = get_logger('lmdeploy') - logger.debug(f'Checking : {path}.') - - try: - from peft import PeftConfig - PeftConfig.from_pretrained(path) - except Exception as e: - message = ('Please make sure the adapter can be loaded with ' - '`peft.PeftConfig.from_pretrained`\n') - err_msg = '' if len(e.args) == 0 else e.args[0] - if 'got an unexpected keyword argument' in err_msg: - message += ('Or try remove all unexpected keywords ' - 'in `adapter_config.json`.') - _handle_exception(e, 'Model', logger, message=message) - - -def check_adapters(adapter_paths: List[str]): - """check adapters.""" - if len(adapter_paths) <= 0: - return - logger = get_logger('lmdeploy') - logger.info('Checking adapters.') - for path in adapter_paths: - check_adapter(path) diff --git a/lmdeploy/pytorch/check_env/adapter.py b/lmdeploy/pytorch/check_env/adapter.py new file mode 100644 index 000000000..bcaf5fd0e --- /dev/null +++ b/lmdeploy/pytorch/check_env/adapter.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseChecker + + +class AdapterChecker(BaseChecker): + """check adapter is available.""" + + def __init__(self, adapter_path: str, logger=None): + super().__init__(logger) + self.adapter_path = adapter_path + + def check(self): + """check.""" + path = self.adapter_path + + try: + import peft # noqa: F401 + except Exception as e: + self.log_and_exit(e, 'Adapter', message='Failed to import peft.') + + try: + from peft import PeftConfig + PeftConfig.from_pretrained(path) + except Exception as e: + message = ('Please make sure the adapter can be loaded with ' + '`peft.PeftConfig.from_pretrained`\n') + err_msg = '' if len(e.args) == 0 else e.args[0] + if 'got an unexpected keyword argument' in err_msg: + message += ('Or try remove all unexpected keywords ' + 'in `adapter_config.json`.') + self.log_and_exit(e, 'Adapter', message=message) diff --git a/lmdeploy/pytorch/check_env/base.py b/lmdeploy/pytorch/check_env/base.py new file mode 100644 index 000000000..ed5e5a600 --- /dev/null +++ b/lmdeploy/pytorch/check_env/base.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from logging import Logger +from typing import List + +from lmdeploy.utils import get_logger + +RED_COLOR = '\033[31m' +RESET_COLOR = '\033[0m' + + +def _red_text(text: str): + """red text.""" + return f'{RED_COLOR}{text}{RESET_COLOR}' + + +class BaseChecker: + """base checker.""" + + def __init__(self, logger: Logger = None): + if logger is None: + logger = get_logger('lmdeploy') + self.logger = logger + self._is_passed = False + self._required_checker: List[BaseChecker] = list() + + def get_logger(self): + """get logger.""" + return self.logger + + def register_required_checker(self, checker: 'BaseChecker'): + """register_required.""" + self._required_checker.append(checker) + + def handle(self): + """handle check.""" + is_passed = getattr(self, '_is_passed', False) + if not is_passed: + checker_name = type(self).__name__ + self.logger.debug(f'Checking <{checker_name}>:') + for checker in self._required_checker: + checker.handle() + self.check() + self.is_passed = True + + def log_and_exit(self, + e: Exception = None, + mod_name: str = None, + message: str = None): + logger = self.logger + if mod_name is None: + mod_name = type(self).__name__ + if message is None: + message = 'Please check your environment.' + logger.debug('Exception', exc_info=1) + if e is not None: + logger.error(f'{type(e).__name__}: {e}') + logger.error(f'<{mod_name}> check failed!\n{_red_text(message)}') + exit(1) + + def check(self): + """check.""" + raise NotImplementedError('check not implemented.') diff --git a/lmdeploy/pytorch/check_env/deeplink.py b/lmdeploy/pytorch/check_env/deeplink.py new file mode 100644 index 000000000..74ab5a7b8 --- /dev/null +++ b/lmdeploy/pytorch/check_env/deeplink.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseChecker + +deeplink_device_type_list = [ + 'ascend', + 'npu', + 'maca', +] + + +class DeeplinkChecker(BaseChecker): + """check pytorch is available.""" + + def __init__(self, device_type: str, logger=None) -> None: + super().__init__(logger=logger) + self.device_type = device_type + + def check(self): + """check.""" + device_type = self.device_type + if device_type in deeplink_device_type_list: + try: + import dlinfer.framework.lmdeploy_ext # noqa: F401 + except Exception as e: + self.log_and_exit(e, 'dlinfer', 'dlinfer is not available.') diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py new file mode 100644 index 000000000..4b721e50e --- /dev/null +++ b/lmdeploy/pytorch/check_env/model.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from packaging import version + +from .base import BaseChecker + + +class ModelChecker(BaseChecker): + """check model is available.""" + + def __init__(self, + model_path: str, + trust_remote_code: bool, + dtype: str, + device_type: str, + logger=None) -> None: + super().__init__(logger=logger) + self.model_path = model_path + self.trust_remote_code = trust_remote_code + self.device_type = device_type + self.dtype = dtype + + def check_config(self, trans_version): + """check config.""" + model_path = self.model_path + trust_remote_code = self.trust_remote_code + try: + from transformers import AutoConfig + config = AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code) + except Exception as e: + message = ( + f'Load model config with transformers=={trans_version}' + ' failed. ' + 'Please make sure model can be loaded with transformers API.') + self.log_and_exit(e, 'transformers', message=message) + return config + + def check_trans_version(self, config, trans_version): + """check transformers version.""" + model_path = self.model_path + try: + model_trans_version = getattr(config, 'transformers_version', None) + if model_trans_version is not None: + model_trans_version = version.parse(model_trans_version) + assert trans_version >= model_trans_version, ( + 'Version mismatch.') + except Exception as e: + message = (f'model `{model_path}` requires ' + f'transformers version {model_trans_version} ' + f'but transformers {trans_version} is installed.') + self.log_and_exit(e, 'transformers', message=message) + + def check_dtype(self, config): + """check dtype.""" + logger = self.get_logger() + model_path = self.model_path + device_type = self.device_type + dtype = self.dtype + try: + import torch + + from lmdeploy.pytorch.config import ModelConfig + from lmdeploy.utils import is_bf16_supported + model_config = ModelConfig.from_hf_config(config, + model_path=model_path, + dtype=dtype) + if model_config.dtype == torch.bfloat16: + if not is_bf16_supported(device_type): + logger.warning('Device does not support bfloat16.') + except Exception as e: + message = (f'Checking failed with error {e}', + 'Please send issue to LMDeploy with error logs.') + self.log_and_exit(e, 'Model', message=message) + + def check_awq(self, config): + """check awq.""" + logger = self.get_logger() + device_type = self.device_type + if device_type != 'cuda': + return + + quantization_config = getattr(config, 'quantization_config', dict()) + quant_method = quantization_config.get('quant_method', None) + if quant_method != 'awq': + return + try: + import awq # noqa + except Exception as e: + self.log_and_exit(e, 'autoawq', logger) + + try: + import awq_ext # noqa + except Exception as e: + logger.debug('Exception:', exc_info=1) + self.log_and_exit( + e, + 'awq_ext', + message='Failed to import `awq_ext`. ' + 'Try reinstall it from source: ' + 'https://github.com/casper-hansen/AutoAWQ_kernels') + + def check(self): + """check.""" + import transformers + trans_version = version.parse(transformers.__version__) + + # config + config = self.check_config(trans_version) + + # transformers version + self.check_trans_version(config, trans_version) + + # dtype check + self.check_dtype(config) + + # awq + self.check_awq(config) diff --git a/lmdeploy/pytorch/check_env/torch.py b/lmdeploy/pytorch/check_env/torch.py new file mode 100644 index 000000000..14b24e04a --- /dev/null +++ b/lmdeploy/pytorch/check_env/torch.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import BaseChecker + + +class TorchChecker(BaseChecker): + """check pytorch is available.""" + + def __init__(self, device: str = 'cuda', logger=None) -> None: + super().__init__(logger=logger) + self.device = device + + def check(self): + """check.""" + try: + import torch + a = torch.tensor([1, 2], device=self.device) + b = a.new_tensor([3, 4], device=self.device) + c = a + b + torch.testing.assert_close(c, a.new_tensor([4, 6])) + except Exception as e: + self.log_and_exit(e, 'PyTorch', 'PyTorch is not available.') diff --git a/lmdeploy/pytorch/check_env/transformers.py b/lmdeploy/pytorch/check_env/transformers.py new file mode 100644 index 000000000..9d97cd6dc --- /dev/null +++ b/lmdeploy/pytorch/check_env/transformers.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from packaging import version + +from .base import BaseChecker + +MIN_TRANSFORMERS_VERSION = '4.33.0' +MAX_TRANSFORMERS_VERSION = '4.46.1' + + +class TransformersChecker(BaseChecker): + """check transformers is available.""" + + def check(self): + """check.""" + import transformers + logger = self.get_logger() + try: + trans_version = version.parse(transformers.__version__) + min_version = version.parse(MIN_TRANSFORMERS_VERSION) + max_version = version.parse(MAX_TRANSFORMERS_VERSION) + if trans_version < min_version or trans_version > max_version: + logger.warning('LMDeploy requires transformers version: ' + f'[{MIN_TRANSFORMERS_VERSION} ~ ' + f'{MAX_TRANSFORMERS_VERSION}], ' + 'but found version: ' + f'{transformers.__version__}') + except Exception as e: + self.log_and_exit(e, 'transformers', + 'transformers is not available.') diff --git a/lmdeploy/pytorch/check_env/triton.py b/lmdeploy/pytorch/check_env/triton.py new file mode 100644 index 000000000..4cc58c549 --- /dev/null +++ b/lmdeploy/pytorch/check_env/triton.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from packaging import version + +from .base import BaseChecker + +MAX_TRITON_VERSION = '3.1.0' +MIN_TRITON_VERSION = '3.0.0' + + +class TritonChecker(BaseChecker): + """check triton is available.""" + + def check_version(self): + """check version.""" + logger = self.get_logger() + + # version check + import triton + max_version = version.parse(MAX_TRITON_VERSION) + min_version = version.parse(MIN_TRITON_VERSION) + triton_version = version.parse(triton.__version__) + + if triton_version > max_version: + logger.warning('PytorchEngine has not been tested on ' + f'triton>{MAX_TRITON_VERSION}.') + if triton_version < min_version: + msg = (f'triton>={MIN_TRITON_VERSION} is required. ' + f'Found triton=={triton_version}') + self.log_and_exit(mod_name='Triton', message=msg) + + def check(self): + """check.""" + logger = self.get_logger() + + msg = ( + 'Please ensure that your device is functioning properly with .\n' # noqa: E501 + 'You can verify your environment by running ' + '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.') + try: + logger.debug('Checking environment.') + import torch + + from .triton_custom_add import custom_add + a = torch.tensor([1, 2], device='cuda') + b = a.new_tensor([3, 4], device='cuda') + c = custom_add(a, b) + torch.testing.assert_close(c, a + b) + except RuntimeError as e: + ptxas_error = 'device kernel image is invalid' + if len(e.args) > 0 and ptxas_error in e.args[0]: + msg = ( + 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501 + 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501 + ' or reinstall the driver.') + self.log_and_exit(e, 'Triton', msg) + except Exception as e: + self.log_and_exit(e, 'Triton', msg) + + # version check + self.check_version() diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 715291a90..b74c0f64a 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -14,13 +14,13 @@ logging_timer) from ..adapter.adapter import AdapterManager -from ..check_env import check_adapters, check_env, check_model from ..config import BackendConfig, CacheConfig, SchedulerConfig from ..devices import DeviceContext, get_device_manager from ..messages import (InputEmbeddingRangeType, InputEmbeddingType, MessageStatus, SchedulerSequence) from ..model_inputs import ModelInputs, MRopeModelInputs, VisionModelInputs from ..paging import Scheduler +from .engine_checker import EngineChecker from .logits_process import FusedLogitsProcessor, SamplingInputs from .model_agent import build_model_agent from .request import Request, RequestManager, RequestType, Response @@ -78,6 +78,40 @@ def _check_finish(scheduler: Scheduler, current_iter: int): return False +def _build_scheduler_config(engine_config: PytorchEngineConfig): + """build scheduler config.""" + scheduler_config = SchedulerConfig( + max_batches=engine_config.max_batch_size, + max_session_len=engine_config.session_len, + prefill_interval=engine_config.prefill_interval) + return scheduler_config + + +def _build_cache_config(engine_config: PytorchEngineConfig): + """build cache config.""" + cache_config = CacheConfig( + max_batches=engine_config.max_batch_size, + block_size=engine_config.block_size, + num_cpu_blocks=engine_config.num_cpu_blocks, + num_gpu_blocks=engine_config.num_gpu_blocks, + cache_max_entry_count=engine_config.cache_max_entry_count, + max_prefill_token_num=engine_config.max_prefill_token_num, + enable_prefix_caching=engine_config.enable_prefix_caching, + quant_policy=engine_config.quant_policy, + device_type=engine_config.device_type, + ) + return cache_config + + +def _build_backend_config(engine_config: PytorchEngineConfig): + """build backend config.""" + backend_config = BackendConfig( + eager_mode=engine_config.eager_mode, + device_type=engine_config.device_type, + ) + return backend_config + + class Engine: """The inference engine of lmdeploy pytorch. @@ -95,44 +129,23 @@ def __init__(self, engine_config = PytorchEngineConfig() else: engine_config = copy.deepcopy(engine_config) - check_env(engine_config.device_type) - check_model(model_path, trust_remote_code, engine_config.dtype, - engine_config.device_type) if engine_config.max_batch_size is None: engine_config.max_batch_size = get_max_batch_size( engine_config.device_type) - adapters = engine_config.adapters - if adapters is not None: - check_adapters(list(adapters.values())) - assert engine_config.max_batch_size > 0, 'max_batch_size should be' \ - f' greater than 0, but got {engine_config.max_batch_size}' - assert engine_config.dtype in ['auto', 'float16', 'bfloat16'], \ - f'unsupported specified data type {engine_config.dtype}' + checker = EngineChecker(model_path=model_path, + engine_config=engine_config, + trust_remote_code=trust_remote_code, + logger=logger) + checker.handle() + + adapters = engine_config.adapters self.engine_config = engine_config self.tp = engine_config.tp self.device_context = DeviceContext( device_type=engine_config.device_type) - scheduler_config = SchedulerConfig( - max_batches=engine_config.max_batch_size, - max_session_len=engine_config.session_len, - prefill_interval=engine_config.prefill_interval) - - # block_size = 1 to enable unified paging - cache_config = CacheConfig( - max_batches=engine_config.max_batch_size, - block_size=engine_config.block_size, - num_cpu_blocks=engine_config.num_cpu_blocks, - num_gpu_blocks=engine_config.num_gpu_blocks, - cache_max_entry_count=engine_config.cache_max_entry_count, - max_prefill_token_num=engine_config.max_prefill_token_num, - enable_prefix_caching=engine_config.enable_prefix_caching, - quant_policy=engine_config.quant_policy, - device_type=engine_config.device_type, - ) - if not os.path.exists(model_path): model_path = get_model(model_path, engine_config.download_dir, engine_config.revision) @@ -141,10 +154,9 @@ def __init__(self, if adapters is not None and len(adapters) > 0: adapters = self._download_adapters(adapters, engine_config) - backend_config = BackendConfig( - eager_mode=engine_config.eager_mode, - device_type=engine_config.device_type, - ) + scheduler_config = _build_scheduler_config(engine_config) + cache_config = _build_cache_config(engine_config) + backend_config = _build_backend_config(engine_config) with get_device_manager().context(self.device_context): self.model_agent = build_model_agent( diff --git a/lmdeploy/pytorch/engine/engine_checker.py b/lmdeploy/pytorch/engine/engine_checker.py new file mode 100644 index 000000000..1654ece4b --- /dev/null +++ b/lmdeploy/pytorch/engine/engine_checker.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.messages import PytorchEngineConfig + +from ..check_env.adapter import AdapterChecker +from ..check_env.base import BaseChecker +from ..check_env.model import ModelChecker +from ..check_env.torch import TorchChecker +from ..check_env.transformers import TransformersChecker + + +class EngineChecker(BaseChecker): + """check transformers is available.""" + + def __init__(self, + model_path: str, + engine_config: PytorchEngineConfig, + trust_remote_code: bool = True, + logger=None): + super().__init__(logger) + logger = self.get_logger() + + self.engine_config = engine_config + + dtype = engine_config.dtype + device_type = engine_config.device_type + + # pytorch + torch_checker = TorchChecker(logger=logger) + self.register_required_checker(torch_checker) + + if device_type == 'cuda': + # triton + from ..check_env.triton import TritonChecker + triton_checker = TritonChecker(logger=logger) + triton_checker.register_required_checker(torch_checker) + self.register_required_checker(triton_checker) + else: + # deeplink + from ..check_env.deeplink import DeeplinkChecker + dl_checker = DeeplinkChecker(device_type, logger=logger) + self.register_required_checker(dl_checker) + + # transformers + + # model + trans_checker = TransformersChecker() + model_checker = ModelChecker(model_path=model_path, + trust_remote_code=trust_remote_code, + dtype=dtype, + device_type=device_type, + logger=logger) + model_checker.register_required_checker(torch_checker) + model_checker.register_required_checker(trans_checker) + self.register_required_checker(model_checker) + + # adapters + adapters = engine_config.adapters + if adapters is not None: + adapter_paths = list(adapters.values()) + for adapter in adapter_paths: + adapter_checker = AdapterChecker(adapter, logger=logger) + self.register_required_checker(adapter_checker) + + def check(self): + """check.""" + engine_config = self.engine_config + logger = self.get_logger() + + if engine_config.thread_safe: + logger.warning('thread safe mode has been deprecated and' + ' it would be removed in the future.') + + if engine_config.max_batch_size <= 0: + self.log_and_exit( + mod_name='Engine', + message='max_batch_size should be' + f' greater than 0, but got {engine_config.max_batch_size}') diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 400c492b0..a11a74942 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -15,8 +15,8 @@ safetensors sentencepiece shortuuid tiktoken -torch<=2.4.0,>=2.0.0 +torch<=2.5.1,>=2.0.0 torchvision<=0.19.0,>=0.15.0 transformers -triton>=2.2.0,<=3.0.0; sys_platform == "linux" +triton==3.0.0; sys_platform == "linux" uvicorn