Skip to content

Commit

Permalink
Merge pull request #11 from magnusross/mps-support
Browse files Browse the repository at this point in the history
Adds support for torch MPS backend
  • Loading branch information
wesselb authored Sep 15, 2023
2 parents d470a4d + 47472ad commit 8c45363
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions lab/torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,29 @@ def global_random_state(_: TorchDType):
if B.ActiveDevice.active_name in {None, "cpu"}:
return torch.random.default_generator
else:
parts = B.ActiveDevice.active_name.lower().split(":")
parts = B.ActiveDevice.active_name.lower().split(":", 1)

if len(parts) == 0 or parts[0] not in {"cuda", "gpu"}:
raise RuntimeError(f'Unknown active device "{B.ActiveDevice.active_name}".')
if len(parts) == 0 or parts[0] not in {"cuda", "gpu", "mps"}:
raise RuntimeError(
f'Unknown active device "{B.ActiveDevice.active_name}".')

# Ensure that the generators are available.
if len(torch.cuda.default_generators) == 0:
torch.cuda.init()
if parts[0] == "mps":
if parts[1] != "0":
raise ValueError(
"Cannot specify a device number for PyTorch MPS.")

if len(parts) == 1:
return torch.cuda.default_generators[0]
import torch.mps as mps

return mps._get_default_mps_generator()
else:
return torch.cuda.default_generators[int(parts[1])]
# Ensure that the generators are available.
if len(torch.cuda.default_generators) == 0:
torch.cuda.init()

if len(parts) == 1:
return torch.cuda.default_generators[0]
else:
return torch.cuda.default_generators[int(parts[1])]


@dispatch
Expand Down

0 comments on commit 8c45363

Please sign in to comment.