Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Enable modelscope for itrex #1655

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]:
pretrained_model_name_or_path,
*model_args,
config=config,
model_hub=model_hub,
**kwargs,
)
logger.info(
Expand Down Expand Up @@ -1452,9 +1453,13 @@ def train_func(model):
model.quantization_config = None
return model
else:
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
if model_hub=="modelscope":
from modelscope import AutoModelForCausalLM # pylint: disable=E0401
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path,trust_remote_code=True)
else:
model = cls.ORIG_MODEL.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
if (
not torch.cuda.is_available()
or device_map == "cpu"
Expand Down Expand Up @@ -1520,7 +1525,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
device_map = kwargs.pop("device_map", "auto")
use_safetensors = kwargs.pop("use_safetensors", None)
kwarg_attn_imp = kwargs.pop("attn_implementation", None)

model_hub = kwargs.pop("model_hub", None)
# lm-eval device map is dictionary
device_map = device_map[""] if isinstance(device_map, dict) and "" in device_map else device_map

Expand Down Expand Up @@ -1709,7 +1714,20 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
if model_hub=="modelscope":
from modelscope import snapshot_download # pylint: disable=E0401
model_dir = snapshot_download(pretrained_model_name_or_path)
if os.path.exists(model_dir+"/model.safetensors"):
resolved_archive_file = model_dir+"/model.safetensors"
elif os.path.exists(model_dir+"/pytorch_model.bin"):
resolved_archive_file = model_dir+"/pytorch_model.bin"
else:
assert (
resolved_archive_file is not None
), "Don't detect this model checkpoint"
else:
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename,
**cached_file_kwargs)

# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not.
Expand Down
Loading