Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 67d905a

Browse files
guangy10malfet
authored andcommitted
Add doc for script (#239)
1 parent 06524bc commit 67d905a

File tree

4 files changed

+21
-242
lines changed

4 files changed

+21
-242
lines changed

.ci/scripts/validate.sh

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function generate_compiled_model_output() {
2525
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
2626
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')
2727

28-
for DTYPE in float32 bfloat16; do
28+
for DTYPE in float32 bfloat16 float16; do
2929
echo ""############### Run inference with torch.compile for dtype $DTYPE "###############"
3030
echo ""
3131
echo "******************************************"
@@ -98,7 +98,7 @@ function generate_aoti_model_output() {
9898
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
9999
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')
100100

101-
for DTYPE in float32 bfloat16; do
101+
for DTYPE in float32 bfloat16 float16; do
102102
echo ""############### Run inference with AOT Inductor for dtype $DTYPE "###############"
103103
echo ""
104104
echo "******************************************"
@@ -150,12 +150,16 @@ function generate_aoti_model_output() {
150150
python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
151151
cat "$MODEL_DIR/output_aoti"
152152

153-
# echo "******************************************"
154-
# echo "******** INT4 group-wise quantized *******"
155-
# echo "******************************************"
156-
# python -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
157-
# python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
158-
# cat "$MODEL_DIR/output_aoti"
153+
echo "******************************************"
154+
echo "******** INT4 group-wise quantized *******"
155+
echo "******************************************"
156+
if [ $(uname -s) == "Linux" ]; then
157+
echo "Skipping INT4 groupwise quantization because AOTI fails"
158+
else
159+
python -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
160+
python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
161+
cat "$MODEL_DIR/output_aoti"
162+
fi
159163
done
160164
}
161165

.github/workflows/compile-dtype.yml

Lines changed: 0 additions & 118 deletions
This file was deleted.

.github/workflows/compile_t4-dtype.yml

Lines changed: 0 additions & 115 deletions
This file was deleted.

scripts/workflow.sh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
################################################################################
9+
# Usage:
10+
# bash script.sh [cpu|cuda] [model_repo] [optional_command]
11+
# Arguments:
12+
# cpu|cuda: Specify the device to run validation on (cpu or cuda).
13+
# model_repo: Model repository name to validate (e.g., tinyllamas/stories15M).
14+
# optional_command: (optional) Specify additional command "compile", "aoti" or "executorch" to run the selected validation.
15+
################################################################################
816

917
set -eu
1018

@@ -75,7 +83,7 @@ MODEL_REPOS=(
7583
"mistralai/Mistral-7B-v0.1"
7684
"mistralai/Mistral-7B-Instruct-v0.1"
7785
"mistralai/Mistral-7B-Instruct-v0.2"
78-
# "openlm-research/open_llama_7b"
86+
"openlm-research/open_llama_7b"
7987
"codellama/CodeLlama-7b-Python-hf"
8088
"codellama/CodeLlama-34b-Python-hf"
8189
# "meta-llama/Llama-2-7b-chat-hf"

0 commit comments

Comments
 (0)