From 9e01becda8fa437f897110e0305739a905bfe81e Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:29:37 -0700 Subject: [PATCH] Fix float16 with int4 in CI (#248) --- .ci/scripts/validate.sh | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/.ci/scripts/validate.sh b/.ci/scripts/validate.sh index 4fc7a1f8c..fb7160406 100644 --- a/.ci/scripts/validate.sh +++ b/.ci/scripts/validate.sh @@ -85,10 +85,14 @@ function generate_compiled_model_output() { echo "******************************************" echo "******** INT4 group-wise quantized *******" echo "******************************************" - python -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1 - cat "$MODEL_DIR/output_eager" - python -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1 - cat "$MODEL_DIR/output_compiled" + if [ "$DTYPE" = float16 ]; then + echo "Skipping INT4 groupwise quantization for float16 because torch.compile fails" + else + python -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1 + cat "$MODEL_DIR/output_eager" + python -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1 + cat "$MODEL_DIR/output_compiled" + fi done } @@ -153,12 +157,16 @@ function generate_aoti_model_output() { echo "******************************************" echo "******** INT4 group-wise quantized *******" echo "******************************************" - if [ $(uname -s) == "Linux" ]; then - echo "Skipping INT4 groupwise quantization because AOTI fails" - else - python -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 - cat "$MODEL_DIR/output_aoti" + if [ "$DTYPE" = float16 ]; then + echo "Skipping INT4 groupwise quantization for float16 because AOTI fails" + else + if [ $(uname -s) == "Linux" ]; then + echo "Skipping INT4 groupwise quantization because AOTI fails" + else + python -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 + python -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + cat "$MODEL_DIR/output_aoti" + fi fi done }