From b4e71599f0badcf2b16f3e8e54f7957efbfc6ff6 Mon Sep 17 00:00:00 2001 From: intellinjun Date: Fri, 5 Jul 2024 11:36:49 +0800 Subject: [PATCH 1/2] enable modelscope for itrex Signed-off-by: intellinjun --- .../transformers/modeling/modeling_auto.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index a5be8cdc519..4e5c86ec7ad 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -587,6 +587,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: pretrained_model_name_or_path, *model_args, config=config, + model_hub=model_hub, **kwargs, ) logger.info( @@ -1451,9 +1452,13 @@ def train_func(model): model.quantization_config = None return model else: - model = cls.ORIG_MODEL.from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) + if model_hub=="modelscope": + from modelscope import AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,trust_remote_code=True) + else: + model = cls.ORIG_MODEL.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) if ( not torch.cuda.is_available() or device_map == "cpu" @@ -1519,7 +1524,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): device_map = kwargs.pop("device_map", "auto") use_safetensors = kwargs.pop("use_safetensors", None) kwarg_attn_imp = kwargs.pop("attn_implementation", None) - + model_hub = kwargs.pop("model_hub", None) # lm-eval device map is dictionary device_map = device_map[""] if isinstance(device_map, dict) and "" in device_map else device_map @@ -1708,7 +1713,19 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, } - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + if model_hub=="modelscope": + from modelscope import snapshot_download + model_dir = snapshot_download(pretrained_model_name_or_path) + if os.path.exists(model_dir+"/model.safetensors"): + resolved_archive_file = model_dir+"/model.safetensors" + elif os.path.exists(model_dir+"/pytorch_model.bin"): + resolved_archive_file = model_dir+"/pytorch_model.bin" + else: + assert ( + resolved_archive_file is not None + ), "Don't detect this model checkpoint" + else: + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not. From 6c26dd2d24701f70a885807b47df36d18195c378 Mon Sep 17 00:00:00 2001 From: intellinjun Date: Tue, 9 Jul 2024 09:04:03 +0800 Subject: [PATCH 2/2] fix format error Signed-off-by: intellinjun --- .../transformers/modeling/modeling_auto.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index 4e5c86ec7ad..ba965d7a45f 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -1453,7 +1453,7 @@ def train_func(model): return model else: if model_hub=="modelscope": - from modelscope import AutoModelForCausalLM + from modelscope import AutoModelForCausalLM # pylint: disable=E0401 model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,trust_remote_code=True) else: model = cls.ORIG_MODEL.from_pretrained( @@ -1714,7 +1714,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): "_commit_hash": commit_hash, } if model_hub=="modelscope": - from modelscope import snapshot_download + from modelscope import snapshot_download # pylint: disable=E0401 model_dir = snapshot_download(pretrained_model_name_or_path) if os.path.exists(model_dir+"/model.safetensors"): resolved_archive_file = model_dir+"/model.safetensors" @@ -1725,7 +1725,8 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): resolved_archive_file is not None ), "Don't detect this model checkpoint" else: - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, + **cached_file_kwargs) # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not.