This repository provides a custom training pipeline for Hugging Face Transformer models, that replaces the standard softmax with a numerically stable alternative (StableMax or LogStableMax) and appliesOrthogonal Gradient Updates to encourage generalization. Built on top of Hugging Face Transformers, this project aims to integrate the findings of "Grokking at the Edge of Numerical Stability" by Lucas Prieto, Melih Barsbey, Pedro A.M. Mediano, and Tolga Birdal into Huggingface training pipelines.
A really nice video describing it: https://www.youtube.com/watch?v=H3OofROzlA0
StableMax is a drop-in replacement for Softmax aimed at preventing large-scale numerical instabilities sometimes observed when logits grow excessively. Instead of the exponential function used by Softmax, StableMax applies an elementwise transform:
then normalizes across a specified dimension. This can help avoid issues such as “Softmax Collapse” or large logit blow-ups after perfect accuracy is achieved. We also provide LogStableMax, which outputs log-probabilities directly.
In orthogonal gradient decomposition ("⊥Grad"), the gradient
-
Parallel Component along the current weight vector
$\theta$ . - Orthogonal Component, which is the remainder.
By discarding the parallel component and updating only in directions orthogonal to the current weights, the model is encouraged to explore new directions for generalization. This technique can help reduce overfitting and keep parameter norms in check, especially in large-scale training.
-
StableMax and LogStableMax
- A numerically stable alternative to Softmax (probabilities or log-probabilities).
-
Orthogonal Gradient Decorator
- Compatible with many PyTorch optimizers (like
AdamW
), modifies gradients before the optimizer step.
- Compatible with many PyTorch optimizers (like
-
Custom Trainer
- Inherits from Hugging Face’s
Trainer
, adding:- Automatic final-layer replacement to handle StableMax.
- Integration with orthogonal gradients.
- Built-in model config mapping (
MODEL_CONFIG_MAPPING
) for quick setup of known architectures.
- Inherits from Hugging Face’s
-
Examples
examples/minimal_usage.py
: Demonstrates a simple GPT-2 training loop on a toy dataset.examples/llama3.2-1b-alpaca.py
: Fine-tunes Llama-3.2-1B on the Alpaca-cleaned dataset using StableMax and orthogonal gradients.
- Clone the repository:
git clone https://github.com/YourUsername/my_repo.git
cd my_repo
- Install required Python packages:
pip install -r requirements.txt
- (Optional) Install as a package:
pip install -e .
This makes src/
importable from anywhere in your environment.
A quick demonstration on a small GPT-2 model with a toy dataset:
cd examples
python minimal_usage.py
minimal_usage.py
shows how to:
- Load a small GPT-2 model & tokenizer.
- Create a small dataset.
- Use
CustomTrainingArguments
andCustomTrainer
. - Enable
use_stable_max
,use_log_stable_max
, oruse_orthogonal_optimizer
.
To train Llama-3.2-1B on the Alpaca-cleaned dataset with StableMax and orthogonal gradient updates, run:
cd examples
python llama3.2-1b-alpaca.py
Key steps in llama3.2-1b-alpaca.py
:
- Loads the
meta-llama/Llama-3.2-1B
model from Hugging Face. - Tokenizes the Alpaca-cleaned dataset with a custom prompt format.
- Fine-tunes with
use_stable_max=True
anduse_orthogonal_optimizer=True
.
- StableMax or LogStableMax: Toggle via
use_stable_max
oruse_log_stable_max
. - Orthogonal Gradient: Toggle via
use_orthogonal_optimizer
. - Expand Final Layer: Experimental; some tasks might benefit from dimension +1.
- Skip Parameter Types: Provide substrings like
["bias", "LayerNorm"]
to avoid orthogonal decomposition on certain parameters.
All these can be set in CustomTrainingArguments
. See examples/minimal_usage.py
for a template.
-
Why StableMax?
Softmax can become numerically unstable when logits are very large. StableMax helps clamp or transform logits in a way that avoids overflow, continuing to learn after near-perfect training accuracy. -
When to use
LogStableMax
?
If you prefer working in log-space (e.g., withtorch.nn.NLLLoss
),LogStableMax
directly yields log-probabilities. -
How does the orthogonal gradient help?
It removes gradient components parallel to the existing weight vector. This can reduce “runaway norm” issues and help generalization by forcing updates in new directions. -
What if I only want orthogonal gradients without StableMax?
Simply keepuse_stable_max=False
anduse_log_stable_max=False
, but setuse_orthogonal_optimizer=True
. -
Does this code work with DeepSpeed or FSDP?
Yes, though you might need additional config (e.g.,--deepspeed ds_config.json
). Ensure that custom operations (like orthogonal decomposition) do not conflict with distributed memory partitioning.
Pull requests, bug reports, and feature requests are welcome! If you’d like to add more model entries to MODEL_CONFIG_MAPPING
or refine stable/log transforms, feel free to open an issue or submit a PR.
MIT License
This project builds on the work presented in the paper "Grokking at the Edge of Numerical Stability" by Lucas Prieto, Melih Barsbey, Pedro A.M. Mediano, and Tolga Birdal. We thank the authors for their insights into Softmax Collapse (SC) and their contributions to StableMax and ⊥Grad, which inspired the development of this repository.
The original paper and code can be found at: https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability