Skip to content

Commit

Permalink
Fix SDP NaN bug
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 10, 2023
1 parent fddded5 commit 85f9bc4
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,21 @@ def reset_parameters(self):
layer_id=self.layer_id,
)

@classmethod
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
target_dtype = input_dtype
if torch.is_autocast_enabled():
if bias.device.type == "cuda":
target_dtype = torch.get_autocast_gpu_dtype()
elif bias.device.type == "cpu":
target_dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
if bias.dtype != target_dtype:
bias = bias.to(target_dtype)
bias.masked_fill_(bias == float("-inf"), torch.finfo(target_dtype).min)
return bias

def attention(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -486,7 +501,14 @@ def attention(
q, k = self.rotary_emb(q, k)

if attention_bias is not None:
attention_bias = attention_bias[:, :, key_len - query_len : key_len, :key_len]
# Resize and cast attention bias.
# The current dtype of the attention bias might not match the dtype that the SDP attn function will
# run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
# as down-casting the attention bias to the autocast precision will result in -infs, which will
# cause the SDP attn function to produce NaNs.
attention_bias = self._cast_attn_bias(
attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
)

# Get the attention scores.
# shape: (B, nh, T, hs)
Expand Down Expand Up @@ -677,7 +699,7 @@ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTens
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
diagonal=1,
)
att_bias.masked_fill_(att_bias == 1, float("-inf"))
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore


Expand Down Expand Up @@ -882,9 +904,6 @@ def forward(
# shape: (batch_size, 1, 1, seq_len)
attention_mask = attention_mask.to(dtype=x.dtype).view(batch_size, -1)[:, None, None, :]
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
# TODO: fill w/ -inf instead?
# attention_mask = 1.0 - attention_mask
# attention_mask.masked_fill_(attention_mask == 1.0, float("-inf"))

# Merge attention mask with attention bias.
if (
Expand Down Expand Up @@ -917,6 +936,10 @@ def forward(
# Add in the masking bias.
if attention_mask is not None:
attention_bias = attention_bias + attention_mask
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
# it can produce NaNs.
attention_bias.masked_fill_(attention_bias == float("-inf"), torch.finfo(attention_bias.dtype).min)

attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None

Expand Down

0 comments on commit 85f9bc4

Please sign in to comment.