Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check bf16 model in torch engine #1270

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add model dtype check
  • Loading branch information
grimoire committed Mar 11, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 811ea4d5828fe9f7899d5630ad4c5b5d3fcf59ef
93 changes: 65 additions & 28 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
@@ -66,37 +66,74 @@ def check_transformers_version(model_path: str,
"""check transformers version."""
from packaging import version
logger = get_logger('lmdeploy')
logger.debug('Checking <transformers> version.')

trans_version = None
try:
import transformers
trans_version = version.parse(transformers.__version__)
except Exception as e:
_handle_exception(e, 'transformers', logger)
def __check_transformers_version():
"""check transformers version."""
logger.debug('Checking <transformers> version.')
trans_version = None
try:
import transformers
trans_version = version.parse(transformers.__version__)
except Exception as e:
_handle_exception(e, 'transformers', logger)
return transformers, trans_version

def __check_config(trans_version):
"""check config."""
logger.debug('Checking <Model> AutoConfig.from_pretrained.')
try:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code)
except Exception as e:
message = (
f'Load model config with transformers=={trans_version} failed. '
'Please make sure model can be loaded with transformers API.')
_handle_exception(e, 'transformers', logger, message=message)
return config

def __check_model_transformers_version(config, trans_version):
"""check model transformers version."""
logger.debug('Checking <Model> required transformers version.')
try:
model_trans_version = getattr(config, 'transformers_version')
model_trans_version = version.parse(model_trans_version)
assert trans_version >= model_trans_version, 'Version mismatch.'
except Exception as e:
message = (f'model `{model_path}` requires '
f'transformers version {model_trans_version} '
f'but transformers {trans_version} is installed.')
_handle_exception(e, 'transformers', logger, message=message)

def __check_model_dtype_support(config):
"""Checking model dtype support."""
logger.debug('Checking <Model> dtype support.')

model_trans_version = None
try:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code)
model_trans_version = getattr(config, 'transformers_version')
except Exception as e:
message = (
f'Load model config with transformers=={trans_version} failed. '
'Please make sure model can be loaded with transformers API.')
_handle_exception(e, 'transformers', logger, message=message)
import torch

try:
model_trans_version = version.parse(model_trans_version)
assert trans_version >= model_trans_version
except Exception as e:
message = {
f'model `{model_path}` requires '
f'transformers version {model_trans_version} '
f'but transformers {trans_version} is installed.'
}
_handle_exception(e, 'transformers', logger, message=message)
from lmdeploy.pytorch.config import ModelConfig

try:
model_config = ModelConfig.from_hf_config(config,
model_path=model_path)
if model_config.dtype == torch.bfloat16:
assert torch.cuda.is_bf16_supported(), (
'bf16 is not supported on your device')
except AssertionError as e:
message = (f'Your device does not support {model_config.dtype}. '
'Try edit `torch_dtype` in `config.json`.')
_handle_exception(e, 'Model', logger, message=message)
except Exception as e:
message = (f'Checking failed with error {e}',
'Please send issue to LMDeploy with error logs.')
_handle_exception(e, 'Model', logger, message=message)

return model_config

_, trans_version = __check_transformers_version()
config = __check_config(trans_version)
__check_model_transformers_version(config, trans_version)
__check_model_dtype_support(config)


def check_model(model_path: str, trust_remote_code: bool = True):