-
Notifications
You must be signed in to change notification settings - Fork 274
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c7c3f4e
commit 2ab924d
Showing
16 changed files
with
438 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# model config for llama3.1-405b | ||
|
||
base_emb_dim: 16384 | ||
base_num_query_heads: 128 | ||
base_num_kv_heads: 8 | ||
base_num_decoder_layers: 126 | ||
base_mlp_dim: 53248 | ||
head_dim: 128 | ||
mlp_activations: ["silu","linear"] | ||
vocab_size: 128256 | ||
enable_dropout: False | ||
logits_via_embedding: False | ||
normalization_layer_epsilon: 1.0e-5 | ||
rope_max_timescale: 500_000 | ||
decoder_block: "llama2" # Uses the same decoder block as llama2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# model config for llama3.1-70b | ||
|
||
base_emb_dim: 8192 | ||
base_num_query_heads: 64 | ||
base_num_kv_heads: 8 | ||
base_num_decoder_layers: 80 | ||
base_mlp_dim: 28672 | ||
head_dim: 128 | ||
mlp_activations: ["silu","linear"] | ||
vocab_size: 128256 | ||
enable_dropout: False | ||
logits_via_embedding: False | ||
normalization_layer_epsilon: 1.0e-5 | ||
rope_max_timescale: 500_000 | ||
decoder_block: "llama2" # Uses the same decoder block as llama2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# model config for llama3.1-8b | ||
|
||
base_emb_dim: 4096 | ||
base_num_query_heads: 32 | ||
base_num_kv_heads: 8 | ||
base_num_decoder_layers: 32 | ||
base_mlp_dim: 14336 | ||
head_dim: 128 | ||
mlp_activations: ["silu","linear"] | ||
vocab_size: 128256 | ||
enable_dropout: False | ||
logits_via_embedding: False | ||
normalization_layer_epsilon: 1.0e-5 | ||
rope_max_timescale: 500_000 | ||
decoder_block: "llama2" # Uses the same decoder block as llama2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/bin/bash | ||
|
||
# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama3.1-405b. | ||
# Please make sure you have run end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh before running commands from this file. | ||
|
||
# The flow of this file is as follows: | ||
# 1. Run decoding, finetuning of Llama3.1-405B with the converted checkpoint obtained from end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh. Also, run pretraining of Llama3.1-70B | ||
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. | ||
# 3. Run decoding from the finetuned checkpoint from step 1 | ||
|
||
|
||
# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh | ||
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh | ||
# Please note that in these two scripts (1_test_llama3.1_405b.sh and 2_test_llama3.1_405b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and | ||
# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. | ||
|
||
set -ex | ||
export MODEL_VARIATION='llama3.1-405b' | ||
|
||
if [ -z "${BASE_OUTPUT_PATH}" ]; then | ||
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run | ||
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh | ||
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) | ||
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" | ||
fi | ||
|
||
|
||
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data | ||
export DATASET_PATH=gs://maxtext-dataset | ||
|
||
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands | ||
export CONVERTED_CHECKPOINT=gs://maxtext-llama/llama3.1_405b_instruct/maxtext-ckpt/0/items | ||
|
||
# We run finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning. | ||
# We use a small per_device_batch_size and SGD optmizer for the model to fit on a v4-128. | ||
export FINETUNE_RUN_NAME=runner_finetune | ||
# python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_type=synthetic tokenizer_path=assets/tokenizer_llama3.tiktoken per_device_batch_size=0.25 ici_tensor_parallelism=4 run_name=${FINETUNE_RUN_NAME} steps=10 enable_checkpointing=false model_name=${MODEL_VARIATION} logits_dot_in_fp32=false weight_dtype=bfloat16 opt_type=sgd | ||
|
||
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` | ||
# python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=0.0625 ici_tensor_parallelism=4 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product weight_dtype=bfloat16 prompt="I love to" | ||
|
||
# We also test whether the forward pass logits match the golden logits for Llama3.1-405B | ||
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=assets/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=0.0625 ici_tensor_parallelism=4 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic logits_dot_in_fp32=false weight_dtype=bfloat16 opt_type=sgd async_checkpointing=false --max_kl_div=0.15 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#!/bin/bash | ||
|
||
# This file, combined with step 2 in the same directory, demonstrates converting a Llama3-70B checkpoint from Meta and running various MaxText operations on it. | ||
# This step is tested nightly on an ordinary CPU VM. | ||
|
||
# The flow of this file is as follows: | ||
# 1. Pull the checkpoint from a GCS bucket and uploads the new MaxText compatible checkpoint to destination GCS bucket. | ||
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. | ||
|
||
# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3/70b/1_test_llama3_70b.sh | ||
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh. | ||
# Please note that in these two scripts (1_test_llama3_70b.sh and 2_test_llama3_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and | ||
# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs. | ||
|
||
set -ex | ||
MODEL_VARIATION='llama3.1-70b' | ||
|
||
# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU | ||
pip install torch --index-url https://download.pytorch.org/whl/cpu | ||
|
||
# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint | ||
export META_CHECKPOINT_PATH=gs://maxtext-llama/llama3.1_70b/meta-ckpt | ||
|
||
# In the following command, we are copying Meta's checkpoint into a local directory `tmp`. | ||
# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py` | ||
gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/ | ||
|
||
if [ -z "${BASE_OUTPUT_PATH}" ]; then | ||
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing. | ||
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3/70b/2_test_llama3_70b | ||
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) | ||
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" | ||
fi | ||
|
||
echo "Converted checkpoints will be stored at ${BASE_OUTPUT_PATH}" | ||
|
||
#Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path` | ||
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --maxtext-model-path ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt --model-size ${MODEL_VARIATION} | ||
|
||
echo "Wrote MaxText compatible checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt" | ||
|
||
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. | ||
export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items | ||
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. | ||
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. | ||
export RUN_NAME=unscanned_chkpt | ||
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true | ||
echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" |
Oops, something went wrong.