Skip to content

Commit

Permalink
Merge pull request #845 from google:rdyro-add-mixtral-maxtext
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671554810
  • Loading branch information
maxtext authors committed Sep 5, 2024
2 parents aef1bb0 + 3eca7f3 commit b0b1772
Show file tree
Hide file tree
Showing 12 changed files with 644 additions and 154 deletions.
6 changes: 6 additions & 0 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,15 @@ def create_orbax_checkpoint_manager(
else:
item_names = ("items",)

# local storage checkpoint needs parent directory created
p.mkdir(exist_ok=True, parents=True)
# we need to use ocdbt and zarr3 to control max file size in the checkpoint
# omitting `iter` uses default handler for `iter`
item_handlers = {"items": PyTreeCheckpointHandler(use_ocdbt=True, use_zarr3=True)}
mngr = CheckpointManager(
p,
item_names=item_names,
item_handlers=item_handlers,
options=CheckpointManagerOptions(
create=True,
save_interval_steps=save_interval_steps,
Expand Down
31 changes: 31 additions & 0 deletions MaxText/configs/models/mixtral-8x22b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for mixtral-8x22b

base_emb_dim: 6144
base_num_query_heads: 48
base_num_kv_heads: 8
base_mlp_dim: 16384
base_num_decoder_layers: 56
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 32768
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
num_experts: 8
num_experts_per_tok: 2
rope_max_timescale: 1_000_000
decoder_block: "mistral"
423 changes: 277 additions & 146 deletions MaxText/llama_or_mistral_ckpt.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def validate_model_name(s: str) -> bool:
"llama3-70b",
"mistral-7b",
"mixtral-8x7b",
"mixtral-8x22b",
"gemma-7b",
"gemma-2b",
"gemma2-2b",
Expand Down
161 changes: 161 additions & 0 deletions MaxText/scratch_code/golden_mixtral-8x22b_export.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0d13ebbb",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n",
"!pip3 install tokenizers -U\n",
"!pip3 install transformers -U"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6a8a4bb6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/rdyro/devel/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch \n",
"from transformers import AutoTokenizer, AutoModelForCausalLM \n",
"import jsonlines"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ff804403",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 59/59 [03:54<00:00, 3.97s/it]\n"
]
}
],
"source": [
"# Load the tokenizer and model from Hugging Face \n",
" \n",
"model_id = \"mistralai/Mixtral-8x22B-Instruct-v0.1\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_id,\n",
" torch_dtype=torch.float16,\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "9f218ba6",
"metadata": {},
"source": [
"## looping over multiple prompts and logits"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c567f8d9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Data saved to golden_data_mixtral-8x22b.jsonl\n"
]
}
],
"source": [
"# Save to disk \n",
"output_path = \"golden_data_mixtral-8x22b.jsonl\" \n",
" \n",
" \n",
"# Your prompt text \n",
"prompt_texts = [\"[INST] I love to [/INST]\", \"[INST] Today is a [/INST]\", \"[INST] What is the [/INST]\"]\n",
"all_data_to_save = []\n",
"\n",
"\n",
"for prompt_text in prompt_texts:\n",
" # Encode the prompt text \n",
" input_ids = tokenizer.encode(prompt_text, return_tensors='pt') \n",
"\n",
" # Get the logits for the prompt + completion \n",
" with torch.no_grad(): \n",
" outputs = model(input_ids) \n",
" logits = outputs.logits \n",
"\n",
" # Convert logits to fp32 \n",
" logits = logits.cpu().numpy().astype('float32') \n",
"\n",
" # Prepare data to be saved \n",
" data_to_save = { \n",
" \"prompt\": prompt_text, \n",
" \"tokens\": input_ids.tolist()[0], \n",
" \"logits\": logits.tolist()[0] # Convert numpy array to list for JSON serialization \n",
" } \n",
" all_data_to_save.append(data_to_save)\n",
" \n",
"with jsonlines.open(output_path,'w') as f: \n",
" f.write_all(all_data_to_save)\n",
"\n",
"\n",
"\n",
"print(f\"Data saved to {output_path}\") \n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82c6e1f7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 3 additions & 0 deletions MaxText/test_assets/golden_data_mixtral-8x22b.jsonl

Large diffs are not rendered by default.

20 changes: 17 additions & 3 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

Transformer = models.Transformer
EPS = 1e-8
_CHUNK_BYTE_SIZE = 2 * 1024 **3


def validate_train_config(config):
Expand Down Expand Up @@ -167,22 +168,35 @@ def clear_buffered_metrics():

def save_checkpoint(checkpoint_manager, step, state, dataset_type="c4", data_iterator=None):
"""Wrapper for saving checkpoint"""
# specify chunk_byte_size to force orbax to control maximum file size in checkpoint
save_args = jax.tree.map(
lambda _: orbax.checkpoint.SaveArgs(chunk_byte_size=_CHUNK_BYTE_SIZE), state
)

if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager):
return checkpoint_manager.save(
step, args=orbax.checkpoint.args.PyTreeSave(state)
step, args=orbax.checkpoint.args.PyTreeSave(
item=state, save_args=save_args, ocdbt_target_data_file_size=_CHUNK_BYTE_SIZE
)
)

if dataset_type == "grain":
return checkpoint_manager.save(
step,
args=orbax.checkpoint.args.Composite(
items=orbax.checkpoint.args.PyTreeSave(item=state),
items=orbax.checkpoint.args.PyTreeSave(
item=state, save_args=save_args, ocdbt_target_data_file_size=_CHUNK_BYTE_SIZE
),
iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator),
),
)
else:
return checkpoint_manager.save(
step, args=orbax.checkpoint.args.Composite(items=orbax.checkpoint.args.PyTreeSave(item=state))
step, args=orbax.checkpoint.args.Composite(
items=orbax.checkpoint.args.PyTreeSave(
item=state, save_args=save_args, ocdbt_target_data_file_size=_CHUNK_BYTE_SIZE
)
)
)


Expand Down
File renamed without changes.
Binary file added assets/tokenizer.mistral-v3
Binary file not shown.
34 changes: 34 additions & 0 deletions end_to_end/tpu/mixtral/8x22b/1_test_mixtral.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/bash

# This file, combined with step 2 in the same directory, runs on daily basis and demonstrates:
# 1. Converts the Mistral PyTorch checkpoint to MaxText(orbax) format using a CPU VM.
# 2. Takes the MaxText(orbax) checkpoint to run inference, fine-tuning, and pre-training on a TPU VM.

# The flow of this file is to convert the Mistral PyTorch checkpoint to MaxText (orbax) format using a CPU VM.

# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/mixtral/8x22b/1_test_mixtral.sh
# Use the same BASE_OUTPUT_PATH for both 1_test_mixtral.sh & 2_test_mixtral.sh.

set -ex
MODEL_VARIATION='8x22b'

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing.
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)/
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi

# Download checkpoint
pip3 install torch
MODEL_NAME=Mixtral-8x22B-Instruct-v0.1
gcloud storage cp -r "gs://maxtext-external/$MODEL_NAME" /tmp

# Convert it to MaxText(orbax) format - scanned ckpt
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path="/tmp/$MODEL_NAME" --model-size=mixtral-8x22b --maxtext-model-path=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/
echo "Wrote MaxText compatible scanned checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt"

# Generate unscanned ckpt for efficient decoding test
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/items
export RUN_NAME=unscanned_ckpt
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mixtral-8x22b' force_unroll=true megablox=false dtype=float16 weight_dtype=bfloat16 per_device_batch_size=1 max_target_length=1
echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints"
Loading

0 comments on commit b0b1772

Please sign in to comment.