diff --git a/besskge/utils.py b/besskge/utils.py index 31bd6a9..e93fd40 100644 --- a/besskge/utils.py +++ b/besskge/utils.py @@ -66,7 +66,7 @@ def complex_rotation(v: torch.Tensor, r: torch.Tensor) -> torch.Tensor: Row-wise rotated tensors. """ # Always compute sin and cos in fp16, as faster on IPU - if r.dtype == torch.float32: + if r.dtype == torch.float32 and r.device.type == "ipu": r_cos = torch.cos(r.to(dtype=torch.float16)).to(dtype=torch.float32) r_sin = torch.sin(r.to(dtype=torch.float16)).to(dtype=torch.float32) else: