Skip to content

Commit

Permalink
Llama3.1 config
Browse files Browse the repository at this point in the history
  • Loading branch information
khatwanimohit committed Sep 20, 2024
1 parent c7c3f4e commit 2ab924d
Show file tree
Hide file tree
Showing 16 changed files with 438 additions and 46 deletions.
29 changes: 29 additions & 0 deletions MaxText/configs/models/llama3.1-405b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for llama3.1-405b

base_emb_dim: 16384
base_num_query_heads: 128
base_num_kv_heads: 8
base_num_decoder_layers: 126
base_mlp_dim: 53248
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 128256
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
rope_max_timescale: 500_000
decoder_block: "llama2" # Uses the same decoder block as llama2
29 changes: 29 additions & 0 deletions MaxText/configs/models/llama3.1-70b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for llama3.1-70b

base_emb_dim: 8192
base_num_query_heads: 64
base_num_kv_heads: 8
base_num_decoder_layers: 80
base_mlp_dim: 28672
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 128256
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
rope_max_timescale: 500_000
decoder_block: "llama2" # Uses the same decoder block as llama2
29 changes: 29 additions & 0 deletions MaxText/configs/models/llama3.1-8b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for llama3.1-8b

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_num_decoder_layers: 32
base_mlp_dim: 14336
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 128256
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
rope_max_timescale: 500_000
decoder_block: "llama2" # Uses the same decoder block as llama2
21 changes: 21 additions & 0 deletions MaxText/llama_or_mistral_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,27 @@
"dims_per_head": 128,
"vocab": 128256,
},
"llama3.1-8b": {
"num_layers": 32,
"num_heads": 32,
"num_kv_heads": 8,
"dims_per_head": 128,
"vocab": 128256,
},
"llama3.1-70b": {
"num_layers": 80,
"num_heads": 64,
"num_kv_heads": 8,
"dims_per_head": 128,
"vocab": 128256,
},
"llama3.1-405b": {
"num_layers": 126,
"num_heads": 128,
"num_kv_heads": 8,
"dims_per_head": 128,
"vocab": 128256,
},
"mistral-7b": {
"num_layers": 32,
"num_heads": 32,
Expand Down
4 changes: 4 additions & 0 deletions MaxText/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def get_optimizer(config, learning_rate_schedule):
epsilon_root=config.adam_eps_root,
weight_decay=config.adam_weight_decay,
)
elif config.opt_type == "sgd":
return optax.sgd(
learning_rate_schedule
)
else:
raise ValueError(f"{config.opt_type=} is not a supported.")

Expand Down
5 changes: 4 additions & 1 deletion MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def validate_model_name(s: str) -> bool:
"llama2-70b",
"llama3-8b",
"llama3-70b",
"llama3.1-8b",
"llama3.1-70b",
"llama3.1-405b",
"mistral-7b",
"mixtral-8x7b",
"mixtral-8x22b",
Expand All @@ -159,7 +162,7 @@ def validate_model_name(s: str) -> bool:
"gpt3-52k",
)
if s not in valid_model_names:
raise ValueError("Invalid model name was passed. Valid options ", valid_model_names)
raise ValueError(f"Invalid model name was passed. Got {s}, Valid options {valid_model_names}")


def validate_no_keys_overwritten_twice(keys1: list[str], keys2: list[str]):
Expand Down
103 changes: 59 additions & 44 deletions MaxText/scratch_code/golden_llama3-70b_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,17 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Usage: python3 golden_llama3-70b_export.py --model-id meta-llama/Meta-Llama-3-70B --output-path llama3-70b/golden_logits/golden_data_llama3-70b.jsonl
"""

import torch
import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
import jsonlines
from google.cloud import storage

# Load the tokenizer and model from Hugging Face

model_id = "meta-llama/Meta-Llama-3-70B"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
)


# Your prompt text
prompt_texts = ["I love to"]
all_data_to_save = []

output_path = 'golden_data_llama3-70b.jsonl'


for prompt_text in prompt_texts:
# Encode the prompt text
input_ids = tokenizer.encode(prompt_text, return_tensors='pt')

# Get the logits for the prompt + completion
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits

# Convert logits to fp32
logits = logits.cpu().numpy().astype('float32')

# Prepare data to be saved
data_to_save = {
"prompt": prompt_text,
"tokens": input_ids.tolist()[0],
"logits": logits.tolist()[0] # Convert numpy array to list for JSON serialization
}
all_data_to_save.append(data_to_save)

with jsonlines.open(output_path,'w') as f:
f.write_all(all_data_to_save)

def upload_blob(bucket_name, source_file_name, destination_blob_name):
"""Uploads a file to the bucket."""
Expand All @@ -68,8 +32,59 @@ def upload_blob(bucket_name, source_file_name, destination_blob_name):

blob.upload_from_filename(source_file_name)

upload_blob('maxtext-llama', output_path, 'llama3-70b/golden-logits/' + output_path)
print('File {} uploaded to {}.'.format(
output_path,
'llama3-70b/golden-logits/' + output_path))
def convert(model_id, output_path):


tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
)


# Your prompt text
prompt_texts = ["I love to"]
all_data_to_save = []

output_filename = output_path.split('/')[-1]

for prompt_text in prompt_texts:
# Encode the prompt text
input_ids = tokenizer.encode(prompt_text, return_tensors='pt')

# Get the logits for the prompt + completion
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits

# Convert logits to fp32
logits = logits.cpu().numpy().astype('float32')

# Prepare data to be saved
data_to_save = {
"prompt": prompt_text,
"tokens": input_ids.tolist()[0],
"logits": logits.tolist()[0] # Convert numpy array to list for JSON serialization
}
all_data_to_save.append(data_to_save)

with jsonlines.open(output_filename,'w') as f:
f.write_all(all_data_to_save)



upload_blob('maxtext-llama', output_filename, output_path)
print('File {} uploaded to {}.'.format(
output_filename,
output_path))

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)

args = parser.parse_args()

convert(args.model_id, args.output_path)


1 change: 1 addition & 0 deletions MaxText/test_assets/golden_data_llama3.1-405b.jsonl

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions MaxText/test_assets/golden_data_llama3.1-70b.jsonl

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions MaxText/test_assets/golden_data_llama3.1-8b.jsonl

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_data(golden_data, golden_data_index, config):
s = (config.global_batch_size_to_train_on, config.max_target_length)
ids = np.asarray(golden_data[golden_data_index]['tokens'], dtype=np.int32)

logits = np.asarray(golden_data[golden_data_index]['logits'], dtype=np.float32)
logits = jnp.asarray(golden_data[golden_data_index]['logits'], dtype=config.dtype)
max_logging.log(f" prompt=\"{golden_data[golden_data_index]['prompt']}\" raw ids={ids}, logits.shape = {logits.shape}")


Expand Down
44 changes: 44 additions & 0 deletions end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash

# This file is both an integration test that runs once a day on a v4-128 and documentation for how to get started with Llama3.1-405b.
# Please make sure you have run end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh before running commands from this file.

# The flow of this file is as follows:
# 1. Run decoding, finetuning of Llama3.1-405B with the converted checkpoint obtained from end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh. Also, run pretraining of Llama3.1-70B
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding.
# 3. Run decoding from the finetuned checkpoint from step 1


# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3.1/405b/1_test_llama3.1_405b.sh
# Please note that in these two scripts (1_test_llama3.1_405b.sh and 2_test_llama3.1_405b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and
# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs.

set -ex
export MODEL_VARIATION='llama3.1-405b'

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi


# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset

# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands
export CONVERTED_CHECKPOINT=gs://maxtext-llama/llama3.1_405b_instruct/maxtext-ckpt/0/items

# We run finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning.
# We use a small per_device_batch_size and SGD optmizer for the model to fit on a v4-128.
export FINETUNE_RUN_NAME=runner_finetune
# python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_type=synthetic tokenizer_path=assets/tokenizer_llama3.tiktoken per_device_batch_size=0.25 ici_tensor_parallelism=4 run_name=${FINETUNE_RUN_NAME} steps=10 enable_checkpointing=false model_name=${MODEL_VARIATION} logits_dot_in_fp32=false weight_dtype=bfloat16 opt_type=sgd

# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
# python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=0.0625 ici_tensor_parallelism=4 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product weight_dtype=bfloat16 prompt="I love to"

# We also test whether the forward pass logits match the golden logits for Llama3.1-405B
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=assets/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=0.0625 ici_tensor_parallelism=4 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic logits_dot_in_fp32=false weight_dtype=bfloat16 opt_type=sgd async_checkpointing=false --max_kl_div=0.15

48 changes: 48 additions & 0 deletions end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/bin/bash

# This file, combined with step 2 in the same directory, demonstrates converting a Llama3-70B checkpoint from Meta and running various MaxText operations on it.
# This step is tested nightly on an ordinary CPU VM.

# The flow of this file is as follows:
# 1. Pull the checkpoint from a GCS bucket and uploads the new MaxText compatible checkpoint to destination GCS bucket.
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding.

# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/llama3/70b/1_test_llama3_70b.sh
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh.
# Please note that in these two scripts (1_test_llama3_70b.sh and 2_test_llama3_70b.sh) BASE_OUTPUT_PATH is assumed to be already a unique path across multiple runs and
# the subfolders names aka RUN_NAMEs are static. Please remember to change BASE_OUTPUT_PATH across different runs.

set -ex
MODEL_VARIATION='llama3.1-70b'

# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU
pip install torch --index-url https://download.pytorch.org/whl/cpu

# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint
export META_CHECKPOINT_PATH=gs://maxtext-llama/llama3.1_70b/meta-ckpt

# In the following command, we are copying Meta's checkpoint into a local directory `tmp`.
# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py`
gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing.
# Use the same BASE_OUTPUT_PATH as end_to_end/tpu/llama3/70b/2_test_llama3_70b
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi

echo "Converted checkpoints will be stored at ${BASE_OUTPUT_PATH}"

#Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path`
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --maxtext-model-path ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt --model-size ${MODEL_VARIATION}

echo "Wrote MaxText compatible checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt"

# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory.
export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt/0/items
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`.
export RUN_NAME=unscanned_chkpt
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true
echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items"
Loading

0 comments on commit 2ab924d

Please sign in to comment.