-
Notifications
You must be signed in to change notification settings - Fork 274
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #845 from google:rdyro-add-mixtral-maxtext
PiperOrigin-RevId: 671554810
- Loading branch information
Showing
12 changed files
with
644 additions
and
154 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# model config for mixtral-8x22b | ||
|
||
base_emb_dim: 6144 | ||
base_num_query_heads: 48 | ||
base_num_kv_heads: 8 | ||
base_mlp_dim: 16384 | ||
base_num_decoder_layers: 56 | ||
head_dim: 128 | ||
mlp_activations: ["silu","linear"] | ||
vocab_size: 32768 | ||
enable_dropout: False | ||
logits_via_embedding: False | ||
normalization_layer_epsilon: 1.0e-5 | ||
num_experts: 8 | ||
num_experts_per_tok: 2 | ||
rope_max_timescale: 1_000_000 | ||
decoder_block: "mistral" |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0d13ebbb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", | ||
"!pip3 install tokenizers -U\n", | ||
"!pip3 install transformers -U" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "6a8a4bb6", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/rdyro/devel/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
" from .autonotebook import tqdm as notebook_tqdm\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch \n", | ||
"from transformers import AutoTokenizer, AutoModelForCausalLM \n", | ||
"import jsonlines" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "ff804403", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loading checkpoint shards: 100%|██████████| 59/59 [03:54<00:00, 3.97s/it]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Load the tokenizer and model from Hugging Face \n", | ||
" \n", | ||
"model_id = \"mistralai/Mixtral-8x22B-Instruct-v0.1\"\n", | ||
"\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n", | ||
"model = AutoModelForCausalLM.from_pretrained(\n", | ||
" model_id,\n", | ||
" torch_dtype=torch.float16,\n", | ||
")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9f218ba6", | ||
"metadata": {}, | ||
"source": [ | ||
"## looping over multiple prompts and logits" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "c567f8d9", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Data saved to golden_data_mixtral-8x22b.jsonl\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Save to disk \n", | ||
"output_path = \"golden_data_mixtral-8x22b.jsonl\" \n", | ||
" \n", | ||
" \n", | ||
"# Your prompt text \n", | ||
"prompt_texts = [\"[INST] I love to [/INST]\", \"[INST] Today is a [/INST]\", \"[INST] What is the [/INST]\"]\n", | ||
"all_data_to_save = []\n", | ||
"\n", | ||
"\n", | ||
"for prompt_text in prompt_texts:\n", | ||
" # Encode the prompt text \n", | ||
" input_ids = tokenizer.encode(prompt_text, return_tensors='pt') \n", | ||
"\n", | ||
" # Get the logits for the prompt + completion \n", | ||
" with torch.no_grad(): \n", | ||
" outputs = model(input_ids) \n", | ||
" logits = outputs.logits \n", | ||
"\n", | ||
" # Convert logits to fp32 \n", | ||
" logits = logits.cpu().numpy().astype('float32') \n", | ||
"\n", | ||
" # Prepare data to be saved \n", | ||
" data_to_save = { \n", | ||
" \"prompt\": prompt_text, \n", | ||
" \"tokens\": input_ids.tolist()[0], \n", | ||
" \"logits\": logits.tolist()[0] # Convert numpy array to list for JSON serialization \n", | ||
" } \n", | ||
" all_data_to_save.append(data_to_save)\n", | ||
" \n", | ||
"with jsonlines.open(output_path,'w') as f: \n", | ||
" f.write_all(all_data_to_save)\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"print(f\"Data saved to {output_path}\") \n", | ||
"\n", | ||
" \n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "82c6e1f7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/bin/bash | ||
|
||
# This file, combined with step 2 in the same directory, runs on daily basis and demonstrates: | ||
# 1. Converts the Mistral PyTorch checkpoint to MaxText(orbax) format using a CPU VM. | ||
# 2. Takes the MaxText(orbax) checkpoint to run inference, fine-tuning, and pre-training on a TPU VM. | ||
|
||
# The flow of this file is to convert the Mistral PyTorch checkpoint to MaxText (orbax) format using a CPU VM. | ||
|
||
# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/mixtral/8x22b/1_test_mixtral.sh | ||
# Use the same BASE_OUTPUT_PATH for both 1_test_mixtral.sh & 2_test_mixtral.sh. | ||
|
||
set -ex | ||
MODEL_VARIATION='8x22b' | ||
|
||
if [ -z "${BASE_OUTPUT_PATH}" ]; then | ||
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing. | ||
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)/ | ||
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" | ||
fi | ||
|
||
# Download checkpoint | ||
pip3 install torch | ||
MODEL_NAME=Mixtral-8x22B-Instruct-v0.1 | ||
gcloud storage cp -r "gs://maxtext-external/$MODEL_NAME" /tmp | ||
|
||
# Convert it to MaxText(orbax) format - scanned ckpt | ||
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path="/tmp/$MODEL_NAME" --model-size=mixtral-8x22b --maxtext-model-path=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/ | ||
echo "Wrote MaxText compatible scanned checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt" | ||
|
||
# Generate unscanned ckpt for efficient decoding test | ||
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/items | ||
export RUN_NAME=unscanned_ckpt | ||
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mixtral-8x22b' force_unroll=true megablox=false dtype=float16 weight_dtype=bfloat16 per_device_batch_size=1 max_target_length=1 | ||
echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints" |
Oops, something went wrong.