Skip to content

Commit

Permalink
add mps
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Aug 31, 2023
1 parent 4a6b654 commit 4e00e48
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion elk/utils/gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def select_usable_devices(
# Trivial case: no GPUs requested or available
num_visible = torch.cuda.device_count()
if num_gpus == 0 or num_visible == 0:
return ["cpu"]
if torch.backends.mps.is_available():
return ["mps"]
else:
return ["cpu"]

# Sanity checks
if num_gpus > num_visible:
Expand Down
2 changes: 1 addition & 1 deletion elk/utils/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def instantiate_model(
kwargs["torch_dtype"] = torch.float16

# CPUs generally don't support anything other than fp32.
elif device.type == "cpu":
elif device.type in ("cpu", "mps"):
kwargs["torch_dtype"] = torch.float32

# If the model is fp32 but bf16 is available, convert to bf16.
Expand Down

0 comments on commit 4e00e48

Please sign in to comment.