Skip to content

Commit

Permalink
fix_mps_index_error
Browse files Browse the repository at this point in the history
  • Loading branch information
suquark committed Nov 26, 2023
1 parent a754c48 commit 1263fde
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,16 @@ def load_model(
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
elif device == "mps":
kwargs = {"torch_dtype": torch.float16}
# Avoid bugs in mps backend by not using in-place operations.
replace_llama_attn_with_non_inplace_operations()
import transformers

version = tuple(int(v) for v in transformers.__version__.split("."))
if version < (4, 35, 0):
# NOTE: Recent transformers library seems has fix the mps issue, also
# it has made some changes causing compatibility issues with the
# inplace operation. So we only apply the patch for older versions.

# Avoid bugs in mps backend by not using in-place operations.
replace_llama_attn_with_non_inplace_operations()
elif device == "xpu":
kwargs = {"torch_dtype": torch.bfloat16}
# Try to load ipex, while it looks unused, it links into torch for xpu support
Expand Down

0 comments on commit 1263fde

Please sign in to comment.