Skip to content

Commit

Permalink
Adds a new end-to-end test for Mistral 7b:
Browse files Browse the repository at this point in the history
- A notebook in scratch_pad that runs the model from HF, for sample inputs
and generates logits, which it stores in a file.
- A test that downloads the HF checkpoint, converts it to MaxText-compatible format,
and uses it to do one forward pass on a sample input.
- Compares the logits obtained in both steps above to make sure they're equal.
The old test, which asserted on the generated text being equal, rather than logits, is removed.
  • Loading branch information
shralex committed Sep 19, 2024
1 parent 48cb7b0 commit f00d1e2
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 31 deletions.
288 changes: 288 additions & 0 deletions MaxText/scratch_code/golden_mistral-7b_export.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0d13ebbb",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n",
"!pip3 install tokenizers -U\n",
"!pip3 install transformers -U"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6a8a4bb6",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"import jsonlines"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ff804403",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "587cc338332e42cd8438f831d6fcf2f7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer_config.json: 0%| | 0.00/996 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b99b7430db7b467d8f20dba4710b9ce7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.model: 0%| | 0.00/493k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d30f164b3a5a489b9294b161f0e2cd5d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/1.80M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e9dadc8ebd57401cb2d25987f44ebb1b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"special_tokens_map.json: 0%| | 0.00/414 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d3328850136a4198b094b290b265b6bb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/571 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a90a1f56913249c49031f83ea404c35b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors.index.json: 0%| | 0.00/25.1k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5556a29be4e445ceba17f606036ba868",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7d02fb922c074fee866cf21db143bc9a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00001-of-00002.safetensors: 0%| | 0.00/9.94G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ad1529db8bfe42619e8e5c16e0ff1e97",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00002-of-00002.safetensors: 0%| | 0.00/4.54G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "208f48342b8b43f9a0ea59ab260e6452",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "06309016d9ef4aca984b97e5fe983ba9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/116 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Load the tokenizer and model from Hugging Face\n",
"\n",
"model_id = \"mistralai/Mistral-7B-v0.1\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_id,\n",
" torch_dtype=torch.float32,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9f218ba6",
"metadata": {},
"source": [
"## looping over multiple prompts and logits"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c567f8d9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Data saved to golden_data_mistral-7b.jsonl\n"
]
}
],
"source": [
"# Save to disk\n",
"output_path = \"golden_data_mistral-7b.jsonl\"\n",
"\n",
"\n",
"# Your prompt text\n",
"prompt_texts = [\"[INST] I love to [/INST]\", \"[INST] Today is a [/INST]\", \"[INST] What is the [/INST]\"]\n",
"all_data_to_save = []\n",
"\n",
"\n",
"for prompt_text in prompt_texts:\n",
" # Encode the prompt text\n",
" input_ids = tokenizer.encode(prompt_text, return_tensors=\"pt\")\n",
"\n",
" # Get the logits for the prompt + completion\n",
" with torch.no_grad():\n",
" outputs = model(input_ids)\n",
" logits = outputs.logits\n",
"\n",
" # Convert logits to fp32\n",
" logits = logits.cpu().numpy().astype(\"float32\")\n",
"\n",
" # Prepare data to be saved\n",
" data_to_save = {\n",
" \"prompt\": prompt_text,\n",
" \"tokens\": input_ids.tolist()[0],\n",
" \"logits\": logits.tolist()[0], # Convert numpy array to list for JSON serialization\n",
" }\n",
" all_data_to_save.append(data_to_save)\n",
"\n",
"with jsonlines.open(output_path, \"w\") as f:\n",
" f.write_all(all_data_to_save)\n",
"\n",
"\n",
"print(f\"Data saved to {output_path}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 3 additions & 0 deletions MaxText/test_assets/golden_data_mistral-7b.jsonl

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions end_to_end/tpu/mistral/7b/test_mistral-7b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash

# This file runs on daily basis and demonstrates:
# 1. Converts the Mistral PyTorch checkpoint to MaxText(orbax) format.
# 2. Loads the MaxText(orbax) checkpoint to run inference, and runs one forward pass on a given input.
# 3. Compares the logits to pre-computed logits obtained by running the HF checkpoint directly,
# see scratch_code/golden-mistral-7b_export.ipynb and the resulting test_assets/golden_data_mistral-7b.jsonl

# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; bash end_to_end/tpu/mistral/7b/test_mistral-7b.sh

set -ex
MODEL_VARIATION='7b'

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing.
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d)
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi

# Download checkpoint
pip3 install torch
gcloud storage cp -r gs://maxtext-external/mistral-7B-v0.1 /tmp

# Convert it to MaxText(orbax) format - scanned ckpt
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path=/tmp/mistral-7B-v0.1 --model-size=mistral-7b --maxtext-model-path=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/
echo "Wrote MaxText compatible scanned checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt"

# `SCANNED_CHECKPOINT` refers to the checkpoint that used for both `train.py` and `decode.py`
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/items

# Generate unscanned ckpt for efficient decoding test
export RUN_NAME=unscanned_ckpt
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mistral-7b' force_unroll=true
echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints"

export DATASET_PATH=gs://maxtext-dataset

# Run decoding with converted ckpt - matmul implementation
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mistral-7b async_checkpointing=false tokenizer_path=assets/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt="[INST] I love to [/INST]" attention=dot_product megablox=False

# Test whether the forward pass logits match the golden logits - matmul implementation
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mistral-7b tokenizer_path=assets/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False --atol=3 --rtol=1 --token_size=4
4 changes: 2 additions & 2 deletions end_to_end/tpu/mixtral/8x22b/1_test_mixtral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ MODEL_VARIATION='8x22b'

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing.
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)/
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d)
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi

Expand All @@ -31,4 +31,4 @@ echo "Wrote MaxText compatible scanned checkpoint to ${BASE_OUTPUT_PATH}/${MODEL
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/items
export RUN_NAME=unscanned_ckpt
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mixtral-8x22b' force_unroll=true megablox=false dtype=float16 weight_dtype=bfloat16 per_device_batch_size=1 max_target_length=1
echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints"
echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints"
2 changes: 1 addition & 1 deletion end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ RTOL=10.0

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing.
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d)
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi

Expand Down
8 changes: 4 additions & 4 deletions end_to_end/tpu/mixtral/8x7b/1_test_mixtral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ MODEL_VARIATION='8x7b'

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing.
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)/
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d)
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi

Expand All @@ -23,11 +23,11 @@ pip3 install torch
gcloud storage cp -r gs://maxtext-external/mixtral-8x7B-v0.1-Instruct /tmp

# Convert it to MaxText(orbax) format - scanned ckpt
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path=/tmp/mixtral-8x7B-v0.1-Instruct --model-size=mixtral-8x7b --maxtext-model-path=${BASE_OUTPUT_PATH}${MODEL_VARIATION}/scanned_ckpt/
echo "Wrote MaxText compatible scanned checkpoint to ${BASE_OUTPUT_PATH}${MODEL_VARIATION}/scanned_ckpt"
JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path=/tmp/mixtral-8x7B-v0.1-Instruct --model-size=mixtral-8x7b --maxtext-model-path=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/
echo "Wrote MaxText compatible scanned checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt"

# Generate unscanned ckpt for efficient decoding test
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}${MODEL_VARIATION}/scanned_ckpt/0/items
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/items
export RUN_NAME=unscanned_ckpt
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mixtral-8x7b' force_unroll=true
echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints"
4 changes: 2 additions & 2 deletions end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ MODEL_VARIATION='8x7b'

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing.
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d)
echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}"
fi

export DATASET_PATH=gs://maxtext-dataset

# `SCANNED_CHECKPOINT` refers to the checkpoint that used for both `train.py` and `decode.py`
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}${MODEL_VARIATION}/scanned_ckpt/0/items
export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/items

# Run decoding with converted ckpt - matmul implementation
# TODO(ranran): add decoding test for megablox implementation
Expand Down
Loading

0 comments on commit f00d1e2

Please sign in to comment.