Skip to content

Commit

Permalink
Test MPS specification
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Sep 15, 2023
1 parent 533882f commit 3f82352
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
2 changes: 1 addition & 1 deletion lab/torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def global_random_state(_: TorchDType):
raise RuntimeError(f'Unknown active device "{B.ActiveDevice.active_name}".')

if parts[0] == "mps":
if int(parts[1]) != 0:
if len(parts) == 2 and int(parts[1]) != 0:
raise ValueError("Cannot specify a device number for PyTorch MPS.")

import torch.mps as mps
Expand Down
31 changes: 21 additions & 10 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import lab.tensorflow
import lab.torch

# noinspection PyUnresolvedReferences
from .util import PositiveTensor, Tensor, approx, check_lazy_shapes, to_np
from .util import PositiveTensor, Tensor, approx, check_lazy_shapes, to_np # noqa


@pytest.mark.parametrize(
Expand Down Expand Up @@ -173,7 +172,7 @@ def test_randbeta_parameters(t, check_lazy_shapes):
approx(B.randbeta(t, alpha=1, beta=1e-6), 1, atol=1e-6)


def test_torch_global_random_state(mocker):
def test_torch_global_random_state(mocker, monkeypatch):
# Check CPU specifications.
B.ActiveDevice.active_name = None
assert B.global_random_state(torch.float32) is torch.random.default_generator
Expand All @@ -190,23 +189,35 @@ def test_torch_global_random_state(mocker):
assert torch_cuda_init.called_once()

# Now set some fake default generators.
torch.cuda.default_generators = (0, 1)
monkeypatch.setattr("torch.cuda.default_generators", (33, 34))
monkeypatch.setattr("torch.mps._get_default_mps_generator", lambda: 35)

# Check GPU specifications.
B.ActiveDevice.active_name = "cuda"
assert B.global_random_state(torch.float32) == 0
assert B.global_random_state(torch.float32) == 33
B.ActiveDevice.active_name = "gpu"
assert B.global_random_state(torch.float32) == 0
assert B.global_random_state(torch.float32) == 33
B.ActiveDevice.active_name = "gpu:0"
assert B.global_random_state(torch.float32) == 0
assert B.global_random_state(torch.float32) == 33
B.ActiveDevice.active_name = "gpu:1"
assert B.global_random_state(torch.float32) == 1
assert B.global_random_state(torch.float32) == 34
with pytest.raises(RuntimeError):
B.ActiveDevice.active_name = "weird-device"
assert B.global_random_state(torch.float32) == 1
B.global_random_state(torch.float32)

# Check MPS specification.
B.ActiveDevice.active_name = "mps"
assert B.global_random_state(torch.float32) == 35
B.ActiveDevice.active_name = "mps:0"
assert B.global_random_state(torch.float32) == 35
with pytest.raises(
ValueError,
match="(?i)cannot specify a device number for PyTorch MPS",
):
B.ActiveDevice.active_name = "mps:1"
B.global_random_state(torch.float32)

# Reset back to defaults.
torch.cuda.default_generators = ()
B.ActiveDevice.active_name = None


Expand Down

0 comments on commit 3f82352

Please sign in to comment.