diff --git a/CHANGELOG.md b/CHANGELOG.md index 28b894008..87d7602a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Don't log garbage on nodes that aren't rank 0 - Don't crash in the HF code when we are referring to a tokenizer in a local file +- Corrected the `resize_token_embeddings` method in the `OLMoForCausalLM` class to properly update the token embeddings when resizing the vocabulary. - Changed `tie_weights` method to a no-op as weight tying is handled in olmo/model.py - Fixed the size calculation for qk layer norm - Fixed pipeline test failure that occurs due to a bug in transformers version 4.39.1 diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index 8850acb91..1bf763917 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -1,3 +1,4 @@ +import logging from dataclasses import fields from typing import List, Optional, Tuple, Union @@ -12,6 +13,8 @@ from .configuration_olmo import OLMoConfig +log = logging.getLogger(__name__) + def create_model_config_from_pretrained_config(config: OLMoConfig): """ @@ -165,6 +168,59 @@ def tie_weights(self): """ pass + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> torch.nn.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`. + + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + new_num_tokens (`int`, *optional*): + The new number of tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + + Return: + `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. + + Note: + This method differs from the base class implementation by resizing the `embedding_size` attribute of the + model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size` + is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token + embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary. + """ + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + if new_num_tokens is None and pad_to_multiple_of is None: + return model_embeds + + # Update base model and current model config + self.config.embedding_size = model_embeds.weight.shape[0] + self.model.config.embedding_size = model_embeds.weight.shape[0] + + # Check if the embedding size is less than the vocab size + if self.config.embedding_size < self.config.vocab_size: + warning_message = ( + f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size " + f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary " + "size is less than or equal to the new token embedding size." + ) + log.warning(warning_message) + + # Tie weights again if needed + self.tie_weights() + + return model_embeds + # Register the model so that it is available for transformer pipelines, auto-loading, etc. AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)