diff --git a/elk/utils/gpu_utils.py b/elk/utils/gpu_utils.py index a4294298..0492b1ec 100644 --- a/elk/utils/gpu_utils.py +++ b/elk/utils/gpu_utils.py @@ -55,7 +55,10 @@ def select_usable_devices( # Trivial case: no GPUs requested or available num_visible = torch.cuda.device_count() if num_gpus == 0 or num_visible == 0: - return ["cpu"] + if torch.backends.mps.is_available(): + return ["mps"] + else: + return ["cpu"] # Sanity checks if num_gpus > num_visible: diff --git a/elk/utils/hf_utils.py b/elk/utils/hf_utils.py index 9f429921..670e99ca 100644 --- a/elk/utils/hf_utils.py +++ b/elk/utils/hf_utils.py @@ -47,7 +47,7 @@ def instantiate_model( kwargs["torch_dtype"] = torch.float16 # CPUs generally don't support anything other than fp32. - elif device.type == "cpu": + elif device.type in ("cpu", "mps"): kwargs["torch_dtype"] = torch.float32 # If the model is fp32 but bf16 is available, convert to bf16.