-
Notifications
You must be signed in to change notification settings - Fork 293
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #616 from RissyRan:separate_mixtral
PiperOrigin-RevId: 628127953
- Loading branch information
Showing
3 changed files
with
56 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#!/bin/bash | ||
|
||
# This file, combined with step 2 in the same directory, runs on daily basis and demonstrates: | ||
# 1. Converts the Mistral PyTorch checkpoint to MaxText(orbax) format using a CPU VM. | ||
# 2. Takes the MaxText (orbax) checkpoint to run inference and fine-tuning on a TPU VM. | ||
|
||
# The flow of this file is to convert the Mistral PyTorch checkpoint to MaxText (orbax) format using a CPU VM. | ||
|
||
# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/mixtral/8x7b/1_test_mixtral.sh | ||
# Use the same BASE_OUTPUT_PATH for both 1_test_mixtral.sh & 2_test_mixtral.sh. | ||
|
||
set -ex | ||
MODEL_VARIATION='8x7b' | ||
|
||
if [ -z "${BASE_OUTPUT_PATH}" ]; then | ||
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing. | ||
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)/ | ||
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" | ||
fi | ||
|
||
# Download checkpoint, convert it to MaxText(orbax) format | ||
pip3 install torch | ||
gcloud storage cp -r gs://maxtext-external/mixtral-8x7B-v0.1-Instruct /tmp | ||
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/mixtral-8x7B-v0.1-Instruct --model-size mixtral-8x7b --maxtext-model-path ${BASE_OUTPUT_PATH}${MODEL_VARIATION}/decode-ckpt-maxtext/ | ||
echo "Wrote MaxText compatible checkpoint to ${BASE_OUTPUT_PATH}${MODEL_VARIATION}/decode-ckpt-maxtext" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#!/bin/bash | ||
|
||
# This file, combined with step 1 in the same directory, runs on daily basis and demonstrates: | ||
# 1. Converts the Mistral PyTorch checkpoint to MaxText(orbax) format using a CPU VM. | ||
# 2. Takes the MaxText(orbax) checkpoint to run inference and fine-tuning on a TPU VM. | ||
|
||
# The flow of this file is to take the MaxText(orbax) checkpoint to run inference and fine-tuning on a TPU VM. | ||
# Please make sure you have run end_to_end/tpu/mixtral/8x7b/1_test_mixtral.sh before running commands from this file. | ||
|
||
# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh | ||
# Use the same BASE_OUTPUT_PATH for both 1_test_mixtral.sh & 2_test_mixtral.sh. | ||
|
||
set -ex | ||
MODEL_VARIATION='8x7b' | ||
|
||
if [ -z "${BASE_OUTPUT_PATH}" ]; then | ||
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing. | ||
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) | ||
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" | ||
fi | ||
|
||
export M_ENABLE_CHECKPOINTING=true | ||
export M_BASE_OUTPUT_DIRECTORY=${BASE_OUTPUT_PATH}${MODEL_VARIATION} | ||
export M_DATASET_PATH=gs://maxtext-dataset | ||
export M_ASYNC_CHECKPOINTING=false | ||
|
||
# Run decoding | ||
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${BASE_OUTPUT_PATH}${MODEL_VARIATION}/decode-ckpt-maxtext/0/items run_name=decoding per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=24 prompt="[INST] I love to [/INST]" autoregressive_decode_assert="That's great to hear! I love to learn new things" attention=dot_product | ||
|
||
# Run fine-tuning | ||
python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=${BASE_OUTPUT_PATH}${MODEL_VARIATION}/decode-ckpt-maxtext/0/items run_name=fine_tuning per_device_batch_size=1 model_name=mixtral-8x7b ici_tensor_parallelism=4 ici_fsdp_parallelism=16 steps=10 max_target_length=1024 tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral |
This file was deleted.
Oops, something went wrong.