diff --git a/olmo/model.py b/olmo/model.py index ec4c597f8..5f94f8523 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -438,6 +438,21 @@ def reset_parameters(self): layer_id=self.layer_id, ) + @classmethod + def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: + target_dtype = input_dtype + if torch.is_autocast_enabled(): + if bias.device.type == "cuda": + target_dtype = torch.get_autocast_gpu_dtype() + elif bias.device.type == "cpu": + target_dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + if bias.dtype != target_dtype: + bias = bias.to(target_dtype) + bias.masked_fill_(bias == float("-inf"), torch.finfo(target_dtype).min) + return bias + def attention( self, q: torch.Tensor, @@ -486,7 +501,14 @@ def attention( q, k = self.rotary_emb(q, k) if attention_bias is not None: - attention_bias = attention_bias[:, :, key_len - query_len : key_len, :key_len] + # Resize and cast attention bias. + # The current dtype of the attention bias might not match the dtype that the SDP attn function will + # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding + # as down-casting the attention bias to the autocast precision will result in -infs, which will + # cause the SDP attn function to produce NaNs. + attention_bias = self._cast_attn_bias( + attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype + ) # Get the attention scores. # shape: (B, nh, T, hs) @@ -677,7 +699,7 @@ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTens torch.ones(seq_len, seq_len, device=device, dtype=torch.float), diagonal=1, ) - att_bias.masked_fill_(att_bias == 1, float("-inf")) + att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) return att_bias.view(1, 1, seq_len, seq_len) # type: ignore @@ -882,9 +904,6 @@ def forward( # shape: (batch_size, 1, 1, seq_len) attention_mask = attention_mask.to(dtype=x.dtype).view(batch_size, -1)[:, None, None, :] attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min - # TODO: fill w/ -inf instead? - # attention_mask = 1.0 - attention_mask - # attention_mask.masked_fill_(attention_mask == 1.0, float("-inf")) # Merge attention mask with attention bias. if ( @@ -917,6 +936,10 @@ def forward( # Add in the masking bias. if attention_mask is not None: attention_bias = attention_bias + attention_mask + # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf. + # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead + # it can produce NaNs. + attention_bias.masked_fill_(attention_bias == float("-inf"), torch.finfo(attention_bias.dtype).min) attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None