Skip to content

Commit

Permalink
Update and rename 1024b.sh to v5p-12288.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
Obliviour authored Aug 14, 2024
1 parent f904ede commit af0db80
Showing 1 changed file with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
echo "Running 1024b.sh"
# 1024B parameter model.
echo "Running v5p-12288.sh"
# X parameter model.
# This config will work out of the box for any number of v5p-2048 or v5p-4096 slices.
#
# Command Flags:
Expand All @@ -8,10 +8,10 @@ echo "Running 1024b.sh"
# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE)
#
# Example to invoke this script:
# bash MaxText/configs/v5p/1024b.sh RUN_NAME="<your_run_name>" OUTPUT_PATH="gs://<your_output_path>" DATASET_PATH="gs://<your_dataset_path>"
# bash MaxText/configs/v5p/v5p-12288.sh STEPS=10000 RUN_NAME="<your_run_name>" OUTPUT_PATH="gs://<your_output_path>" DATASET_PATH="gs://<your_dataset_path>"
#
# Example to AOT compile:
# bash MaxText/configs/v5p/1024b.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v5p-2048 M_COMPILE_TOPOLOGY_NUM_SLICES=2
# bash MaxText/configs/v5p/v5p-12288.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v5p-2048 M_COMPILE_TOPOLOGY_NUM_SLICES=2


# Stop execution if any command exits with error
Expand Down Expand Up @@ -39,9 +39,10 @@ bash preflight.sh
# Train
export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
python3 MaxText/$EXECUTABLE MaxText/configs/base.yml\
steps=15 per_device_batch_size=2 enable_checkpointing=false\
remat_policy=full global_parameter_scale=1024\
steps=$STEPS per_device_batch_size=2 enable_checkpointing=false\
remat_policy=full global_parameter_scale=512\
ici_fsdp_parallelism=-1 ici_tensor_parallelism=16\
max_target_length=2048 base_output_directory=$OUTPUT_PATH\
base_emb_dim=6144 \
dataset_path=$DATASET_PATH use_iota_embed=true reuse_example_batch=1\
dataset_type=synthetic gcs_metrics=true attention='flash' quantization=""

0 comments on commit af0db80

Please sign in to comment.