-
Notifications
You must be signed in to change notification settings - Fork 451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom resize_token_embeddings
method to OLMoForCausalLM
(#491)
#501
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
5b4e866
adds custom resize_token_embeddings method
djliden 4e4ff1f
Merge branch 'allenai:main' into dl/fix-embedding-resize
djliden cf774dc
fixes indentation
djliden 17149cf
updates changelog
djliden 8f8e8c2
Merge branch 'main' into dl/fix-embedding-resize
djliden 36140c4
adds warning if embedding size < vocab size
djliden 1dedce9
fixes import ordering
djliden 2da225e
Merge branch 'main' into dl/fix-embedding-resize
djliden ba72c66
Merge branch 'main' into dl/fix-embedding-resize
djliden 04065f0
Merge branch 'allenai:main' into dl/fix-embedding-resize
djliden f395a55
updates comment
djliden ee9e9ff
Merge branch 'main' into dl/fix-embedding-resize
djliden e1ec4b3
minor fixes
djliden 1595326
formats with ruff
djliden e186f5b
Merge branch 'main' into dl/fix-embedding-resize
djliden aa5687d
Merge branch 'main' into dl/fix-embedding-resize
2015aroras 3cdfcde
Merge branch 'main' into dl/fix-embedding-resize
2015aroras 81d30d4
Merge branch 'main' into dl/fix-embedding-resize
2015aroras File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import logging | ||
from dataclasses import fields | ||
from typing import List, Optional, Tuple, Union | ||
|
||
|
@@ -12,6 +13,8 @@ | |
|
||
from .configuration_olmo import OLMoConfig | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def create_model_config_from_pretrained_config(config: OLMoConfig): | ||
""" | ||
|
@@ -165,6 +168,59 @@ def tie_weights(self): | |
""" | ||
pass | ||
|
||
def resize_token_embeddings( | ||
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None | ||
) -> torch.nn.Embedding: | ||
""" | ||
Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`. | ||
|
||
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. | ||
|
||
Arguments: | ||
new_num_tokens (`int`, *optional*): | ||
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized | ||
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just | ||
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. | ||
pad_to_multiple_of (`int`, *optional*): | ||
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to | ||
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`. | ||
|
||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability | ||
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more | ||
details about this, or help on choosing the correct value for resizing, refer to this guide: | ||
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc | ||
|
||
Return: | ||
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. | ||
|
||
Note: | ||
This method differs from the base class implementation by resizing the `embedding_size` attribute of the | ||
model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size` | ||
is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token | ||
embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary. | ||
""" | ||
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) | ||
if new_num_tokens is None and pad_to_multiple_of is None: | ||
return model_embeds | ||
|
||
# Update base model and current model config | ||
self.config.embedding_size = model_embeds.weight.shape[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a warning for when the new embedding size is less than the vocab size? Something like |
||
self.model.config.embedding_size = model_embeds.weight.shape[0] | ||
|
||
# Check if the embedding size is less than the vocab size | ||
if self.config.embedding_size < self.config.vocab_size: | ||
warning_message = ( | ||
f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size " | ||
f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary " | ||
"size is less than or equal to the new token embedding size." | ||
) | ||
log.warning(warning_message) | ||
|
||
# Tie weights again if needed | ||
self.tie_weights() | ||
|
||
return model_embeds | ||
|
||
|
||
# Register the model so that it is available for transformer pipelines, auto-loading, etc. | ||
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment seems to be the exact same as the one from the parent class. If it is, then it's not needed here (code editors and IDEs should be able to show the parent class's comment when one hovers over
resize_token_embeddings
). Having the comment somewhat implies that we are defining what the method does, but it is the parent doing it.