Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End-to-End LLM Model Development with Torchtitan and Torchtune #341

Open
wants to merge 856 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
856 commits
Select commit Hold shift + click to select a range
0e7a427
Update README.md
KeitaW Mar 16, 2024
d6e56b2
Update README
KeitaW Mar 17, 2024
5c9ecae
update scripts
KeitaW Mar 17, 2024
b30d471
update log file name
KeitaW Mar 17, 2024
5d160c6
Remove 14k log lines
Mar 18, 2024
736a029
Merge pull request #216 from aws-samples/reduce-build-log-efa-exporter
KeitaW Mar 18, 2024
ae6b020
Merge pull request #215 from aws-samples/pytorch-cpu-ddp-conda-enroot
KeitaW Mar 18, 2024
61ddfa5
Merge pull request #214 from aws-samples/smph-fix-dcgm-exporter-gpu-util
mhuguesaws Mar 18, 2024
098c222
Change nccl version to 2.20.3
mhuguesaws Mar 21, 2024
0e77ca5
Merge pull request #217 from aws-samples/nccl_tests_version_changes
verdimrc Mar 22, 2024
1a88359
smp v2 llama2 training example using fp8
arunkumarl87 Mar 22, 2024
318f9d9
Update 3.container-train.sbatch
KeitaW Mar 25, 2024
3dc358c
Merge pull request #221 from aws-samples/KeitaW-patch-1
verdimrc Mar 25, 2024
bea1b68
Added second subnet for other AWS services which require multi-AZ
shimomut Mar 26, 2024
da7a51d
Removed FSXSecurityGroup as it is unused
shimomut Mar 26, 2024
1de3e5a
Renamed resources to Primary/Backup Subnet
shimomut Mar 26, 2024
018f4e9
Revert "Removed FSXSecurityGroup as it is unused"
shimomut Mar 26, 2024
a474f65
Merge branch 'hyperpod_backup_subnet_20240326'
shimomut Mar 26, 2024
447d45c
Rename 0.crate-conda-env.sh to 0.create-conda-env.sh
sean-smith Mar 27, 2024
c6a146b
Merge pull request #225 from aws-samples/sean-smith-patch-2
KeitaW Mar 27, 2024
67d6af7
Deleted unused security group FSXSecurityGroup
shimomut Mar 26, 2024
8d59eef
Merge pull request #222 from shimomut/main
shimomut Mar 28, 2024
ac8f5bd
Added comments to conda setup scripts
arunkumarl87 Mar 28, 2024
44701fd
Merge pull request #218 from aruncs2005/main
aruncs2005 Mar 28, 2024
73b2ccb
Update 1.conda-train.sbatch
KeitaW Mar 31, 2024
4aa19c5
Update 3.container-train.sbatch
KeitaW Mar 31, 2024
437783a
Merge pull request #229 from aws-samples/KeitaW-patch-1
KeitaW Mar 31, 2024
c608899
updated pytorch version to 2.2
johnbensnyder Mar 31, 2024
3adaa5c
Validate Json in preflight check
sean-smith Apr 1, 2024
7d25c4a
Merge pull request #233 from aws-samples/validate-json
KeitaW Apr 1, 2024
7e76b5d
Adding for ActiveDirectory/LDAPS integration for HyperPod (#224)
shimomut Apr 1, 2024
6fc2e9d
DCGM exporter Updates - added responses to comments. create systemd s…
nghtm Mar 28, 2024
5debc4f
updated comments with references
johnbensnyder Apr 2, 2024
3cf0ea2
Merge pull request #230 from johnbensnyder/fsdp_version_update
verdimrc Apr 2, 2024
13a8b32
smhp: shorter wget log (less 3k lines)
Apr 2, 2024
a069026
smhp: increase apt lock timeout
Apr 2, 2024
0bb48b3
Merge pull request #236 from aws-samples/smhp-shorter-log
KeitaW Apr 2, 2024
a477f22
Merge pull request #227 from nghtm/exporter-updates
mhuguesaws Apr 2, 2024
d46de06
Update setup_conda_env.sh
aruncs2005 Apr 2, 2024
5b748ab
Merge pull request #237 from aruncs2005/main
perifaws Apr 2, 2024
d4c68bb
Merge pull request #235 from aws-samples/smhp-apt-lock-timeout
perifaws Apr 2, 2024
d08ba66
Deprecate chrony timesync
Apr 3, 2024
d53342d
Driver for apply hotfix
Apr 3, 2024
605d9e6
smhp: add hotfix to hold lustre client
Apr 3, 2024
4967a8a
smhp: hotfix to mock gpu .deb package
Apr 3, 2024
b156771
Revert "DCGM exporter Updates - added responses to comments. create s…
mhuguesaws Apr 3, 2024
1e04b44
Enabled Prometheus agent mode.
giuseppeporcelli Apr 3, 2024
43b0f84
Merge pull request #238 from giuseppeporcelli/main
verdimrc Apr 3, 2024
621ac21
Merge pull request #226 from aws-samples/upstream-ppc-v2403.02
verdimrc Apr 3, 2024
69734b2
updated pytorch version to 2.2
johnbensnyder Mar 31, 2024
8cc9fb3
updated comments with references
johnbensnyder Apr 2, 2024
8d5daac
smhp: shorter wget log (less 3k lines)
Apr 2, 2024
bdf9d10
DCGM exporter Updates - added responses to comments. create systemd s…
nghtm Mar 28, 2024
96818b5
Update setup_conda_env.sh
aruncs2005 Apr 2, 2024
521b1b9
smhp: increase apt lock timeout
Apr 2, 2024
046270b
Revert "DCGM exporter Updates - added responses to comments. create s…
mhuguesaws Apr 3, 2024
7b6ce70
Revert "DCGM exporter Updates - added responses to comments. create s…
mhuguesaws Apr 3, 2024
29eb8be
Enabled Prometheus agent mode.
giuseppeporcelli Apr 3, 2024
afa41f3
Enabled Prometheus agent mode.
giuseppeporcelli Apr 3, 2024
d256f8f
Deprecate chrony timesync
Apr 3, 2024
306c8b4
Deprecate chrony timesync
Apr 3, 2024
15c8108
Driver for apply hotfix
Apr 3, 2024
5590534
Driver for apply hotfix
Apr 3, 2024
2621a2c
smhp: add hotfix to hold lustre client
Apr 3, 2024
c537bbc
smhp: add hotfix to hold lustre client
Apr 3, 2024
d0fe50d
smhp: hotfix to mock gpu .deb package
Apr 3, 2024
df1b0e3
smhp: hotfix to mock gpu .deb package
Apr 3, 2024
08fd1b6
Bump pillow from 10.2.0 to 10.3.0 in /3.test_cases/4.DDP
dependabot[bot] Apr 3, 2024
13a441d
use docker restart always for DCGM and EFA NODE containers, to sustai…
nghtm Apr 3, 2024
0625a02
use docker restart always for DCGM and EFA NODE containers, to sustai…
nghtm Apr 3, 2024
ba7a748
Merge branch 'aws-samples:main' into exporter-updates
nghtm Apr 3, 2024
577aace
Merge branch 'aws-samples:main' into exporter-updates
nghtm Apr 3, 2024
84999f2
Merge pull request #240 from nghtm/exporter-updates
mhuguesaws Apr 3, 2024
a73df1b
Merge pull request #240 from nghtm/exporter-updates
mhuguesaws Apr 3, 2024
23f6963
Adding comment why we uninstall ec2-instance-connect (#241)
shimomut Apr 3, 2024
2cd2b1e
Adding comment why we uninstall ec2-instance-connect (#241)
shimomut Apr 3, 2024
8f874e0
Merge pull request #239 from aws-samples/dependabot/pip/3.test_cases/…
perifaws Apr 3, 2024
cb2c9d6
Merge pull request #239 from aws-samples/dependabot/pip/3.test_cases/…
perifaws Apr 3, 2024
06f78ee
pcluster: add a small util script to fetch config from a running cluster
Apr 4, 2024
6343a90
pcluster: add a small util script to fetch config from a running cluster
Apr 4, 2024
6cb9781
smhp: fix issue #243
Apr 5, 2024
6e90775
smhp: fix issue #243
Apr 5, 2024
a73ac66
Merge pull request #244 from aws-samples/smhp-dpkg-retry
verdimrc Apr 5, 2024
fdd246c
Merge pull request #244 from aws-samples/smhp-dpkg-retry
verdimrc Apr 5, 2024
57415ae
start adding deepspeed example
KeitaW Apr 5, 2024
32fe215
start adding deepspeed example
KeitaW Apr 5, 2024
f130f6b
Activate conda environment
sean-smith Apr 5, 2024
3cdb653
Activate conda environment
sean-smith Apr 5, 2024
bdf2487
Merge pull request #245 from aws-samples/sean-smith-patch-2
KeitaW Apr 5, 2024
8c4dae9
Merge pull request #245 from aws-samples/sean-smith-patch-2
KeitaW Apr 5, 2024
a8b8aad
adopt code from megotron-deepspeed repository
KeitaW Apr 7, 2024
29da5fb
adopt code from megotron-deepspeed repository
KeitaW Apr 7, 2024
2da855d
Fix typo in 15.gpt-neox README
KeitaW Apr 7, 2024
58bb612
Fix typo in 15.gpt-neox README
KeitaW Apr 7, 2024
eb891e4
Update README.md
KeitaW Apr 7, 2024
3c88a6f
Update README.md
KeitaW Apr 7, 2024
34ba897
cleanup
KeitaW Apr 7, 2024
e682407
cleanup
KeitaW Apr 7, 2024
74b8479
update readme
KeitaW Apr 7, 2024
d09696c
update readme
KeitaW Apr 7, 2024
ea15006
update
KeitaW Apr 7, 2024
ab93f3c
update
KeitaW Apr 7, 2024
0b214b1
update
KeitaW Apr 7, 2024
525159e
update
KeitaW Apr 7, 2024
e0aa9d0
cleanup
KeitaW Apr 7, 2024
1a39144
cleanup
KeitaW Apr 7, 2024
7d0dfab
Update 2.train-mpt-manual-distributed.sbatch
KeitaW Apr 8, 2024
c05056d
Update 2.train-mpt-manual-distributed.sbatch
KeitaW Apr 8, 2024
18fd145
Merge pull request #248 from aws-samples/KeitaW-patch-2
verdimrc Apr 8, 2024
55d9397
Merge pull request #248 from aws-samples/KeitaW-patch-2
verdimrc Apr 8, 2024
7fffe50
Merge pull request #242 from aws-samples/pcluster-util-fetch-config
verdimrc Apr 8, 2024
b38410a
Merge pull request #242 from aws-samples/pcluster-util-fetch-config
verdimrc Apr 8, 2024
a6f9581
update
KeitaW Apr 8, 2024
538d64f
update
KeitaW Apr 8, 2024
00c07ac
update
KeitaW Apr 8, 2024
fa68fc3
update
KeitaW Apr 8, 2024
6c99cba
update
KeitaW Apr 9, 2024
580e825
update
KeitaW Apr 9, 2024
c0b2a85
update
KeitaW Apr 9, 2024
82a69ea
update
KeitaW Apr 9, 2024
a5ce2b2
removed
KeitaW Apr 9, 2024
8f35ef6
removed
KeitaW Apr 9, 2024
6b28131
Skip incomplete checkpoints in FSDP sample app (#251)
shimomut Apr 10, 2024
d16967a
Skip incomplete checkpoints in FSDP sample app (#251)
shimomut Apr 10, 2024
41052b7
Enable Auto-resume
sean-smith Apr 1, 2024
429fe94
Enable Auto-resume
sean-smith Apr 1, 2024
35af783
Validate provisioning_parameters.json
sean-smith Apr 10, 2024
ef4b597
Validate provisioning_parameters.json
sean-smith Apr 10, 2024
53ba90e
Merge pull request #253 from aws-samples/validate-config
KeitaW Apr 10, 2024
db79703
Merge pull request #253 from aws-samples/validate-config
KeitaW Apr 10, 2024
bffd4e6
Merge pull request #247 from aws-samples/deepspeed
KeitaW Apr 10, 2024
aae94b1
Merge pull request #247 from aws-samples/deepspeed
KeitaW Apr 10, 2024
ea0a581
Bump transformers in /3.test_cases/12.SM-dataparallel-FSDP/scripts
dependabot[bot] Apr 10, 2024
7d2fc9c
Bump transformers in /3.test_cases/12.SM-dataparallel-FSDP/scripts
dependabot[bot] Apr 10, 2024
dac63aa
Bump transformers in /3.test_cases/13.SM-dataparallel-deepspeed/code
dependabot[bot] Apr 10, 2024
a361538
Bump transformers in /3.test_cases/13.SM-dataparallel-deepspeed/code
dependabot[bot] Apr 10, 2024
e7824e8
Merge pull request #255 from aws-samples/dependabot/pip/3.test_cases/…
KeitaW Apr 11, 2024
313712d
Merge pull request #255 from aws-samples/dependabot/pip/3.test_cases/…
KeitaW Apr 11, 2024
dc4c301
Merge pull request #254 from aws-samples/dependabot/pip/3.test_cases/…
KeitaW Apr 11, 2024
b680627
Merge pull request #254 from aws-samples/dependabot/pip/3.test_cases/…
KeitaW Apr 11, 2024
b7f6ff8
Merge pull request #246 from aws-samples/KeitaW-patch-1
verdimrc Apr 11, 2024
284ea5a
Merge pull request #246 from aws-samples/KeitaW-patch-1
verdimrc Apr 11, 2024
efbbe53
nemo-launcher: support nemo-launcher with patch version; increase ver…
Apr 12, 2024
1bca0ff
nemo-launcher: support nemo-launcher with patch version; increase ver…
Apr 12, 2024
654ba82
Merge pull request #258 from aws-samples/nemo-launcher-bcm
KeitaW Apr 12, 2024
b6461fb
Merge pull request #258 from aws-samples/nemo-launcher-bcm
KeitaW Apr 12, 2024
2eca998
torchtune usecase
pbelevich Apr 12, 2024
db51efe
torchtune usecase
pbelevich Apr 12, 2024
6002abb
torchtune usecase
pbelevich Apr 12, 2024
6416560
add initial draft
KeitaW May 19, 2024
a8e5bba
add initial draft
KeitaW May 19, 2024
0105a19
add initial draft
KeitaW May 19, 2024
d5f3555
add docs
KeitaW May 19, 2024
345b729
add docs
KeitaW May 19, 2024
001a09d
add docs
KeitaW May 19, 2024
0b4e8e5
update
KeitaW May 19, 2024
309ef58
update
KeitaW May 19, 2024
2003183
update
KeitaW May 19, 2024
ab5d3d5
reorganize
KeitaW May 19, 2024
2726b58
reorganize
KeitaW May 19, 2024
1e8d3e8
reorganize
KeitaW May 19, 2024
f98accf
update
KeitaW May 20, 2024
4c0c69d
update
KeitaW May 20, 2024
c907933
update
KeitaW May 20, 2024
678f985
current state
KeitaW May 24, 2024
8b65168
current state
KeitaW May 24, 2024
ac4c45f
current state
KeitaW May 24, 2024
27d6967
update
KeitaW May 24, 2024
94a0dbc
update
KeitaW May 24, 2024
e9ac7d2
update
KeitaW May 24, 2024
5b91caf
Make *.sh files executable
pbelevich May 24, 2024
e63f623
Make *.sh files executable
pbelevich May 24, 2024
244db77
Make *.sh files executable
pbelevich May 24, 2024
04adb95
Update 3.test_cases/torchtitan-torchtune/slurm/README.md
KeitaW May 25, 2024
9a3160a
Update 3.test_cases/torchtitan-torchtune/slurm/README.md
KeitaW May 25, 2024
28d43c2
Update 3.test_cases/torchtitan-torchtune/slurm/README.md
KeitaW May 25, 2024
762f21e
Update 3.test_cases/torchtitan-torchtune/slurm/README.md
KeitaW May 25, 2024
7ca1fc9
Update 3.test_cases/torchtitan-torchtune/slurm/README.md
KeitaW May 25, 2024
5f7eb84
Update 3.test_cases/torchtitan-torchtune/slurm/README.md
KeitaW May 25, 2024
f1da782
local change
KeitaW May 25, 2024
690df84
local change
KeitaW May 25, 2024
83a7e2d
local change
KeitaW May 25, 2024
5d8b9b7
update README.md
KeitaW May 25, 2024
e341ec3
update README.md
KeitaW May 25, 2024
fbdb034
update README.md
KeitaW May 25, 2024
c9f8ac7
update README
KeitaW May 25, 2024
45ac55d
update README
KeitaW May 25, 2024
e5ac23f
update README
KeitaW May 25, 2024
fa15546
separate libraries
KeitaW May 26, 2024
f826581
separate libraries
KeitaW May 26, 2024
0d15a9d
separate libraries
KeitaW May 26, 2024
dd13ba0
update README
KeitaW May 26, 2024
daa65b6
update README
KeitaW May 26, 2024
bfb35c6
update README
KeitaW May 26, 2024
ff02c72
move container image
KeitaW May 26, 2024
f55f60d
move container image
KeitaW May 26, 2024
06e6df0
move container image
KeitaW May 26, 2024
0adc47f
update
KeitaW May 26, 2024
948418a
update
KeitaW May 26, 2024
9af9bb9
update
KeitaW May 26, 2024
a55c333
update to make it compatible with SMHP
KeitaW May 26, 2024
13a5aff
update to make it compatible with SMHP
KeitaW May 26, 2024
e24b21a
update to make it compatible with SMHP
KeitaW May 26, 2024
08342fc
update readme
KeitaW May 26, 2024
9db7487
update readme
KeitaW May 26, 2024
c7a8cf0
update readme
KeitaW May 26, 2024
6cb0dc3
update
KeitaW May 26, 2024
5276eda
update
KeitaW May 26, 2024
fbd278e
update
KeitaW May 26, 2024
523de6e
update
KeitaW May 26, 2024
b322db3
update
KeitaW May 26, 2024
3eba59c
update
KeitaW May 26, 2024
19e4cba
update README
KeitaW May 26, 2024
a604cf0
update README
KeitaW May 26, 2024
392d28e
update README
KeitaW May 26, 2024
6a940fb
update script
KeitaW May 26, 2024
fe62d87
update script
KeitaW May 26, 2024
1be9da2
update script
KeitaW May 26, 2024
6d4da01
remove torchtitan
KeitaW May 26, 2024
19dd02c
remove torchtitan
KeitaW May 26, 2024
1e5aedd
remove torchtitan
KeitaW May 26, 2024
7b292ab
update tutorials
KeitaW May 26, 2024
c8ceac8
update tutorials
KeitaW May 26, 2024
e6c47cf
update tutorials
KeitaW May 26, 2024
89b4927
update LoRA part WIP
KeitaW May 27, 2024
b024895
update LoRA part WIP
KeitaW May 27, 2024
b195f0b
update LoRA part WIP
KeitaW May 27, 2024
fbabf38
update
KeitaW May 30, 2024
1026a36
update
KeitaW May 30, 2024
eacd729
update
KeitaW May 30, 2024
33e99e2
update
KeitaW May 30, 2024
ebf995c
update
KeitaW May 30, 2024
e9abb4e
update
KeitaW May 30, 2024
3083036
clean up
KeitaW May 31, 2024
c0397a4
clean up
KeitaW May 31, 2024
7eb0f66
clean up
KeitaW May 31, 2024
3b0d9e6
update
KeitaW May 31, 2024
332285e
update
KeitaW May 31, 2024
d4029d2
update
KeitaW May 31, 2024
ae98bf9
update
KeitaW Jun 2, 2024
64e0724
update
KeitaW Jun 2, 2024
00dfbf5
update
KeitaW Jun 2, 2024
b929043
Merge branch 'torchtitan-torchtune' of github.com:aws-samples/awsome-…
KeitaW Jun 4, 2024
4ac5496
Merge branch 'torchtitan-torchtune' of github.com:aws-samples/awsome-…
KeitaW Jun 4, 2024
952eba3
update
KeitaW Jun 4, 2024
563e807
update
KeitaW Jun 5, 2024
71c33f6
Merge branch 'main' into torchtitan-torchtune
KeitaW Jun 5, 2024
77d4908
Update 3.test_cases/torchtune/slurm/README.md
KeitaW Jun 11, 2024
0133094
Update 3.test_cases/torchtune/slurm/tutorials/e2e-llama3-70b-developm…
KeitaW Jun 11, 2024
f8833b7
Update 3.test_cases/torchtune/slurm/README.md
KeitaW Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
KeitaW committed Jun 5, 2024
commit 563e8071bbaa9e4e7c61a6df7e48b1a4dd9bb193
2 changes: 1 addition & 1 deletion 3.test_cases/torchtune/slurm/torchtune.dockerfile
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
# # Load image to local docker registry -> on head node, or new compute/build node.
# docker load < /fsx/nvidia-pt-od__latest.tar
####################################################################################################
FROM nvcr.io/nvidia/pytorch:24.04-py3
FROM nvcr.io/nvidia/pytorch:24.05-py3
ENV DEBIAN_FRONTEND=noninteractive

# The three must-be-built packages.
Original file line number Diff line number Diff line change
@@ -68,18 +68,62 @@ By following these steps, you ensure that the necessary model components are in

## 3. Continuous Pretraining

In this step, you will fine-tune Llama3 model from the orinal checkpoint. Specifically, the finetune process in this step is called Full-parameter finetuning, which will update all the parameters in the original model. One of the problem we encounter in such training is memory consumption. A typical model trained in mixed precision with AdamW requires 18 bytes per model parameter plus activation memory (6 bytes for parameters for mixed precision training, 8 bytes for AdamW, 4 bytes).For more details of the anatomy, see [huggingface blog post](https://huggingface.co/docs/transformers/model_memory_anatomy). This means that 70B parameter model training would require more than 1.12 TB of accelerated memory, which is way bigger than 80 GB of H100 accelerated memory size. To tackle the problem, `torchtune` integrates PyTorch Fully Distributed Data Parallel (FSDP). In this framework. PyTorch Fully Sharded Data Parallel (FSDP) is a distributed training feature designed to efficiently handle large model training by sharding model parameters, gradients, and optimizer states across multiple devices. This approach significantly reduces memory consumption and optimizes resource utilization, making it possible to train models that are too large to fit on a single GPU.
In this step, you will fine-tune the Llama3 model starting from the original checkpoint using the WikiText dataset. This process, known as Full-Parameter Finetuning, updates all the parameters in the original model. The configuration file used for this process is `./tutorials/e2e-llama3-70b-development/full_finetune_distributed.yaml`.

### Memory Consumption Challenges
One of the primary challenges during such training is memory consumption. A typical model trained in mixed precision with AdamW requires 18 bytes per model parameter plus activation memory (6 bytes for parameters in mixed precision training, 8 bytes for AdamW, and 4 bytes for other overheads). For more details on the anatomy, see the [Hugging Face blog post](https://huggingface.co/docs/transformers/model_memory_anatomy) blog post. This means that training a 70B parameter model would require more than 1.12 TB of accelerated memory, which far exceeds the 80 GB capacity of H100 accelerated memory. To address this issue, torchtune integrates PyTorch Fully Sharded Data Parallel (FSDP).
Copy link
Collaborator

@pbelevich pbelevich Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
One of the primary challenges during such training is memory consumption. A typical model trained in mixed precision with AdamW requires 18 bytes per model parameter plus activation memory (6 bytes for parameters in mixed precision training, 8 bytes for AdamW, and 4 bytes for other overheads). For more details on the anatomy, see the [Hugging Face blog post](https://huggingface.co/docs/transformers/model_memory_anatomy) blog post. This means that training a 70B parameter model would require more than 1.12 TB of accelerated memory, which far exceeds the 80 GB capacity of H100 accelerated memory. To address this issue, torchtune integrates PyTorch Fully Sharded Data Parallel (FSDP).
One of the primary challenges during such training is memory consumption. A typical model trained in mixed precision with AdamW requires 18 bytes per model parameter(6 bytes for parameter in mixed precision training, 4 bytes for gradient and 8 bytes for AdamW optimizer states) plus activation memory. For more details on the anatomy, see the [Hugging Face blog post](https://huggingface.co/docs/transformers/model_memory_anatomy) blog post. This means that training a 70B parameter model would require more than 1.12 TiB of accelerator's memory, which far exceeds the 80 GB capacity of H100 memory. To address this issue, torchtune integrates PyTorch Fully Sharded Data Parallel (FSDP).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was 1.12 TiB calculated?
70_000_000_000 * 18 = 1_260_000_000_000
1_260_000_000_000 / 1024 / 1024 / 1024 / 1024 = 1.15TiB

Copy link
Collaborator

@pbelevich pbelevich Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memory is not accelerated itself


### Basic concepts and relevant configuration

**FSDP** is a distributed training feature designed to efficiently handle large model training by sharding model parameters, gradients, and optimizer states across multiple devices. This approach significantly reduces memory consumption and optimizes resource utilization, making it possible to train models that are too large to fit on a single GPU. In `torchtune` users can launch FSDP training job with command `tune run full_finetune_distributed`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**FSDP** is a distributed training feature designed to efficiently handle large model training by sharding model parameters, gradients, and optimizer states across multiple devices. This approach significantly reduces memory consumption and optimizes resource utilization, making it possible to train models that are too large to fit on a single GPU. In `torchtune` users can launch FSDP training job with command `tune run full_finetune_distributed`.
**FSDP** is a distributed training technique designed to efficiently handle large model training by sharding model parameters, gradients, and optimizer states across multiple devices. This approach significantly reduces memory consumption and optimizes resource utilization, making it possible to train models that are too large to fit on a single GPU. In `torchtune` users can launch FSDP training job with command `tune run full_finetune_distributed`.


**The WikiText language modeling dataset** is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. `torchtune` has a module preconfigured for this dataset. The configuration file preconfigures the WikiText dataset as follows:

```yaml
dataset:
_component_: torchtune.datasets.wikitext_dataset
```
### Submit the training job
Submit the job with the following command:
```bash
sbatch tutorials/e2e-llama3-70b-development/full_finetune_distributed.sbatch
```

By default, this script launches the FSDP training job with two instances. Once the job has been scheduled, you will see the following outputs in the log file named `logs/full-finetuning*`:
Copy link
Collaborator

@pbelevich pbelevich Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where two instances are specified by default, I see only --nnodes 1 --nnodes=1 in sbatch files


```bash
# tail -f logs/full-finetuning*
Executing following command:
tune run --master_addr 10.1.62.14 --master_port 28415 --nproc_per_node=8 --nnodes 2 --rdzv_backend=c10d --rdzv_endpoint=p5-st-p5-1 full_finetune_distributed --config /fsx/ubuntu/awsome-distributed-training/3.test_cases/torchtune/slurm/tutorials/e2e-llama3-70b-development/configs/full_finetune_distributed.yaml tokenizer.path=/fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B/original/tokenizer.model checkpointer.checkpoint_dir=/fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B checkpointer.output_dir=/fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B-tuned output_dir=/fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B-tuned/log metric_logger.log_dir=/fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B-tuned/log/metrics
...
0: wandb: Currently logged in as: <YOURUSERNAME>. Use `wandb login --relogin` to force relogin
0: wandb: Tracking run with wandb version 0.17.0
0: wandb: Run data is saved locally in /fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B-tuned/log/metrics/wandb/run-20240527_001350-oziekm6j
0: wandb: Run `wandb offline` to turn off syncing.
0: wandb: Syncing run helpful-surf-1
0: wandb: ⭐️ View project at https://wandb.ai/<YOURUSERNAME>/torchtune
0: wandb: 🚀 View run at https://wandb.ai/<YOURUSERNAME>/torchtune/runs/oziekm6j
0: 2024-05-27:00:13:50,919 INFO [metric_logging.py:225] Logging /fsx/ubuntu/models/torchtune/meta-llama/Meta-Llama-3-70B/torchtune_config.yaml to W&B under Files
...
```

Notice that the job is being tracked by WANDB because of the following section in the config file:

```yaml
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
log_dir: None
```
On the WANDB dashboard (`https://wandb.ai/<YOURUSERNAME>/torchtune`), you can monitor the learning curve, compute resource utilization, log outputs, and more.


## 4. Instruction-tuning

In this step, you will fine-tune the LLaMA model using Low-Rank Adaptation (LoRA) with the Alpaca dataset. We will first cover the basic concepts and relevant configurations found in the [config file](configs/lora_finetune_distributed.yaml), followed by a detailed fine-tuning tutorial.
In this step, you will fine-tune the Llama model using Low-Rank Adaptation (LoRA) with the Alpaca dataset. We will first cover the basic concepts and relevant configurations found in the [config file](configs/lora_finetune_distributed.yaml), followed by a detailed fine-tuning tutorial.


### Basic Concepts and Relevant Configurations
@@ -115,9 +159,6 @@ dataset:

As the config suggests, we use a predefined dataset class prepared in torchtune.

## 5. Alignment



### Submit Finetuning job

@@ -128,13 +169,12 @@ You can submit the finetuning job with the following command:
sbatch tutorials/e2e-llama3-70b-development/lora_finetune_distributed.sbatch
```

Once the job has been scheduled, you will see following outputs in the log:

Once the job has been scheduled, you will see following outputs in the logo output named `logs/:

```bash
...
Executing following command:
torchtune run --master_addr 10.1.28.89 --master_port 14280 --nproc_per_node=8 --nnodes 1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=p5-st-p5-2 lora_finetune_distributed
tune run --master_addr 10.1.28.89 --master_port 14280 --nproc_per_node=8 --nnodes 1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=p5-st-p5-2 lora_finetune_distributed
...
0: wandb: Currently logged in as: <YOURUSERNAME>. Use `wandb login --relogin` to force relogin
0: wandb: Tracking run with wandb version 0.17.0
@@ -149,7 +189,7 @@ torchtune run --master_addr 10.1.28.89 --master_port 14280 --nproc_per_node=8 --
As the output indicates, we run a single-node distributed training job with 8 GPUs here.
```bash
torchtune run --master_addr 10.1.28.89 --master_port 14280 --nproc_per_node=8 --nnodes 1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=p5-st-p5-2 lora_finetune_distributed
tune run --master_addr 10.1.28.89 --master_port 14280 --nproc_per_node=8 --nnodes 1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=p5-st-p5-2 lora_finetune_distributed
```


@@ -191,33 +231,6 @@ You can submit sample evaluation job by:
sbatch evaluate.sbatch
```

You will see:

```
Running loglikelihood requests: 6%|▋ | 23/400 [00:01<00:18, 20.53it/s]
Running loglikelihood requests: 16%|█▌ | 62/400 [00:02<00:15, 22.65it/s]
Running loglikelihood requests: 24%|██▍ | 98/400 [00:04<00:13, 22.50it/s]
Running loglikelihood requests: 33%|███▎ | 131/400 [00:06<00:12, 22.28it/s]
Running loglikelihood requests: 42%|████▏ | 164/400 [00:07<00:10, 22.40it/s]
Running loglikelihood requests: 50%|█████ | 200/400 [00:09<00:08, 22.60it/s]
Running loglikelihood requests: 58%|█████▊ | 233/400 [00:10<00:07, 22.46it/s]
Running loglikelihood requests: 66%|██████▌ | 263/400 [00:11<00:06, 22.51it/s]
Running loglikelihood requests: 74%|███████▍ | 296/400 [00:13<00:04, 22.45it/s]
Running loglikelihood requests: 82%|█�██████▏ | 326/400 [00:14<00:03, 22.63it/s]/s]
Running loglikelihood requests: 90%|████████▉ | 356/400 [00:16<00:01, 22.82it/s]
Running loglikelihood requests: 97%|█████████▋| 389/400 [00:17<00:00, 23.11it/s]
Running loglikelihood requests: 100%|██████████| 400/400 [00:17<00:00, 22.27it/s]
0: fatal: not a git repository (or any of the parent directories): .git
0: 2024-05-07:01:12:39,479 INFO [eval.py:69] vllm (pretrained=meta-llama/Meta-Llama-3-70B,tensor_parallel_size=8,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1), gen_kwargs: (None), limit: 100.0, num_fewshot: None, batch_size: 1
0: 2024-05-07:01:12:39,536 INFO [eval.py:70] | Tasks |Version|Filter|n-shot| Metric |Value| |Stderr|
0: |---------|------:|------|-----:|--------|----:|---|-----:|
0: |hellaswag| 1|none | 0|acc | 0.56|± |0.0499|
0: | | |none | 0|acc_norm| 0.75|± |0.0435|
0:
```



## 6. Quantization

In the production setting, it is often not feasible to deploy large model as it is, this requires
Original file line number Diff line number Diff line change
@@ -100,9 +100,9 @@ fsdp_cpu_offload: True
dtype: bf16

# Logging
output_dir: None
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama3-finetune
_component_: torchtune.utils.metric_logging.WandBLogger
log_dir: None
log_every_n_steps: 1
log_peak_memory_stats: False
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Config for multi-device LoRA in lora_finetune_distributed.py
# using a Llama3 70B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3-70B-Instruct --hf-token <TOKEN> --output-dir /tmp/Meta-Llama-3-70b --ignore-patterns "original/consolidated*"
#
# This config needs 8 GPUs to run
# # tune run --nproc_per_node 8 lora_finetune_distributed --config recipes/configs/llama3/70B_lora.yaml
#

# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_70b
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 16
lora_alpha: 32

tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: None

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: None
checkpoint_files: [
model-00001-of-00030.safetensors,
model-00002-of-00030.safetensors,
model-00003-of-00030.safetensors,
model-00004-of-00030.safetensors,
model-00005-of-00030.safetensors,
model-00006-of-00030.safetensors,
model-00007-of-00030.safetensors,
model-00008-of-00030.safetensors,
model-00009-of-00030.safetensors,
model-00010-of-00030.safetensors,
model-00011-of-00030.safetensors,
model-00012-of-00030.safetensors,
model-00013-of-00030.safetensors,
model-00014-of-00030.safetensors,
model-00015-of-00030.safetensors,
model-00016-of-00030.safetensors,
model-00017-of-00030.safetensors,
model-00018-of-00030.safetensors,
model-00019-of-00030.safetensors,
model-00020-of-00030.safetensors,
model-00021-of-00030.safetensors,
model-00022-of-00030.safetensors,
model-00023-of-00030.safetensors,
model-00024-of-00030.safetensors,
model-00025-of-00030.safetensors,
model-00026-of-00030.safetensors,
model-00027-of-00030.safetensors,
model-00028-of-00030.safetensors,
model-00029-of-00030.safetensors,
model-00030-of-00030.safetensors,
]
recipe_checkpoint: null
output_dir: None
model_type: LLAMA3
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 1

# Logging
output_dir: None
metric_logger:
_component_: torchtune.utils.metric_logging.WandBLogger
log_dir: None
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
Original file line number Diff line number Diff line change
@@ -78,5 +78,5 @@ export PYTHONPATH=${PWD}/torchtune
export TORCHTUNE=${PWD}/torchtune/torchtune/_cli/tune.py
export TORCHTUNE_COMMAND="eleuther_eval"
echo "Executing following command:"
echo "torchtune" "run" "${TORCHTUNE_COMMAND}" "${TRAIN_ARGS[@]}"
echo "tune" "run" "${TORCHTUNE_COMMAND}" "${TRAIN_ARGS[@]}"
srun -l "${SRUN_ARGS[@]}" python ${TORCHTUNE} run "${TORCHTUNE_COMMAND}" "${TRAIN_ARGS[@]}"
Original file line number Diff line number Diff line change
@@ -4,8 +4,8 @@
# SPDX-License-Identifier: MIT-0

#SBATCH --job-name=full-finetuning
#SBATCH --nodes=2
#SBATCH --ntasks=2
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --gpus-per-node=8 # Number of GPU per node
#SBATCH --output=logs/%x_%j.out # logfile for stdout
#SBATCH --error=logs/%x_%j.err # logfile for stderr, remove it to merge both outputs
@@ -91,5 +91,5 @@ export PYTHONPATH=${PWD}/torchtune
export TORCHTUNE=${PWD}/torchtune/torchtune/_cli/tune.py
export TORCHTUNE_COMMAND="full_finetune_distributed"
echo "Executing following command:"
echo "torchtune" "run" "${TORCHRUN_ARGS[@]}" "${TORCHTUNE_COMMAND}" "${TORCHTUNE_ARGS[@]}"
echo "tune" "run" "${TORCHRUN_ARGS[@]}" "${TORCHTUNE_COMMAND}" "${TORCHTUNE_ARGS[@]}" "${TRAIN_ARGS[@]}"
srun -l "${SRUN_ARGS[@]}" python ${TORCHTUNE} run "${TORCHRUN_ARGS[@]}" "${TORCHTUNE_COMMAND}" "${TRAIN_ARGS[@]}"