Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check bf16 model in torch engine #1270

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 68 additions & 29 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,41 +66,80 @@ def check_transformers_version(model_path: str,
"""check transformers version."""
from packaging import version
logger = get_logger('lmdeploy')
logger.debug('Checking <transformers> version.')

trans_version = None
try:
import transformers
trans_version = version.parse(transformers.__version__)
except Exception as e:
_handle_exception(e, 'transformers', logger)
def __check_transformers_version():
"""check transformers version."""
logger.debug('Checking <transformers> version.')
trans_version = None
try:
import transformers
trans_version = version.parse(transformers.__version__)
except Exception as e:
_handle_exception(e, 'transformers', logger)
return transformers, trans_version

def __check_config(trans_version):
"""check config."""
logger.debug('Checking <Model> AutoConfig.from_pretrained.')
try:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code)
except Exception as e:
message = (
f'Load model config with transformers=={trans_version}'
' failed. '
'Please make sure model can be loaded with transformers API.')
_handle_exception(e, 'transformers', logger, message=message)
return config

def __check_model_transformers_version(config, trans_version):
"""check model transformers version."""
logger.debug('Checking <Model> required transformers version.')
try:
model_trans_version = getattr(config, 'transformers_version')
model_trans_version = version.parse(model_trans_version)
assert trans_version >= model_trans_version, 'Version mismatch.'
except Exception as e:
message = (f'model `{model_path}` requires '
f'transformers version {model_trans_version} '
f'but transformers {trans_version} is installed.')
_handle_exception(e, 'transformers', logger, message=message)

def __check_model_dtype_support(config):
"""Checking model dtype support."""
logger.debug('Checking <Model> dtype support.')

model_trans_version = None
try:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code)
model_trans_version = getattr(config, 'transformers_version')
except Exception as e:
message = (
f'Load model config with transformers=={trans_version} failed. '
'Please make sure model can be loaded with transformers API.')
_handle_exception(e, 'transformers', logger, message=message)
import torch

try:
model_trans_version = version.parse(model_trans_version)
assert trans_version >= model_trans_version
except Exception as e:
message = {
f'model `{model_path}` requires '
f'transformers version {model_trans_version} '
f'but transformers {trans_version} is installed.'
}
_handle_exception(e, 'transformers', logger, message=message)
from lmdeploy.pytorch.config import ModelConfig

try:
model_config = ModelConfig.from_hf_config(config,
model_path=model_path)
if model_config.dtype == torch.bfloat16:
assert not torch.cuda.is_bf16_supported(), (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种情况下,告警,并 fallback到 fp16,如何?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning 用户不一定会注意,这样更醒目,也能强调 fallback 会造成精度问题。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有没有可以在 engine 中 fallback 到 fp16的办法呢?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,

model = AutoModelForCausalLM.from_config(
config,
torch_dtype=torch_dtype,

可以,但是没有办法保证这样的计算结果的正确性。用户也很容易无视 warning。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也好,我们先看看社区的反馈。

'bf16 is not supported on your device')
except AssertionError as e:
message = (f'Your device does not support `{model_config.dtype}`. '
'Try edit `torch_dtype` in `config.json`.\n'
'Note that this might have negative effect!')
_handle_exception(e, 'Model', logger, message=message)
except Exception as e:
message = (f'Checking failed with error {e}',
'Please send issue to LMDeploy with error logs.')
_handle_exception(e, 'Model', logger, message=message)

return model_config

_, trans_version = __check_transformers_version()
config = __check_config(trans_version)
__check_model_transformers_version(config, trans_version)
__check_model_dtype_support(config)


def check_model(model_path: str, trust_remote_code: bool = True):
"""check model requirements."""
logger = get_logger('lmdeploy')
logger.info('Checking model.')
check_transformers_version(model_path)
check_transformers_version(model_path, trust_remote_code)
Loading