diff --git a/MaxText/configs/a3/llama_2_7b/16vm.sh b/MaxText/configs/a3/llama_2_7b/16vm.sh index e7c814d93..e63a5d816 100644 --- a/MaxText/configs/a3/llama_2_7b/16vm.sh +++ b/MaxText/configs/a3/llama_2_7b/16vm.sh @@ -10,6 +10,7 @@ set -e export OUTPUT_PATH="gs://maxtext-experiments-multipod" export RUN_NAME="llama-2-16vm-$(date +%Y-%m-%d-%H-%M)" +export EXECUTABLE="train.py" # Set environment variables for ARGUMENT in "$@"; do @@ -29,5 +30,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ --xla_disable_hlo_passes=rematerialization" # 16 nodes -python MaxText/train.py MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ +python MaxText/$EXECUTABLE MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ dcn_data_parallelism=16 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane