Skip to content

Commit

Permalink
Fix has_fp64_support for MPS on Apple Silicon
Browse files Browse the repository at this point in the history
  • Loading branch information
ogrisel committed Sep 20, 2023
1 parent f39ff1d commit 8d6b69d
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions sklearn_pytorch_engine/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ def has_fp64_support(device):
try:
torch.zeros(1, dtype=torch.float64, device=device)
return True
except RuntimeError as runtime_error:
if "data type is unsupported" in str(runtime_error):
except RuntimeError as e:
if "data type is unsupported" in str(e):
return False
raise
except TypeError as e:
# On Apple Silicon M1 with the MPS device, the following error is
# raised:
#
# TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS
# framework doesn't support float64. Please use float32 instead.
if "doesn't support float64" in str(e):
return False
raise

0 comments on commit 8d6b69d

Please sign in to comment.