Skip to content

Commit

Permalink
Add attention_mask argument to loss_fn() and lm_cross_entropy_loss() …
Browse files Browse the repository at this point in the history
…and adjust the cross entropy calculation to ignore masked (padding) tokens.
  • Loading branch information
UFO-101 committed Aug 13, 2024
1 parent 3be089b commit a52bfac
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
32 changes: 32 additions & 0 deletions tests/integration/test_cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

from transformer_lens.HookedTransformer import HookedTransformer


def test_cross_entropy_attention_mask():
"""Check that adding a bunch of masked tokens to the input does not change the loss."""
MODEL = "solu-1l"
model = HookedTransformer.from_pretrained(MODEL)

# Step 1: Get the default loss on a prompt
prompt = ["The quick brown fox jumps over the lazy dog."]
default_tokens = model.to_tokens(prompt)
default_attention_mask = torch.ones_like(default_tokens)
default_loss = model(default_tokens, return_type="loss")
ones_mask_loss = model(
default_tokens, attention_mask=default_attention_mask, return_type="loss"
)
assert torch.allclose(default_loss, ones_mask_loss, atol=1e-6)

# Step 2: Get the loss when we add some extra tokens to the input and set their attention mask
# to zero
extra_prompt = ["Lorem ipsum dolor sit amet, consectetur adipiscing elit."]
extra_tokens = model.to_tokens(extra_prompt)
extra_zeros_attention_mask = torch.zeros_like(extra_tokens)

combined_tokens = torch.cat([default_tokens, extra_tokens], dim=1)
combined_attention_mask = torch.cat([default_attention_mask, extra_zeros_attention_mask], dim=1)
combined_masked_loss = model(
combined_tokens, attention_mask=combined_attention_mask, return_type="loss"
)
assert torch.allclose(default_loss, combined_masked_loss)
5 changes: 3 additions & 2 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def forward(
assert (
tokens is not None
), "tokens must be passed in if return_type is 'loss' or 'both'"
loss = self.loss_fn(logits, tokens, per_token=loss_per_token)
loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token)
if return_type == "loss":
return loss
elif return_type == "both":
Expand All @@ -600,6 +600,7 @@ def loss_fn(
self,
logits: Float[torch.Tensor, "batch pos d_vocab"],
tokens: Int[torch.Tensor, "batch pos"],
attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
per_token: bool = False,
):
"""Wrapper around `utils.lm_cross_entropy_loss`.
Expand All @@ -608,7 +609,7 @@ def loss_fn(
"""
if tokens.device != logits.device:
tokens = tokens.to(logits.device)
return utils.lm_cross_entropy_loss(logits, tokens, per_token)
return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token)

@overload
def run_with_cache(
Expand Down
15 changes: 14 additions & 1 deletion transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,37 @@ def to_numpy(tensor):
def lm_cross_entropy_loss(
logits: Float[torch.Tensor, "batch pos d_vocab"],
tokens: Int[torch.Tensor, "batch pos"],
attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
per_token: bool = False,
) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
"""Cross entropy loss for the language model, gives the loss for predicting the NEXT token.
Args:
logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab]
tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos]
attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to
mask out padding tokens. Defaults to None.
per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False.
"""
log_probs = F.log_softmax(logits, dim=-1)
# Use torch.gather to find the log probs of the correct tokens
# Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless)
# None and [..., 0] needed because the tensor used in gather must have the same rank.
predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0]

if attention_mask is not None:
# Ignore token positions which are masked out or where the next token is masked out
# (generally padding tokens)
next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:])
predicted_log_probs *= next_token_mask
n_tokens = next_token_mask.sum().item()
else:
n_tokens = predicted_log_probs.numel()

if per_token:
return -predicted_log_probs
else:
return -predicted_log_probs.mean()
return -predicted_log_probs.sum() / n_tokens


def lm_accuracy(
Expand Down

0 comments on commit a52bfac

Please sign in to comment.