-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
End-to-End LLM Model Development with Torchtitan and Torchtune #341
base: main
Are you sure you want to change the base?
Changes from 1 commit
0e7a427
d6e56b2
5c9ecae
b30d471
5d160c6
736a029
ae6b020
61ddfa5
098c222
0e77ca5
1a88359
318f9d9
3dc358c
bea1b68
da7a51d
1de3e5a
018f4e9
a474f65
447d45c
c6a146b
67d6af7
8d59eef
ac8f5bd
44701fd
73b2ccb
4aa19c5
437783a
c608899
3adaa5c
7d25c4a
7e76b5d
6fc2e9d
5debc4f
3cf0ea2
13a8b32
a069026
0bb48b3
a477f22
d46de06
5b748ab
d4c68bb
d08ba66
d53342d
605d9e6
4967a8a
b156771
1e04b44
43b0f84
621ac21
69734b2
8cc9fb3
8d5daac
bdf9d10
96818b5
521b1b9
046270b
7b6ce70
29eb8be
afa41f3
d256f8f
306c8b4
15c8108
5590534
2621a2c
c537bbc
d0fe50d
df1b0e3
08fd1b6
13a441d
0625a02
ba7a748
577aace
84999f2
a73df1b
23f6963
2cd2b1e
8f874e0
cb2c9d6
06f78ee
6343a90
6cb9781
6e90775
a73ac66
fdd246c
57415ae
32fe215
f130f6b
3cdb653
bdf2487
8c4dae9
a8b8aad
29da5fb
2da855d
58bb612
eb891e4
3c88a6f
34ba897
e682407
74b8479
d09696c
ea15006
ab93f3c
0b214b1
525159e
e0aa9d0
1a39144
7d0dfab
c05056d
18fd145
55d9397
7fffe50
b38410a
a6f9581
538d64f
00c07ac
fa68fc3
6c99cba
580e825
c0b2a85
82a69ea
a5ce2b2
8f35ef6
6b28131
d16967a
41052b7
429fe94
35af783
ef4b597
53ba90e
db79703
bffd4e6
aae94b1
ea0a581
7d2fc9c
dac63aa
a361538
e7824e8
313712d
dc4c301
b680627
b7f6ff8
284ea5a
efbbe53
1bca0ff
654ba82
b6461fb
2eca998
db51efe
6002abb
6416560
a8e5bba
0105a19
d5f3555
345b729
001a09d
0b4e8e5
309ef58
2003183
ab5d3d5
2726b58
1e8d3e8
f98accf
4c0c69d
c907933
678f985
8b65168
ac4c45f
27d6967
94a0dbc
e9ac7d2
5b91caf
e63f623
244db77
04adb95
9a3160a
28d43c2
762f21e
7ca1fc9
5f7eb84
f1da782
690df84
83a7e2d
5d8b9b7
e341ec3
fbdb034
c9f8ac7
45ac55d
e5ac23f
fa15546
f826581
0d15a9d
dd13ba0
daa65b6
bfb35c6
ff02c72
f55f60d
06e6df0
0adc47f
948418a
9af9bb9
a55c333
13a5aff
e24b21a
08342fc
9db7487
c7a8cf0
6cb0dc3
5276eda
fbd278e
523de6e
b322db3
3eba59c
19e4cba
a604cf0
392d28e
6a940fb
fe62d87
1be9da2
6d4da01
19dd02c
1e5aedd
7b292ab
c8ceac8
e6c47cf
89b4927
b024895
b195f0b
fbabf38
1026a36
eacd729
33e99e2
ebf995c
e9abb4e
3083036
c0397a4
7eb0f66
3b0d9e6
332285e
d4029d2
ae98bf9
64e0724
00dfbf5
b929043
4ac5496
952eba3
563e807
71c33f6
77d4908
0133094
f8833b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -68,18 +68,62 @@ By following these steps, you ensure that the necessary model components are in | |||||
|
||||||
## 3. Continuous Pretraining | ||||||
|
||||||
In this step, you will fine-tune Llama3 model from the orinal checkpoint. Specifically, the finetune process in this step is called Full-parameter finetuning, which will update all the parameters in the original model. One of the problem we encounter in such training is memory consumption. A typical model trained in mixed precision with AdamW requires 18 bytes per model parameter plus activation memory (6 bytes for parameters for mixed precision training, 8 bytes for AdamW, 4 bytes).For more details of the anatomy, see [huggingface blog post](https://huggingface.co/docs/transformers/model_memory_anatomy). This means that 70B parameter model training would require more than 1.12 TB of accelerated memory, which is way bigger than 80 GB of H100 accelerated memory size. To tackle the problem, `torchtune` integrates PyTorch Fully Distributed Data Parallel (FSDP). In this framework. PyTorch Fully Sharded Data Parallel (FSDP) is a distributed training feature designed to efficiently handle large model training by sharding model parameters, gradients, and optimizer states across multiple devices. This approach significantly reduces memory consumption and optimizes resource utilization, making it possible to train models that are too large to fit on a single GPU. | ||||||
In this step, you will fine-tune the Llama3 model starting from the original checkpoint using the WikiText dataset. This process, known as Full-Parameter Finetuning, updates all the parameters in the original model. The configuration file used for this process is `./tutorials/e2e-llama3-70b-development/full_finetune_distributed.yaml`. | ||||||
|
||||||
### Memory Consumption Challenges | ||||||
One of the primary challenges during such training is memory consumption. A typical model trained in mixed precision with AdamW requires 18 bytes per model parameter plus activation memory (6 bytes for parameters in mixed precision training, 8 bytes for AdamW, and 4 bytes for other overheads). For more details on the anatomy, see the [Hugging Face blog post](https://huggingface.co/docs/transformers/model_memory_anatomy) blog post. This means that training a 70B parameter model would require more than 1.12 TB of accelerated memory, which far exceeds the 80 GB capacity of H100 accelerated memory. To address this issue, torchtune integrates PyTorch Fully Sharded Data Parallel (FSDP). | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How was 1.12 TiB calculated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. memory is not accelerated itself |
||||||
|
||||||
### Basic concepts and relevant configuration | ||||||
|
||||||
**FSDP** is a distributed training feature designed to efficiently handle large model training by sharding model parameters, gradients, and optimizer states across multiple devices. This approach significantly reduces memory consumption and optimizes resource utilization, making it possible to train models that are too large to fit on a single GPU. In `torchtune` users can launch FSDP training job with command `tune run full_finetune_distributed`. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
**The WikiText language modeling dataset** is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. `torchtune` has a module preconfigured for this dataset. The configuration file preconfigures the WikiText dataset as follows: | ||||||
|
||||||
```yaml | ||||||
dataset: | ||||||
_component_: torchtune.datasets.wikitext_dataset | ||||||
``` | ||||||
### Submit the training job | ||||||
Submit the job with the following command: | ||||||
```bash | ||||||
sbatch tutorials/e2e-llama3-70b-development/full_finetune_distributed.sbatch | ||||||
``` | ||||||
|
||||||
By default, this script launches the FSDP training job with two instances. Once the job has been scheduled, you will see the following outputs in the log file named `logs/full-finetuning*`: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see where two instances are specified by default, I see only |
||||||
|
||||||
```bash | ||||||
# tail -f logs/full-finetuning* | ||||||
Executing following command: | ||||||
tune run --master_addr 10.1.62.14 --master_port 28415 --nproc_per_node=8 --nnodes 2 --rdzv_backend=c10d --rdzv_endpoint=p5-st-p5-1 full_finetune_distributed --config /fsx/ubuntu/awsome-distributed-training/3.test_cases/torchtune/slurm/tutorials/e2e-llama3-70b-development/configs/full_finetune_distributed.yaml tokenizer.path=/fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B/original/tokenizer.model checkpointer.checkpoint_dir=/fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B checkpointer.output_dir=/fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B-tuned output_dir=/fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B-tuned/log metric_logger.log_dir=/fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B-tuned/log/metrics | ||||||
... | ||||||
0: wandb: Currently logged in as: <YOURUSERNAME>. Use `wandb login --relogin` to force relogin | ||||||
0: wandb: Tracking run with wandb version 0.17.0 | ||||||
0: wandb: Run data is saved locally in /fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B-tuned/log/metrics/wandb/run-20240527_001350-oziekm6j | ||||||
0: wandb: Run `wandb offline` to turn off syncing. | ||||||
0: wandb: Syncing run helpful-surf-1 | ||||||
0: wandb: ⭐️ View project at https://wandb.ai/<YOURUSERNAME>/torchtune | ||||||
0: wandb: 🚀 View run at https://wandb.ai/<YOURUSERNAME>/torchtune/runs/oziekm6j | ||||||
0: 2024-05-27:00:13:50,919 INFO [metric_logging.py:225] Logging /fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B/torchtune_config.yaml to W&B under Files | ||||||
... | ||||||
``` | ||||||
|
||||||
Notice that the job is being tracked by WANDB because of the following section in the config file: | ||||||
|
||||||
```yaml | ||||||
metric_logger: | ||||||
_component_: torchtune.utils.metric_logging.WandBLogger | ||||||
log_dir: None | ||||||
``` | ||||||
On the WANDB dashboard (`https://wandb.ai/<YOURUSERNAME>/torchtune`), you can monitor the learning curve, compute resource utilization, log outputs, and more. | ||||||
|
||||||
|
||||||
## 4. Instruction-tuning | ||||||
|
||||||
In this step, you will fine-tune the LLaMA model using Low-Rank Adaptation (LoRA) with the Alpaca dataset. We will first cover the basic concepts and relevant configurations found in the [config file](configs/lora_finetune_distributed.yaml), followed by a detailed fine-tuning tutorial. | ||||||
In this step, you will fine-tune the Llama model using Low-Rank Adaptation (LoRA) with the Alpaca dataset. We will first cover the basic concepts and relevant configurations found in the [config file](configs/lora_finetune_distributed.yaml), followed by a detailed fine-tuning tutorial. | ||||||
|
||||||
|
||||||
### Basic Concepts and Relevant Configurations | ||||||
|
@@ -115,9 +159,6 @@ dataset: | |||||
|
||||||
As the config suggests, we use a predefined dataset class prepared in torchtune. | ||||||
|
||||||
## 5. Alignment | ||||||
|
||||||
|
||||||
|
||||||
### Submit Finetuning job | ||||||
|
||||||
|
@@ -128,13 +169,12 @@ You can submit the finetuning job with the following command: | |||||
sbatch tutorials/e2e-llama3-70b-development/lora_finetune_distributed.sbatch | ||||||
``` | ||||||
|
||||||
Once the job has been scheduled, you will see following outputs in the log: | ||||||
|
||||||
Once the job has been scheduled, you will see following outputs in the logo output named `logs/: | ||||||
|
||||||
```bash | ||||||
... | ||||||
Executing following command: | ||||||
torchtune run --master_addr 10.1.28.89 --master_port 14280 --nproc_per_node=8 --nnodes 1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=p5-st-p5-2 lora_finetune_distributed | ||||||
tune run --master_addr 10.1.28.89 --master_port 14280 --nproc_per_node=8 --nnodes 1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=p5-st-p5-2 lora_finetune_distributed | ||||||
... | ||||||
0: wandb: Currently logged in as: <YOURUSERNAME>. Use `wandb login --relogin` to force relogin | ||||||
0: wandb: Tracking run with wandb version 0.17.0 | ||||||
|
@@ -149,7 +189,7 @@ torchtune run --master_addr 10.1.28.89 --master_port 14280 --nproc_per_node=8 -- | |||||
As the output indicates, we run a single-node distributed training job with 8 GPUs here. | ||||||
```bash | ||||||
torchtune run --master_addr 10.1.28.89 --master_port 14280 --nproc_per_node=8 --nnodes 1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=p5-st-p5-2 lora_finetune_distributed | ||||||
tune run --master_addr 10.1.28.89 --master_port 14280 --nproc_per_node=8 --nnodes 1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=p5-st-p5-2 lora_finetune_distributed | ||||||
``` | ||||||
|
||||||
|
||||||
|
@@ -191,33 +231,6 @@ You can submit sample evaluation job by: | |||||
sbatch evaluate.sbatch | ||||||
``` | ||||||
|
||||||
You will see: | ||||||
|
||||||
``` | ||||||
Running loglikelihood requests: 6%|▋ | 23/400 [00:01<00:18, 20.53it/s] | ||||||
Running loglikelihood requests: 16%|█▌ | 62/400 [00:02<00:15, 22.65it/s] | ||||||
Running loglikelihood requests: 24%|██▍ | 98/400 [00:04<00:13, 22.50it/s] | ||||||
Running loglikelihood requests: 33%|███▎ | 131/400 [00:06<00:12, 22.28it/s] | ||||||
Running loglikelihood requests: 42%|████▏ | 164/400 [00:07<00:10, 22.40it/s] | ||||||
Running loglikelihood requests: 50%|█████ | 200/400 [00:09<00:08, 22.60it/s] | ||||||
Running loglikelihood requests: 58%|█████▊ | 233/400 [00:10<00:07, 22.46it/s] | ||||||
Running loglikelihood requests: 66%|██████▌ | 263/400 [00:11<00:06, 22.51it/s] | ||||||
Running loglikelihood requests: 74%|███████▍ | 296/400 [00:13<00:04, 22.45it/s] | ||||||
Running loglikelihood requests: 82%|█�██████▏ | 326/400 [00:14<00:03, 22.63it/s]/s] | ||||||
Running loglikelihood requests: 90%|████████▉ | 356/400 [00:16<00:01, 22.82it/s] | ||||||
Running loglikelihood requests: 97%|█████████▋| 389/400 [00:17<00:00, 23.11it/s] | ||||||
Running loglikelihood requests: 100%|██████████| 400/400 [00:17<00:00, 22.27it/s] | ||||||
0: fatal: not a git repository (or any of the parent directories): .git | ||||||
0: 2024-05-07:01:12:39,479 INFO [eval.py:69] vllm (pretrained=meta-llama/Meta-Llama-3-70B,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1), gen_kwargs: (None), limit: 100.0, num_fewshot: None, batch_size: 1 | ||||||
0: 2024-05-07:01:12:39,536 INFO [eval.py:70] | Tasks |Version|Filter|n-shot| Metric |Value| |Stderr| | ||||||
0: |---------|------:|------|-----:|--------|----:|---|-----:| | ||||||
0: |hellaswag| 1|none | 0|acc | 0.56|± |0.0499| | ||||||
0: | | |none | 0|acc_norm| 0.75|± |0.0435| | ||||||
0: | ||||||
``` | ||||||
|
||||||
|
||||||
|
||||||
## 6. Quantization | ||||||
|
||||||
In the production setting, it is often not feasible to deploy large model as it is, this requires | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Config for multi-device LoRA in lora_finetune_distributed.py | ||
# using a Llama3 70B model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download meta-llama/Meta-Llama-3-70B-Instruct --hf-token <TOKEN> --output-dir /tmp/Meta-Llama-3-70b --ignore-patterns "original/consolidated*" | ||
# | ||
# This config needs 8 GPUs to run | ||
# # tune run --nproc_per_node 8 lora_finetune_distributed --config recipes/configs/llama3/70B_lora.yaml | ||
# | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3.lora_llama3_70b | ||
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] | ||
apply_lora_to_mlp: False | ||
apply_lora_to_output: False | ||
lora_rank: 16 | ||
lora_alpha: 32 | ||
|
||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: None | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: None | ||
checkpoint_files: [ | ||
model-00001-of-00030.safetensors, | ||
model-00002-of-00030.safetensors, | ||
model-00003-of-00030.safetensors, | ||
model-00004-of-00030.safetensors, | ||
model-00005-of-00030.safetensors, | ||
model-00006-of-00030.safetensors, | ||
model-00007-of-00030.safetensors, | ||
model-00008-of-00030.safetensors, | ||
model-00009-of-00030.safetensors, | ||
model-00010-of-00030.safetensors, | ||
model-00011-of-00030.safetensors, | ||
model-00012-of-00030.safetensors, | ||
model-00013-of-00030.safetensors, | ||
model-00014-of-00030.safetensors, | ||
model-00015-of-00030.safetensors, | ||
model-00016-of-00030.safetensors, | ||
model-00017-of-00030.safetensors, | ||
model-00018-of-00030.safetensors, | ||
model-00019-of-00030.safetensors, | ||
model-00020-of-00030.safetensors, | ||
model-00021-of-00030.safetensors, | ||
model-00022-of-00030.safetensors, | ||
model-00023-of-00030.safetensors, | ||
model-00024-of-00030.safetensors, | ||
model-00025-of-00030.safetensors, | ||
model-00026-of-00030.safetensors, | ||
model-00027-of-00030.safetensors, | ||
model-00028-of-00030.safetensors, | ||
model-00029-of-00030.safetensors, | ||
model-00030-of-00030.safetensors, | ||
] | ||
recipe_checkpoint: null | ||
output_dir: None | ||
model_type: LLAMA3 | ||
resume_from_checkpoint: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
batch_size: 2 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 | ||
|
||
# Logging | ||
output_dir: None | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.WandBLogger | ||
log_dir: None | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.