-
Notifications
You must be signed in to change notification settings - Fork 274
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds a new end-to-end test for Mistral 7b:
- A notebook in scratch_pad that runs the model from HF, for sample inputs and generates logits, which it stores in a file. - A test that downloads the HF checkpoint, converts it to MaxText-compatible format, and uses it to do one forward pass on a sample input. - Compares the logits obtained in both steps above to make sure they're equal. The old test, which asserted on the generated text being equal, rather than logits, is removed.
- Loading branch information
Showing
8 changed files
with
342 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,288 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0d13ebbb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", | ||
"!pip3 install tokenizers -U\n", | ||
"!pip3 install transformers -U" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "6a8a4bb6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from transformers import AutoTokenizer, AutoModelForCausalLM\n", | ||
"import jsonlines" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "ff804403", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "587cc338332e42cd8438f831d6fcf2f7", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"tokenizer_config.json: 0%| | 0.00/996 [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "b99b7430db7b467d8f20dba4710b9ce7", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"tokenizer.model: 0%| | 0.00/493k [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "d30f164b3a5a489b9294b161f0e2cd5d", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"tokenizer.json: 0%| | 0.00/1.80M [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "e9dadc8ebd57401cb2d25987f44ebb1b", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"special_tokens_map.json: 0%| | 0.00/414 [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "d3328850136a4198b094b290b265b6bb", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"config.json: 0%| | 0.00/571 [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "a90a1f56913249c49031f83ea404c35b", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"model.safetensors.index.json: 0%| | 0.00/25.1k [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "5556a29be4e445ceba17f606036ba868", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "7d02fb922c074fee866cf21db143bc9a", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"model-00001-of-00002.safetensors: 0%| | 0.00/9.94G [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "ad1529db8bfe42619e8e5c16e0ff1e97", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"model-00002-of-00002.safetensors: 0%| | 0.00/4.54G [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "208f48342b8b43f9a0ea59ab260e6452", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "06309016d9ef4aca984b97e5fe983ba9", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"generation_config.json: 0%| | 0.00/116 [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"# Load the tokenizer and model from Hugging Face\n", | ||
"\n", | ||
"model_id = \"mistralai/Mistral-7B-v0.1\"\n", | ||
"\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n", | ||
"model = AutoModelForCausalLM.from_pretrained(\n", | ||
" model_id,\n", | ||
" torch_dtype=torch.float32,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9f218ba6", | ||
"metadata": {}, | ||
"source": [ | ||
"## looping over multiple prompts and logits" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "c567f8d9", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Data saved to golden_data_mistral-7b.jsonl\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Save to disk\n", | ||
"output_path = \"golden_data_mistral-7b.jsonl\"\n", | ||
"\n", | ||
"\n", | ||
"# Your prompt text\n", | ||
"prompt_texts = [\"[INST] I love to [/INST]\", \"[INST] Today is a [/INST]\", \"[INST] What is the [/INST]\"]\n", | ||
"all_data_to_save = []\n", | ||
"\n", | ||
"\n", | ||
"for prompt_text in prompt_texts:\n", | ||
" # Encode the prompt text\n", | ||
" input_ids = tokenizer.encode(prompt_text, return_tensors=\"pt\")\n", | ||
"\n", | ||
" # Get the logits for the prompt + completion\n", | ||
" with torch.no_grad():\n", | ||
" outputs = model(input_ids)\n", | ||
" logits = outputs.logits\n", | ||
"\n", | ||
" # Convert logits to fp32\n", | ||
" logits = logits.cpu().numpy().astype(\"float32\")\n", | ||
"\n", | ||
" # Prepare data to be saved\n", | ||
" data_to_save = {\n", | ||
" \"prompt\": prompt_text,\n", | ||
" \"tokens\": input_ids.tolist()[0],\n", | ||
" \"logits\": logits.tolist()[0], # Convert numpy array to list for JSON serialization\n", | ||
" }\n", | ||
" all_data_to_save.append(data_to_save)\n", | ||
"\n", | ||
"with jsonlines.open(output_path, \"w\") as f:\n", | ||
" f.write_all(all_data_to_save)\n", | ||
"\n", | ||
"\n", | ||
"print(f\"Data saved to {output_path}\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#!/bin/bash | ||
|
||
# This file runs on daily basis and demonstrates: | ||
# 1. Converts the Mistral PyTorch checkpoint to MaxText(orbax) format. | ||
# 2. Loads the MaxText(orbax) checkpoint to run inference, and runs one forward pass on a given input. | ||
# 3. Compares the logits to pre-computed logits obtained by running the HF checkpoint directly, | ||
# see scratch_code/golden-mistral-7b_export.ipynb and the resulting test_assets/golden_data_mistral-7b.jsonl | ||
|
||
# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/mistral/7b/test_mistral-7b.sh | ||
|
||
set -ex | ||
MODEL_VARIATION='7b' | ||
|
||
if [ -z "${BASE_OUTPUT_PATH}" ]; then | ||
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing. | ||
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d) | ||
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" | ||
fi | ||
|
||
# Download checkpoint | ||
pip3 install torch | ||
gcloud storage cp -r gs://maxtext-external/mistral-7B-v0.1 /tmp | ||
|
||
# Convert it to MaxText(orbax) format - scanned ckpt | ||
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path=/tmp/mistral-7B-v0.1 --model-size=mistral-7b --maxtext-model-path=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/ | ||
echo "Wrote MaxText compatible scanned checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt" | ||
|
||
# `SCANNED_CHECKPOINT` refers to the checkpoint that used for both `train.py` and `decode.py` | ||
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/items | ||
|
||
# Generate unscanned ckpt for efficient decoding test | ||
export RUN_NAME=unscanned_ckpt | ||
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mistral-7b' force_unroll=true | ||
echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints" | ||
|
||
export DATASET_PATH=gs://maxtext-dataset | ||
|
||
# Run decoding with converted ckpt - matmul implementation | ||
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mistral-7b async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt="[INST] I love to [/INST]" attention=dot_product megablox=False | ||
|
||
# Test whether the forward pass logits match the golden logits - matmul implementation | ||
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mistral-7b tokenizer_path=assets/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False --atol=3 --rtol=1 --token_size=4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.