diff --git a/CHANGELOG.md b/CHANGELOG.md index 8646b1285..87d7602a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Don't log garbage on nodes that aren't rank 0 - Don't crash in the HF code when we are referring to a tokenizer in a local file - Corrected the `resize_token_embeddings` method in the `OLMoForCausalLM` class to properly update the token embeddings when resizing the vocabulary. +- Changed `tie_weights` method to a no-op as weight tying is handled in olmo/model.py - Fixed the size calculation for qk layer norm +- Fixed pipeline test failure that occurs due to a bug in transformers version 4.39.1 ## [v0.2.5](https://github.com/allenai/OLMo/releases/tag/v0.2.5) - 2024-03-06 diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index fd598854f..1bf763917 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -4,6 +4,7 @@ import torch from transformers import PreTrainedModel +from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.auto import AutoModelForCausalLM @@ -60,6 +61,9 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[ + Cache + ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426 ) -> Union[Tuple, CausalLMOutputWithPast]: if use_cache is None: use_cache = self.config.use_cache @@ -151,8 +155,18 @@ def set_output_embeddings(self, value: torch.nn.Module): self.model.transformer.ff_out = value def tie_weights(self): - if self.config.weight_tying: - self.model.transformer.ff_out = self.model.transformer.wte + """ + This function is intentionally left as a no-op. + + Weight tying is handled as follows: + - When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration. + See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`. + - When computing logits, the `wte` weights are used directly if `weight_tying` is enabled. + See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method. + + Therefore, there is no need to explicitly tie the weights in this function. + """ + pass def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None