-
Notifications
You must be signed in to change notification settings - Fork 293
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #639 from google:anisha-llama2-70b-test
PiperOrigin-RevId: 630417921
- Loading branch information
Showing
4 changed files
with
111 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#!/bin/bash | ||
|
||
# This file, combined with step 2 in the same directory, demonstrates converting a Llama2-70B checkpoint from Meta and running various MaxText operations on it. | ||
# This step is tested nightly on an ordinary CPU VM. | ||
|
||
# The flow of this file is as follows: | ||
# 1. Pull the checkpoint from a GCS bucket and uploads the new MaxText compatible checkpoint to destination GCS bucket. | ||
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. | ||
|
||
# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh | ||
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh. | ||
# Please note that in these two scripts (1_test_llama2_70b.sh and 2_test_llama2_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and | ||
# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. | ||
|
||
set -ex | ||
MODEL_VARIATION='llama2-70b' | ||
|
||
|
||
# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU | ||
pip install torch --index-url https://download.pytorch.org/whl/cpu | ||
|
||
# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint | ||
export META_CHECKPOINT_PATH=gs://maxtext-llama/llama2-70b/meta-ckpt | ||
|
||
# In the following command, we are copying Meta's checkpoint into a local directory `tmp`. | ||
# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py` | ||
gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/ | ||
|
||
if [ -z "${BASE_OUTPUT_PATH}" ]; then | ||
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing. | ||
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama2/70b/2_test_llama2_70b | ||
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) | ||
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" | ||
fi | ||
|
||
echo "Converted checkpoints are stored at ${BASE_OUTPUT_PATH}" | ||
|
||
#Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path` | ||
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --maxtext-model-path ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt --model-size ${MODEL_VARIATION} | ||
|
||
echo "Wrote MaxText compatible checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt" | ||
|
||
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. | ||
export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items | ||
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. | ||
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. | ||
export RUN_NAME=unscanned_chkpt | ||
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='llama2-70b' force_unroll=true | ||
echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#!/bin/bash | ||
|
||
# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama2-70b. | ||
# Please make sure you have run end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh before running commands from this file. | ||
|
||
# The flow of this file is as follows: | ||
# 1. Run decoding, finetuning of Llama2-70B with the converted checkpoint obtained from end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh. Also, run pretraining of Llama2-70B | ||
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. | ||
# 3. Run decoding from the finetuned checkpoint from step 1 | ||
|
||
|
||
# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh | ||
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh | ||
# Please note that in these two scripts (1_test_llama2_70b.sh and 2_test_llama2_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and | ||
# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. | ||
|
||
set -ex | ||
export MODEL_VARIATION='llama2-70b' | ||
|
||
if [ -z "${BASE_OUTPUT_PATH}" ]; then | ||
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run | ||
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh | ||
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) | ||
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" | ||
fi | ||
|
||
|
||
|
||
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data | ||
export DATASET_PATH=gs://maxtext-dataset | ||
|
||
|
||
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands | ||
export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items | ||
export RUN_NAME=unscanned_chkpt | ||
# We defined path to unscanned checkpoint created in 1_test_llama2_70b.sh | ||
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items | ||
|
||
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. | ||
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` | ||
# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` | ||
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to read books, magazines," | ||
|
||
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` | ||
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to read books, magazines," | ||
|
||
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning | ||
export FINETUNE_RUN_NAME=runner_finetune | ||
python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 | ||
|
||
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from | ||
python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} | ||
|
||
# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. | ||
# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. | ||
# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding | ||
export PARAM_RUN_NAME=param_chkpt | ||
python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true | ||
|
||
# Now, run decoding on the checkpoint generated from our finetune run. | ||
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters