From 6773d727714bbccb2980807b4f55a605184ce5a3 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 25 Oct 2023 11:35:00 -0700 Subject: [PATCH] Fix dtype casting on CPU --- olmo/model.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 1c0fe0a6a..e873e5aa1 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -127,16 +127,15 @@ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> Lay raise NotImplementedError(f"Not sure how to handle '{config.layer_norm_type}' LayerNorm type") def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if torch.is_autocast_enabled(): - if dtype is None: - if tensor.device.type == "cuda": - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == "cpu": - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor + # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function + # `is_autocast_cpu_enabled()` for CPU autocast. + # See https://github.com/pytorch/pytorch/issues/110966. + if tensor.device.type == "cuda" and torch.is_autocast_enabled(): + return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype()) + elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled(): + return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype()) + else: + return tensor def reset_parameters(self): if self.weight is not None: @@ -442,13 +441,13 @@ def reset_parameters(self): @classmethod def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: target_dtype = input_dtype - if torch.is_autocast_enabled(): - if bias.device.type == "cuda": - target_dtype = torch.get_autocast_gpu_dtype() - elif bias.device.type == "cpu": - target_dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() + # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function + # `is_autocast_cpu_enabled()` for CPU autocast. + # See https://github.com/pytorch/pytorch/issues/110966. + if bias.device.type == "cuda" and torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled(): + target_dtype = torch.get_autocast_cpu_dtype() if bias.dtype != target_dtype: bias = bias.to(target_dtype) ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)