From dc3692c13f6487090cf38c5e6bb968c60832c41b Mon Sep 17 00:00:00 2001 From: y Date: Sat, 10 Aug 2024 03:11:49 +0800 Subject: [PATCH 1/2] fix mps rely on flash_atten --- web_demo_2.6.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/web_demo_2.6.py b/web_demo_2.6.py index 2bebc19d..73a1080b 100644 --- a/web_demo_2.6.py +++ b/web_demo_2.6.py @@ -15,7 +15,9 @@ import traceback import re import modelscope_studio as mgr - +from typing import Union +from transformers.dynamic_module_utils import get_imports +from unittest.mock import patch # README, How to run demo on different devices @@ -66,6 +68,9 @@ model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map) else: model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + if device == 'mps': + with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): + pass model = model.to(device=device) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model.eval() @@ -79,6 +84,12 @@ IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'} +def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]: + imports = get_imports(filename) + if not torch.cuda.is_available() and "flash_attn" in imports: + imports.remove("flash_attn") + return imports + def get_file_extension(filename): return os.path.splitext(filename)[1].lower() From a242b4449c062cb569e86d9ead9de01bdc21f4f7 Mon Sep 17 00:00:00 2001 From: y Date: Sat, 10 Aug 2024 03:11:49 +0800 Subject: [PATCH 2/2] fix mps rely on flash_atten --- web_demo_2.6.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/web_demo_2.6.py b/web_demo_2.6.py index 2bebc19d..d60a6f4c 100644 --- a/web_demo_2.6.py +++ b/web_demo_2.6.py @@ -15,7 +15,9 @@ import traceback import re import modelscope_studio as mgr - +from typing import Union +from transformers.dynamic_module_utils import get_imports +from unittest.mock import patch # README, How to run demo on different devices @@ -65,7 +67,11 @@ model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map) else: - model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + if device == 'mps': + with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): + model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + else: + model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) model = model.to(device=device) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model.eval() @@ -79,6 +85,12 @@ IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} VIDEO_EXTENSIONS = {'.mp4', '.mkv', '.mov', '.avi', '.flv', '.wmv', '.webm', '.m4v'} +def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]: + imports = get_imports(filename) + if not torch.cuda.is_available() and "flash_attn" in imports: + imports.remove("flash_attn") + return imports + def get_file_extension(filename): return os.path.splitext(filename)[1].lower()