From e685da136b83c9462d5346346c7b8646ce234ff4 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Tue, 13 Aug 2024 21:32:16 +0000 Subject: [PATCH] Support AoT in 16-vm GPU Llama2 train script --- MaxText/configs/a3/llama_2_7b/16vm.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MaxText/configs/a3/llama_2_7b/16vm.sh b/MaxText/configs/a3/llama_2_7b/16vm.sh index e7c814d93..e63a5d816 100644 --- a/MaxText/configs/a3/llama_2_7b/16vm.sh +++ b/MaxText/configs/a3/llama_2_7b/16vm.sh @@ -10,6 +10,7 @@ set -e export OUTPUT_PATH="gs://maxtext-experiments-multipod" export RUN_NAME="llama-2-16vm-$(date +%Y-%m-%d-%H-%M)" +export EXECUTABLE="train.py" # Set environment variables for ARGUMENT in "$@"; do @@ -29,5 +30,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ --xla_disable_hlo_passes=rematerialization" # 16 nodes -python MaxText/train.py MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ +python MaxText/$EXECUTABLE MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ dcn_data_parallelism=16 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane