Skip to content

Commit

Permalink
ensure bias is created in fp32 (#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Oct 12, 2023
1 parent d4744d0 commit 1bff308
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
11 changes: 6 additions & 5 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from .exceptions import OlmoConfigurationError
from .initialization import init_weights
from .util import ensure_finite_

__all__ = [
"LayerNormBase",
Expand Down Expand Up @@ -450,7 +451,7 @@ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.
raise NotImplementedError()
if bias.dtype != target_dtype:
bias = bias.to(target_dtype)
bias.masked_fill_(bias == float("-inf"), torch.finfo(target_dtype).min)
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
return bias

def attention(
Expand Down Expand Up @@ -902,7 +903,7 @@ def forward(
# Transform the attention mask into what the blocks expect.
if attention_mask is not None:
# shape: (batch_size, 1, 1, seq_len)
attention_mask = attention_mask.to(dtype=x.dtype).view(batch_size, -1)[:, None, None, :]
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min

# Merge attention mask with attention bias.
Expand All @@ -922,7 +923,7 @@ def forward(
elif attention_bias is None:
attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device)
elif attention_bias.dtype in (torch.int8, torch.bool):
attention_bias = attention_bias.to(dtype=x.dtype)
attention_bias = attention_bias.to(dtype=torch.float)
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)

# Transform to the right shape and data type.
Expand All @@ -931,15 +932,15 @@ def forward(
mask_len = attention_mask.shape[-1]
elif past_key_values is not None:
mask_len = past_key_values[0][0].shape[-2] + input_ids.shape[-1]
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(x.dtype)
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)

# Add in the masking bias.
if attention_mask is not None:
attention_bias = attention_bias + attention_mask
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
# it can produce NaNs.
attention_bias.masked_fill_(attention_bias == float("-inf"), torch.finfo(attention_bias.dtype).min)
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)

attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None

Expand Down
11 changes: 11 additions & 0 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,17 @@ def move_to_device(o: T, device: torch.device) -> T:
return o


def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
"""
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
"""
if check_neg_inf:
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
if check_pos_inf:
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)


def is_distributed() -> bool:
if "LOCAL_RANK" in os.environ:
return True
Expand Down

0 comments on commit 1bff308

Please sign in to comment.