Skip to content

Commit

Permalink
Merge branch 'main' into epwalsh/olmo-core
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Apr 2, 2024
2 parents 378905c + 1c12980 commit 50e96f1
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Don't log garbage on nodes that aren't rank 0
- Don't crash in the HF code when we are referring to a tokenizer in a local file
- Changed `tie_weights` method to a no-op as weight tying is handled in olmo/model.py
- Fixed the size calculation for qk layer norm
- Fixed pipeline test failure that occurs due to a bug in transformers version 4.39.1

## [v0.2.5](https://github.com/allenai/OLMo/releases/tag/v0.2.5) - 2024-03-06

Expand Down
18 changes: 16 additions & 2 deletions hf_olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
from transformers import PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModelForCausalLM

Expand Down Expand Up @@ -57,6 +58,9 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[
Cache
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
) -> Union[Tuple, CausalLMOutputWithPast]:
if use_cache is None:
use_cache = self.config.use_cache
Expand Down Expand Up @@ -148,8 +152,18 @@ def set_output_embeddings(self, value: torch.nn.Module):
self.model.transformer.ff_out = value

def tie_weights(self):
if self.config.weight_tying:
self.model.transformer.ff_out = self.model.transformer.wte
"""
This function is intentionally left as a no-op.
Weight tying is handled as follows:
- When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
- When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
Therefore, there is no need to explicitly tie the weights in this function.
"""
pass


# Register the model so that it is available for transformer pipelines, auto-loading, etc.
Expand Down
6 changes: 6 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,12 @@ class TrainConfig(BaseConfig):
How often to log to the console.
"""

gen1_gc_interval: Optional[int] = 1
"""
How often (in steps) to run generation 1 garbage collection.
Set to ``None`` to use automatic garbage collection (i.e. we don't mess with it).
"""

compile: Optional[CompilerConfig] = None
"""
Settings for compiling the model with ``torch.compile()``.
Expand Down
6 changes: 4 additions & 2 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,8 @@ def fit(self):
self._gc_init_state = gc.isenabled() # cache if garbage collection is enabled, reset on close.

# Disable automatic garbage collection, FSDP doesn't work well with it.
gc.disable()
if self.cfg.gen1_gc_interval is not None:
gc.disable()

if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load:
eval_metrics = self.eval()
Expand Down Expand Up @@ -1155,7 +1156,8 @@ def on_trace_ready(p):
break

# Run generation 1 garbage collection.
gc.collect(1)
if self.cfg.gen1_gc_interval is not None and self.global_step % self.cfg.gen1_gc_interval == 0:
gc.collect(1)

# Python Profiler stuff
# We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
Expand Down

0 comments on commit 50e96f1

Please sign in to comment.