Skip to content

Commit

Permalink
Support AoT in 16-vm GPU Llama2 train script
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Aug 13, 2024
1 parent f904ede commit e685da1
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion MaxText/configs/a3/llama_2_7b/16vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set -e

export OUTPUT_PATH="gs://maxtext-experiments-multipod"
export RUN_NAME="llama-2-16vm-$(date +%Y-%m-%d-%H-%M)"
export EXECUTABLE="train.py"

# Set environment variables
for ARGUMENT in "$@"; do
Expand All @@ -29,5 +30,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_disable_hlo_passes=rematerialization"

# 16 nodes
python MaxText/train.py MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
python MaxText/$EXECUTABLE MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \
dcn_data_parallelism=16 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane

0 comments on commit e685da1

Please sign in to comment.