diff --git a/tests/integration/test_cross_entropy_loss.py b/tests/integration/test_cross_entropy_loss.py new file mode 100644 index 000000000..040070031 --- /dev/null +++ b/tests/integration/test_cross_entropy_loss.py @@ -0,0 +1,32 @@ +import torch + +from transformer_lens.HookedTransformer import HookedTransformer + + +def test_cross_entropy_attention_mask(): + """Check that adding a bunch of masked tokens to the input does not change the loss.""" + MODEL = "solu-1l" + model = HookedTransformer.from_pretrained(MODEL) + + # Step 1: Get the default loss on a prompt + prompt = ["The quick brown fox jumps over the lazy dog."] + default_tokens = model.to_tokens(prompt) + default_attention_mask = torch.ones_like(default_tokens) + default_loss = model(default_tokens, return_type="loss") + ones_mask_loss = model( + default_tokens, attention_mask=default_attention_mask, return_type="loss" + ) + assert torch.allclose(default_loss, ones_mask_loss, atol=1e-6) + + # Step 2: Get the loss when we add some extra tokens to the input and set their attention mask + # to zero + extra_prompt = ["Lorem ipsum dolor sit amet, consectetur adipiscing elit."] + extra_tokens = model.to_tokens(extra_prompt) + extra_zeros_attention_mask = torch.zeros_like(extra_tokens) + + combined_tokens = torch.cat([default_tokens, extra_tokens], dim=1) + combined_attention_mask = torch.cat([default_attention_mask, extra_zeros_attention_mask], dim=1) + combined_masked_loss = model( + combined_tokens, attention_mask=combined_attention_mask, return_type="loss" + ) + assert torch.allclose(default_loss, combined_masked_loss) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 3a35fa307..8ee2e74f8 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -587,7 +587,7 @@ def forward( assert ( tokens is not None ), "tokens must be passed in if return_type is 'loss' or 'both'" - loss = self.loss_fn(logits, tokens, per_token=loss_per_token) + loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token) if return_type == "loss": return loss elif return_type == "both": @@ -600,6 +600,7 @@ def loss_fn( self, logits: Float[torch.Tensor, "batch pos d_vocab"], tokens: Int[torch.Tensor, "batch pos"], + attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, per_token: bool = False, ): """Wrapper around `utils.lm_cross_entropy_loss`. @@ -608,7 +609,7 @@ def loss_fn( """ if tokens.device != logits.device: tokens = tokens.to(logits.device) - return utils.lm_cross_entropy_loss(logits, tokens, per_token) + return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) @overload def run_with_cache( diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index f5d3bacbe..7e5828e14 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -115,6 +115,7 @@ def to_numpy(tensor): def lm_cross_entropy_loss( logits: Float[torch.Tensor, "batch pos d_vocab"], tokens: Int[torch.Tensor, "batch pos"], + attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, per_token: bool = False, ) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: """Cross entropy loss for the language model, gives the loss for predicting the NEXT token. @@ -122,6 +123,8 @@ def lm_cross_entropy_loss( Args: logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab] tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos] + attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to + mask out padding tokens. Defaults to None. per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False. """ log_probs = F.log_softmax(logits, dim=-1) @@ -129,10 +132,20 @@ def lm_cross_entropy_loss( # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless) # None and [..., 0] needed because the tensor used in gather must have the same rank. predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0] + + if attention_mask is not None: + # Ignore token positions which are masked out or where the next token is masked out + # (generally padding tokens) + next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:]) + predicted_log_probs *= next_token_mask + n_tokens = next_token_mask.sum().item() + else: + n_tokens = predicted_log_probs.numel() + if per_token: return -predicted_log_probs else: - return -predicted_log_probs.mean() + return -predicted_log_probs.sum() / n_tokens def lm_accuracy(