Skip to content

Commit

Permalink
Fix float16 with int4 in CI (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
guangy10 authored Apr 17, 2024
1 parent b781741 commit 9e01bec
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions .ci/scripts/validate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,14 @@ function generate_compiled_model_output() {
echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"
python -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
cat "$MODEL_DIR/output_eager"
python -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
cat "$MODEL_DIR/output_compiled"
if [ "$DTYPE" = float16 ]; then
echo "Skipping INT4 groupwise quantization for float16 because torch.compile fails"
else
python -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
cat "$MODEL_DIR/output_eager"
python -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
cat "$MODEL_DIR/output_compiled"
fi
done
}

Expand Down Expand Up @@ -153,12 +157,16 @@ function generate_aoti_model_output() {
echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"
if [ $(uname -s) == "Linux" ]; then
echo "Skipping INT4 groupwise quantization because AOTI fails"
else
python -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
cat "$MODEL_DIR/output_aoti"
if [ "$DTYPE" = float16 ]; then
echo "Skipping INT4 groupwise quantization for float16 because AOTI fails"
else
if [ $(uname -s) == "Linux" ]; then
echo "Skipping INT4 groupwise quantization because AOTI fails"
else
python -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
cat "$MODEL_DIR/output_aoti"
fi
fi
done
}
Expand Down

0 comments on commit 9e01bec

Please sign in to comment.