From 811ea4d5828fe9f7899d5630ad4c5b5d3fcf59ef Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 11 Mar 2024 11:45:38 +0800 Subject: [PATCH 1/3] add model dtype check --- lmdeploy/pytorch/check_env/__init__.py | 93 ++++++++++++++++++-------- 1 file changed, 65 insertions(+), 28 deletions(-) diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 62ba7f11f7..44e9fb9619 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -66,37 +66,74 @@ def check_transformers_version(model_path: str, """check transformers version.""" from packaging import version logger = get_logger('lmdeploy') - logger.debug('Checking version.') - trans_version = None - try: - import transformers - trans_version = version.parse(transformers.__version__) - except Exception as e: - _handle_exception(e, 'transformers', logger) + def __check_transformers_version(): + """check transformers version.""" + logger.debug('Checking version.') + trans_version = None + try: + import transformers + trans_version = version.parse(transformers.__version__) + except Exception as e: + _handle_exception(e, 'transformers', logger) + return transformers, trans_version + + def __check_config(trans_version): + """check config.""" + logger.debug('Checking AutoConfig.from_pretrained.') + try: + from transformers import AutoConfig + config = AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code) + except Exception as e: + message = ( + f'Load model config with transformers=={trans_version} failed. ' + 'Please make sure model can be loaded with transformers API.') + _handle_exception(e, 'transformers', logger, message=message) + return config + + def __check_model_transformers_version(config, trans_version): + """check model transformers version.""" + logger.debug('Checking required transformers version.') + try: + model_trans_version = getattr(config, 'transformers_version') + model_trans_version = version.parse(model_trans_version) + assert trans_version >= model_trans_version, 'Version mismatch.' + except Exception as e: + message = (f'model `{model_path}` requires ' + f'transformers version {model_trans_version} ' + f'but transformers {trans_version} is installed.') + _handle_exception(e, 'transformers', logger, message=message) + + def __check_model_dtype_support(config): + """Checking model dtype support.""" + logger.debug('Checking dtype support.') - model_trans_version = None - try: - from transformers import AutoConfig - config = AutoConfig.from_pretrained( - model_path, trust_remote_code=trust_remote_code) - model_trans_version = getattr(config, 'transformers_version') - except Exception as e: - message = ( - f'Load model config with transformers=={trans_version} failed. ' - 'Please make sure model can be loaded with transformers API.') - _handle_exception(e, 'transformers', logger, message=message) + import torch - try: - model_trans_version = version.parse(model_trans_version) - assert trans_version >= model_trans_version - except Exception as e: - message = { - f'model `{model_path}` requires ' - f'transformers version {model_trans_version} ' - f'but transformers {trans_version} is installed.' - } - _handle_exception(e, 'transformers', logger, message=message) + from lmdeploy.pytorch.config import ModelConfig + + try: + model_config = ModelConfig.from_hf_config(config, + model_path=model_path) + if model_config.dtype == torch.bfloat16: + assert torch.cuda.is_bf16_supported(), ( + 'bf16 is not supported on your device') + except AssertionError as e: + message = (f'Your device does not support {model_config.dtype}. ' + 'Try edit `torch_dtype` in `config.json`.') + _handle_exception(e, 'Model', logger, message=message) + except Exception as e: + message = (f'Checking failed with error {e}', + 'Please send issue to LMDeploy with error logs.') + _handle_exception(e, 'Model', logger, message=message) + + return model_config + + _, trans_version = __check_transformers_version() + config = __check_config(trans_version) + __check_model_transformers_version(config, trans_version) + __check_model_dtype_support(config) def check_model(model_path: str, trust_remote_code: bool = True): From 09a148bf3a0e92137dd7879102ec758386a11f7e Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 11 Mar 2024 11:46:10 +0800 Subject: [PATCH 2/3] fix trust remote code --- lmdeploy/pytorch/check_env/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 44e9fb9619..642b739145 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -140,4 +140,4 @@ def check_model(model_path: str, trust_remote_code: bool = True): """check model requirements.""" logger = get_logger('lmdeploy') logger.info('Checking model.') - check_transformers_version(model_path) + check_transformers_version(model_path, trust_remote_code) From 0dc1aeedbb8971beb2262a705ea093f420f16132 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 11 Mar 2024 11:50:08 +0800 Subject: [PATCH 3/3] update log message --- lmdeploy/pytorch/check_env/__init__.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index 642b739145..7b9f20f6b4 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -87,7 +87,8 @@ def __check_config(trans_version): model_path, trust_remote_code=trust_remote_code) except Exception as e: message = ( - f'Load model config with transformers=={trans_version} failed. ' + f'Load model config with transformers=={trans_version}' + ' failed. ' 'Please make sure model can be loaded with transformers API.') _handle_exception(e, 'transformers', logger, message=message) return config @@ -117,11 +118,12 @@ def __check_model_dtype_support(config): model_config = ModelConfig.from_hf_config(config, model_path=model_path) if model_config.dtype == torch.bfloat16: - assert torch.cuda.is_bf16_supported(), ( + assert not torch.cuda.is_bf16_supported(), ( 'bf16 is not supported on your device') except AssertionError as e: - message = (f'Your device does not support {model_config.dtype}. ' - 'Try edit `torch_dtype` in `config.json`.') + message = (f'Your device does not support `{model_config.dtype}`. ' + 'Try edit `torch_dtype` in `config.json`.\n' + 'Note that this might have negative effect!') _handle_exception(e, 'Model', logger, message=message) except Exception as e: message = (f'Checking failed with error {e}',