From 8d6b69d9d39a00118b847259afeaf29df4fad331 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 20 Sep 2023 17:45:24 +0200 Subject: [PATCH] Fix has_fp64_support for MPS on Apple Silicon --- sklearn_pytorch_engine/_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sklearn_pytorch_engine/_utils.py b/sklearn_pytorch_engine/_utils.py index ae28846..69a5e19 100644 --- a/sklearn_pytorch_engine/_utils.py +++ b/sklearn_pytorch_engine/_utils.py @@ -35,6 +35,16 @@ def has_fp64_support(device): try: torch.zeros(1, dtype=torch.float64, device=device) return True - except RuntimeError as runtime_error: - if "data type is unsupported" in str(runtime_error): + except RuntimeError as e: + if "data type is unsupported" in str(e): return False + raise + except TypeError as e: + # On Apple Silicon M1 with the MPS device, the following error is + # raised: + # + # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS + # framework doesn't support float64. Please use float32 instead. + if "doesn't support float64" in str(e): + return False + raise