-
Notifications
You must be signed in to change notification settings - Fork 316
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add attention_mask argument to loss_fn() and lm_cross_entropy_loss() …
…and adjust the cross entropy calculation to ignore masked (padding) tokens.
- Loading branch information
Showing
3 changed files
with
49 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
|
||
from transformer_lens.HookedTransformer import HookedTransformer | ||
|
||
|
||
def test_cross_entropy_attention_mask(): | ||
"""Check that adding a bunch of masked tokens to the input does not change the loss.""" | ||
MODEL = "solu-1l" | ||
model = HookedTransformer.from_pretrained(MODEL) | ||
|
||
# Step 1: Get the default loss on a prompt | ||
prompt = ["The quick brown fox jumps over the lazy dog."] | ||
default_tokens = model.to_tokens(prompt) | ||
default_attention_mask = torch.ones_like(default_tokens) | ||
default_loss = model(default_tokens, return_type="loss") | ||
ones_mask_loss = model( | ||
default_tokens, attention_mask=default_attention_mask, return_type="loss" | ||
) | ||
assert torch.allclose(default_loss, ones_mask_loss, atol=1e-6) | ||
|
||
# Step 2: Get the loss when we add some extra tokens to the input and set their attention mask | ||
# to zero | ||
extra_prompt = ["Lorem ipsum dolor sit amet, consectetur adipiscing elit."] | ||
extra_tokens = model.to_tokens(extra_prompt) | ||
extra_zeros_attention_mask = torch.zeros_like(extra_tokens) | ||
|
||
combined_tokens = torch.cat([default_tokens, extra_tokens], dim=1) | ||
combined_attention_mask = torch.cat([default_attention_mask, extra_zeros_attention_mask], dim=1) | ||
combined_masked_loss = model( | ||
combined_tokens, attention_mask=combined_attention_mask, return_type="loss" | ||
) | ||
assert torch.allclose(default_loss, combined_masked_loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters