From 4e00e48b30057d1af9c58fa7d23f61fac27863e6 Mon Sep 17 00:00:00 2001 From: jon Date: Fri, 1 Sep 2023 00:34:29 +0300 Subject: [PATCH] add mps --- elk/utils/gpu_utils.py | 5 ++++- elk/utils/hf_utils.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index a4294298..0492b1ec 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -55,7 +55,10 @@ def select_usable_devices( # Trivial case: no GPUs requested or available num_visible = torch.cuda.device_count() if num_gpus == 0 or num_visible == 0: - return ["cpu"] + if torch.backends.mps.is_available(): + return ["mps"] + else: + return ["cpu"] # Sanity checks if num_gpus > num_visible: diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 9f429921..670e99ca 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -47,7 +47,7 @@ def instantiate_model( kwargs["torch_dtype"] = torch.float16 # CPUs generally don't support anything other than fp32. - elif device.type == "cpu": + elif device.type in ("cpu", "mps"): kwargs["torch_dtype"] = torch.float32 # If the model is fp32 but bf16 is available, convert to bf16.