From 1bff3085300876409699cb2a0199233e46364990 Mon Sep 17 00:00:00 2001 From: Pete Date: Thu, 12 Oct 2023 09:35:46 -0700 Subject: [PATCH] ensure bias is created in fp32 (#327) --- olmo/model.py | 11 ++++++----- olmo/util.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 71ef42f93..1c0fe0a6a 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -29,6 +29,7 @@ ) from .exceptions import OlmoConfigurationError from .initialization import init_weights +from .util import ensure_finite_ __all__ = [ "LayerNormBase", @@ -450,7 +451,7 @@ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch. raise NotImplementedError() if bias.dtype != target_dtype: bias = bias.to(target_dtype) - bias.masked_fill_(bias == float("-inf"), torch.finfo(target_dtype).min) + ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) return bias def attention( @@ -902,7 +903,7 @@ def forward( # Transform the attention mask into what the blocks expect. if attention_mask is not None: # shape: (batch_size, 1, 1, seq_len) - attention_mask = attention_mask.to(dtype=x.dtype).view(batch_size, -1)[:, None, None, :] + attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :] attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min # Merge attention mask with attention bias. @@ -922,7 +923,7 @@ def forward( elif attention_bias is None: attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device) elif attention_bias.dtype in (torch.int8, torch.bool): - attention_bias = attention_bias.to(dtype=x.dtype) + attention_bias = attention_bias.to(dtype=torch.float) attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) # Transform to the right shape and data type. @@ -931,7 +932,7 @@ def forward( mask_len = attention_mask.shape[-1] elif past_key_values is not None: mask_len = past_key_values[0][0].shape[-2] + input_ids.shape[-1] - attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(x.dtype) + attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float) # Add in the masking bias. if attention_mask is not None: @@ -939,7 +940,7 @@ def forward( # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf. # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead # it can produce NaNs. - attention_bias.masked_fill_(attention_bias == float("-inf"), torch.finfo(attention_bias.dtype).min) + ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False) attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None diff --git a/olmo/util.py b/olmo/util.py index 5542ae1f5..0e9c3a249 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -283,6 +283,17 @@ def move_to_device(o: T, device: torch.device) -> T: return o +def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False): + """ + Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf`` + is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``. + """ + if check_neg_inf: + x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min) + if check_pos_inf: + x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max) + + def is_distributed() -> bool: if "LOCAL_RANK" in os.environ: return True