From 40437d055ace374e9370d573abcd400a1b107ed0 Mon Sep 17 00:00:00 2001 From: Daniel Justus Date: Fri, 20 Oct 2023 09:34:23 +0000 Subject: [PATCH] only ipu cos sin in fp16 --- besskge/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/besskge/utils.py b/besskge/utils.py index 31bd6a9..e93fd40 100644 --- a/besskge/utils.py +++ b/besskge/utils.py @@ -66,7 +66,7 @@ def complex_rotation(v: torch.Tensor, r: torch.Tensor) -> torch.Tensor: Row-wise rotated tensors. """ # Always compute sin and cos in fp16, as faster on IPU - if r.dtype == torch.float32: + if r.dtype == torch.float32 and r.device.type == "ipu": r_cos = torch.cos(r.to(dtype=torch.float16)).to(dtype=torch.float32) r_sin = torch.sin(r.to(dtype=torch.float16)).to(dtype=torch.float32) else: