Skip to content

Commit

Permalink
Check bf16 model in torch engine (#1270)
Browse files Browse the repository at this point in the history
* add model dtype check

* fix trust remote code

* update log message
  • Loading branch information
grimoire authored Mar 13, 2024
1 parent 9c3069f commit 920d719
Showing 1 changed file with 68 additions and 29 deletions.
97 changes: 68 additions & 29 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,41 +66,80 @@ def check_transformers_version(model_path: str,
"""check transformers version."""
from packaging import version
logger = get_logger('lmdeploy')
logger.debug('Checking <transformers> version.')

trans_version = None
try:
import transformers
trans_version = version.parse(transformers.__version__)
except Exception as e:
_handle_exception(e, 'transformers', logger)
def __check_transformers_version():
"""check transformers version."""
logger.debug('Checking <transformers> version.')
trans_version = None
try:
import transformers
trans_version = version.parse(transformers.__version__)
except Exception as e:
_handle_exception(e, 'transformers', logger)
return transformers, trans_version

def __check_config(trans_version):
"""check config."""
logger.debug('Checking <Model> AutoConfig.from_pretrained.')
try:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code)
except Exception as e:
message = (
f'Load model config with transformers=={trans_version}'
' failed. '
'Please make sure model can be loaded with transformers API.')
_handle_exception(e, 'transformers', logger, message=message)
return config

def __check_model_transformers_version(config, trans_version):
"""check model transformers version."""
logger.debug('Checking <Model> required transformers version.')
try:
model_trans_version = getattr(config, 'transformers_version')
model_trans_version = version.parse(model_trans_version)
assert trans_version >= model_trans_version, 'Version mismatch.'
except Exception as e:
message = (f'model `{model_path}` requires '
f'transformers version {model_trans_version} '
f'but transformers {trans_version} is installed.')
_handle_exception(e, 'transformers', logger, message=message)

def __check_model_dtype_support(config):
"""Checking model dtype support."""
logger.debug('Checking <Model> dtype support.')

model_trans_version = None
try:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code)
model_trans_version = getattr(config, 'transformers_version')
except Exception as e:
message = (
f'Load model config with transformers=={trans_version} failed. '
'Please make sure model can be loaded with transformers API.')
_handle_exception(e, 'transformers', logger, message=message)
import torch

try:
model_trans_version = version.parse(model_trans_version)
assert trans_version >= model_trans_version
except Exception as e:
message = {
f'model `{model_path}` requires '
f'transformers version {model_trans_version} '
f'but transformers {trans_version} is installed.'
}
_handle_exception(e, 'transformers', logger, message=message)
from lmdeploy.pytorch.config import ModelConfig

try:
model_config = ModelConfig.from_hf_config(config,
model_path=model_path)
if model_config.dtype == torch.bfloat16:
assert not torch.cuda.is_bf16_supported(), (
'bf16 is not supported on your device')
except AssertionError as e:
message = (f'Your device does not support `{model_config.dtype}`. '
'Try edit `torch_dtype` in `config.json`.\n'
'Note that this might have negative effect!')
_handle_exception(e, 'Model', logger, message=message)
except Exception as e:
message = (f'Checking failed with error {e}',
'Please send issue to LMDeploy with error logs.')
_handle_exception(e, 'Model', logger, message=message)

return model_config

_, trans_version = __check_transformers_version()
config = __check_config(trans_version)
__check_model_transformers_version(config, trans_version)
__check_model_dtype_support(config)


def check_model(model_path: str, trust_remote_code: bool = True):
"""check model requirements."""
logger = get_logger('lmdeploy')
logger.info('Checking model.')
check_transformers_version(model_path)
check_transformers_version(model_path, trust_remote_code)

0 comments on commit 920d719

Please sign in to comment.