Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom resize_token_embeddings method to OLMoForCausalLM (#491) #501

Merged
merged 18 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Don't log garbage on nodes that aren't rank 0
- Don't crash in the HF code when we are referring to a tokenizer in a local file
- Corrected the `resize_token_embeddings` method in the `OLMoForCausalLM` class to properly update the token embeddings when resizing the vocabulary.
- Changed `tie_weights` method to a no-op as weight tying is handled in olmo/model.py
- Fixed the size calculation for qk layer norm
- Fixed pipeline test failure that occurs due to a bug in transformers version 4.39.1
Expand Down
56 changes: 56 additions & 0 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from dataclasses import fields
from typing import List, Optional, Tuple, Union

Expand All @@ -12,6 +13,8 @@

from .configuration_olmo import OLMoConfig

log = logging.getLogger(__name__)


def create_model_config_from_pretrained_config(config: OLMoConfig):
"""
Expand Down Expand Up @@ -165,6 +168,59 @@ def tie_weights(self):
"""
pass

def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> torch.nn.Embedding:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems to be the exact same as the one from the parent class. If it is, then it's not needed here (code editors and IDEs should be able to show the parent class's comment when one hovers over resize_token_embeddings). Having the comment somewhat implies that we are defining what the method does, but it is the parent doing it.

Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`.

Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.

Arguments:
new_num_tokens (`int`, *optional*):
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.

This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc

Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.

Note:
This method differs from the base class implementation by resizing the `embedding_size` attribute of the
model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size`
is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token
embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds

# Update base model and current model config
self.config.embedding_size = model_embeds.weight.shape[0]
Copy link
Collaborator

@2015aroras 2015aroras Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a warning for when the new embedding size is less than the vocab size? Something like "Resizing token embeddings to size <size>, which is less than the vocab size <vocab size> of the tokenizer.

self.model.config.embedding_size = model_embeds.weight.shape[0]

# Check if the embedding size is less than the vocab size
if self.config.embedding_size < self.config.vocab_size:
warning_message = (
f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size "
f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary "
"size is less than or equal to the new token embedding size."
)
log.warning(warning_message)

# Tie weights again if needed
self.tie_weights()

return model_embeds


# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
Loading