Skip to content

Commit

Permalink
11/24/24 record
Browse files Browse the repository at this point in the history
  • Loading branch information
KellerJordan committed Nov 25, 2024
1 parent 4344f95 commit 42aab06
Show file tree
Hide file tree
Showing 6 changed files with 9,643 additions and 21 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

This is a modified variant of the [PyTorch GPT-2 trainer](https://github.com/karpathy/llm.c/blob/7b929300217ff1a974b63791a228928b39b26409/train_gpt2.py) from
Andrej Karpathy's [llm.c](https://github.com/karpathy/llm.c) repo, which attains the same final validation loss in:
* 1.0B tokens instead of 10B
* 5.0 minutes on 8xH100 instead of 45
* 0.9B tokens instead of 10B
* 4.7 minutes on 8xH100 instead of 45

It uses the following techniques:
* Modernized architecture: Rotary embeddings, QK-Norm, and ReLU^2.
Expand All @@ -13,10 +13,10 @@ It uses the following techniques:
* Architectural shortcuts: value residual and embedding shortcut (partially following https://arxiv.org/abs/2410.17897).
* Momentum warmup.
* Tanh soft logit capping (following Gemma 2).
* FlexAttention.
* FlexAttention with window size warmup.

The training has attained this speed due to the contributions of meself, [@Grad62304977](https://x.com/Grad62304977),
[@jxbz](https://x.com/jxbz), [@bozavlado](https://x.com/bozavlado), [@brendanh0gan](https://x.com/brendanh0gan), & [@KoszarskyB](https://x.com/KoszarskyB).
[@jxbz](https://x.com/jxbz), [@bozavlado](https://x.com/bozavlado), [@brendanh0gan](https://x.com/brendanh0gan), [@KoszarskyB](https://x.com/KoszarskyB), & [@fernbear.bsky.social](https://bsky.app/profile/fernbear.bsky.social)

---

Expand All @@ -31,7 +31,7 @@ python data/cached_fineweb10B.py 10 # downloads only the first 1.0B training tok
./run.sh
```

The result will be a transformer with 124M active parameters trained for 1875 steps on 1.0B tokens of Fineweb [1], achieving ~3.278 validation loss (w/ up to 0.005 inter-run stddev).
The result will be a transformer with 124M active parameters trained for 1750 steps on 0.9B tokens of Fineweb [1], achieving ~3.278 mean validation loss (w/ up to 0.005 inter-run stddev).
For comparison, the default llm.c PyTorch trainer yields [>3.28 validation loss after training for 19560 steps on 10B tokens](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29).

## Running it on fewer GPUs or with less memory
Expand Down Expand Up @@ -70,7 +70,8 @@ The following is the progression of world records for the task of *training a mo
9. [8.2 minutes: Shortcuts & tweaks](https://x.com/kellerjordan0/status/1854296101303800108) (11/06/24) [[reproducible log](./records/110624_ShortcutsTweaks/dd7304a6-cc43-4d5e-adb8-c070111464a1.txt)]
11. [7.8 minutes: Bfloat16 activations](https://x.com/kellerjordan0/status/1855267054774865980) (11/08/24) [[reproducible log](./records/110824_CastBf16/a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt)]
12. [7.23 minutes: U-net & 2x lr](https://x.com/kellerjordan0/status/1856053121103093922) (11/10/24) [[reproducible log](./records/111024_UNetDoubleLr/c87bb826-797b-4f37-98c7-d3a5dad2de74.txt)]
13. [5.0 minutes: FlexAttention]() (11/19/24) [[reproducible log](./records/111924_FlexAttention/8384493d-dba9-4991-b16b-8696953f5e6d.txt)] (requires PyTorch 2.6.0)
13. [5.03 minutes: FlexAttention](https://x.com/kellerjordan0/status/1859331370268623321) (11/19/24) [[reproducible log](./records/111924_FlexAttention/8384493d-dba9-4991-b16b-8696953f5e6d.txt)] (requires PyTorch 2.6.0)
14. [4.66 minutes: Window Warmup]() (11/24/24) [[reproducible log](./records/112424_WindowWarmup/cf9e4571-c5fc-4323-abf3-a98d862ec6c8.txt)]

Please see the X threads for the contributors to each record.

Expand Down
Loading

0 comments on commit 42aab06

Please sign in to comment.