diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index 1314e464eff..62aefc2bb88 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -588,6 +588,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]: pretrained_model_name_or_path, *model_args, config=config, + model_hub=model_hub, **kwargs, ) logger.info( @@ -1452,9 +1453,13 @@ def train_func(model): model.quantization_config = None return model else: - model = cls.ORIG_MODEL.from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **kwargs - ) + if model_hub=="modelscope": + from modelscope import AutoModelForCausalLM # pylint: disable=E0401 + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,trust_remote_code=True) + else: + model = cls.ORIG_MODEL.from_pretrained( + pretrained_model_name_or_path, *model_args, config=config, **kwargs + ) if ( not torch.cuda.is_available() or device_map == "cpu" @@ -1520,7 +1525,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): device_map = kwargs.pop("device_map", "auto") use_safetensors = kwargs.pop("use_safetensors", None) kwarg_attn_imp = kwargs.pop("attn_implementation", None) - + model_hub = kwargs.pop("model_hub", None) # lm-eval device map is dictionary device_map = device_map[""] if isinstance(device_map, dict) and "" in device_map else device_map @@ -1709,7 +1714,20 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, } - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + if model_hub=="modelscope": + from modelscope import snapshot_download # pylint: disable=E0401 + model_dir = snapshot_download(pretrained_model_name_or_path) + if os.path.exists(model_dir+"/model.safetensors"): + resolved_archive_file = model_dir+"/model.safetensors" + elif os.path.exists(model_dir+"/pytorch_model.bin"): + resolved_archive_file = model_dir+"/pytorch_model.bin" + else: + assert ( + resolved_archive_file is not None + ), "Don't detect this model checkpoint" + else: + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, + **cached_file_kwargs) # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not.