From 9e2265022bdee8547a43b42159deb6edb054fa23 Mon Sep 17 00:00:00 2001 From: "Wang,Zhe" Date: Wed, 17 Jul 2024 16:25:11 +0800 Subject: [PATCH] remove clip postfix --- docs/qbits.md | 6 +++--- examples/vllm/vllm_acceleration_example.py | 2 +- .../include/bestla_weightonly_dispatcher.hpp | 13 ++++--------- .../qbits/dispatcher/src/bestla_packq_impl.cpp | 15 +++++++-------- .../src/bestla_weightonly_dispatcher.cpp | 3 +-- .../qbits/qbits_ut/test_packq.py | 2 +- .../qbits/qbits_ut/test_weightonly.py | 8 ++++---- 7 files changed, 21 insertions(+), 28 deletions(-) diff --git a/docs/qbits.md b/docs/qbits.md index 6eba5fe6e08..1527765b831 100644 --- a/docs/qbits.md +++ b/docs/qbits.md @@ -16,7 +16,7 @@ import intel_extension_for_transformers.qbits as qbits transpose (bool): Whether to transpose the weight tensor (required for quantize_to_packed_weight with KxN weight shape). blocksize (int): Blocksize for weight-only quantization. compute_type (str): Computation type (fp32/bf16/int8). fp32 will leverage AVX2/AVX512F to compute, bf16 will be AMX_BF16, int8 will be VNNI/AMX_INT8. - weight_type (str): Quantization type (int8/int4_clip/int4_fullrange/nf4/fp4_e2m1). + weight_type (str): Quantization type (int8/int4/int3/int2/nf4/fp4_e2m1). scale_type (str): Scale type (fp32/bf16). asym (bool): Whether to use asymmetric quantization. @@ -37,7 +37,7 @@ pack_weight = qbits.quantize_to_packed_weight( g_idx (torch.Tensor): shuffle index used by GPTQ, dtype must be int32. blocksize (int): Blocksize for weight-only quantization. compute_type (str): Computation type (fp32/bf16/int8). fp32 will leverage AVX2/AVX512F to compute, bf16 will be AMX_BF16, int8 will be VNNI/AMX_INT8. - weight_type (str): Quantization type (int8/int4_clip/int4_fullrange/nf4/fp4_e2m1). + weight_type (str): Quantization type (int8/int4/int3/int2/nf4/fp4_e2m1). scale_type (str): Scale type (fp32/bf16). asym (bool): Whether to use asymmetric quantization. @@ -57,7 +57,7 @@ pack_weight = qbits.repack_quantized_weight( bias (torch.Tensor): Bias tensor, must be fp32, if bias is empty woq_linear will not add bias. output (torch.Tensor): Output tensor, support fp32/bf16, shape must be MxN. compute_type (str): Computation type (fp32/bf16/int8).fp32 will leverage AVX2/AVX512F to compute, bf16 will leverage AMX_BF16 to compute, int8 will leverage VNNI/AMX_INT8 to compute. - weight_type (str): Quantization type (int8/int4_clip/int4_fullrange/nf4/fp4_e2m1). + weight_type (str): Quantization type (int8/int4/int3/int2/nf4/fp4_e2m1). scale_type (str): Scale type (fp32/bf16). asym (bool): Whether to use asymmetric quantization. """ diff --git a/examples/vllm/vllm_acceleration_example.py b/examples/vllm/vllm_acceleration_example.py index 468ef26c7cb..21d22475b88 100644 --- a/examples/vllm/vllm_acceleration_example.py +++ b/examples/vllm/vllm_acceleration_example.py @@ -41,7 +41,7 @@ def main(args_in: Optional[List[str]] = None) -> None: config = RtnConfig(compute_dtype="int8", group_size=128, scale_dtype="bf16", - weight_dtype="int4_clip", + weight_dtype="int4", bits=4) print(config) prompts = [args.prompt] diff --git a/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp b/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp index 784c512220f..973a6345b8b 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp +++ b/intel_extension_for_transformers/qbits/dispatcher/include/bestla_weightonly_dispatcher.hpp @@ -59,15 +59,10 @@ struct woq_runtime_ctx { bestla::storage::gemm::IWeightBase* deseries_wei; }; -static std::map wei2bestladt_map{{"int8", BTLA_DTYPE::S8}, - {"int4_clip", BTLA_DTYPE::S4_CLIP}, - {"int3_clip", BTLA_DTYPE::S3_CLIP}, - {"int2_clip", BTLA_DTYPE::S2_CLIP}, - {"nf4", BTLA_DTYPE::F4_NF4}, - {"fp4_e2m1_bnb", BTLA_DTYPE::F4_BNB}, - {"fp4_e2m1", BTLA_DTYPE::F4_E2M1}, - {"fp8_e4m3", BTLA_DTYPE::F8_E4M3}, - {"fp8_e5m2", BTLA_DTYPE::F8_E5M2}}; +static std::map wei2bestladt_map{ + {"int8", BTLA_DTYPE::S8}, {"int4", BTLA_DTYPE::S4_CLIP}, {"int3", BTLA_DTYPE::S3_CLIP}, + {"int2", BTLA_DTYPE::S2_CLIP}, {"nf4", BTLA_DTYPE::F4_NF4}, {"fp4_e2m1_bnb", BTLA_DTYPE::F4_BNB}, + {"fp4_e2m1", BTLA_DTYPE::F4_E2M1}, {"fp8_e4m3", BTLA_DTYPE::F8_E4M3}, {"fp8_e5m2", BTLA_DTYPE::F8_E5M2}}; static std::map scale2bestladt_map{ {"fp32", BTLA_DTYPE::F32}, {"bf16", BTLA_DTYPE::BF16}, {"fp8_e8m0", BTLA_DTYPE::F8_E8M0}}; diff --git a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp index cf6889a9f15..075f382daa1 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp +++ b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_packq_impl.cpp @@ -42,8 +42,7 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx template void parse_prob(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { - if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || - p->weight_type == "int2_clip") { + if (p->weight_type == "int8" || p->weight_type == "int4" || p->weight_type == "int3" || p->weight_type == "int2") { return execute_qpack>(p, ctx, task); } if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1") { @@ -61,11 +60,11 @@ std::string get_dtype_str(BTLA_DTYPE dtype) { case BTLA_DTYPE::BF16: return "bf16"; case BTLA_DTYPE::S4_CLIP: - return "int4_clip"; + return "int4"; case BTLA_DTYPE::S3_CLIP: - return "int3_clip"; + return "int3"; case BTLA_DTYPE::S2_CLIP: - return "int2_clip"; + return "int2"; case BTLA_DTYPE::F4_NF4: return "nf4"; case BTLA_DTYPE::F4_E2M1: @@ -205,9 +204,9 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) { void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) { if (p->compute_type == "int8") { - TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || - p->weight_type == "int2_clip", - "Qbits: only support Integer weight-type with int8 compute-type"); + TORCH_CHECK( + p->weight_type == "int8" || p->weight_type == "int4" || p->weight_type == "int3" || p->weight_type == "int2", + "Qbits: only support Integer weight-type with int8 compute-type"); if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) { return parse_prob, BTLA_ISA::AMX_INT8>(p, ctx, task); } diff --git a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp index c04e652a4aa..f22c6ca3f71 100644 --- a/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp +++ b/intel_extension_for_transformers/qbits/dispatcher/src/bestla_weightonly_dispatcher.cpp @@ -276,8 +276,7 @@ void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) { template void parse_weight(woq_config_param* p, woq_runtime_ctx* ctx) { using namespace bestla::prologue_b::gemm; - if (p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" || - p->weight_type == "int2_clip") { + if (p->weight_type == "int8" || p->weight_type == "int4" || p->weight_type == "int3" || p->weight_type == "int2") { return parse_activation(p, ctx); } if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1" || diff --git a/intel_extension_for_transformers/qbits/qbits_ut/test_packq.py b/intel_extension_for_transformers/qbits/qbits_ut/test_packq.py index b87e9933967..acedd6edd05 100644 --- a/intel_extension_for_transformers/qbits/qbits_ut/test_packq.py +++ b/intel_extension_for_transformers/qbits/qbits_ut/test_packq.py @@ -48,7 +48,7 @@ class acquire_type(Enum): @pytest.mark.parametrize("k", [512]) @pytest.mark.parametrize("blocksize", [128]) @pytest.mark.parametrize("compute_type", ["fp32", "bf16", "int8"]) -@pytest.mark.parametrize("weight_type", ["int8", "int4_clip"]) +@pytest.mark.parametrize("weight_type", ["int8", "int4"]) @pytest.mark.parametrize("scale_type", ["fp32"]) @pytest.mark.parametrize("asym", [True, False]) def test(m, k, n, weight_type, scale_type, compute_type, asym, blocksize, dump_tensor=False): diff --git a/intel_extension_for_transformers/qbits/qbits_ut/test_weightonly.py b/intel_extension_for_transformers/qbits/qbits_ut/test_weightonly.py index 343cd4f8e74..048a765b5b1 100644 --- a/intel_extension_for_transformers/qbits/qbits_ut/test_weightonly.py +++ b/intel_extension_for_transformers/qbits/qbits_ut/test_weightonly.py @@ -17,14 +17,14 @@ from ut_utils import * -cmpt_configs = {"int8": {"int8", "bf16", "fp32"}, "int4_clip": {"int8", "fp32", "bf16"}, "int3_clip": {"int8", "fp32", "bf16"}, "int2_clip": {"int8", "fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"}, +cmpt_configs = {"int8": {"int8", "bf16", "fp32"}, "int4": {"int8", "fp32", "bf16"}, "int3": {"int8", "fp32", "bf16"}, "int2": {"int8", "fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"}, "fp8_e5m2": {"fp32", "bf16"}, "fp8_e4m3": {"fp32", "bf16"} } -scale_configs = {"int8": {"fp32", "bf16"}, "int4_clip": {"fp32", "bf16"}, "int3_clip": {"fp32", "bf16"}, "int2_clip": {"fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"}, +scale_configs = {"int8": {"fp32", "bf16"}, "int4": {"fp32", "bf16"}, "int3": {"fp32", "bf16"}, "int2": {"fp32", "bf16"}, "fp4_e2m1_bnb": {"fp32", "bf16"}, "fp4_e2m1": {"fp32", "bf16"}, "nf4": {"fp32", "bf16"}, "fp8_e5m2": {"fp32", "fp8_e8m0"}, "fp8_e4m3": {"fp32", "fp8_e8m0"}} -asym_configs = {"int8", "int4_clip", "int3_clip", "int2_clip"} +asym_configs = {"int8", "int4", "int3", "int2"} @capture_args @@ -33,7 +33,7 @@ @pytest.mark.parametrize("k", [512]) @pytest.mark.parametrize("blocksize", [128, -1]) @pytest.mark.parametrize("compute_type", ["int8", "bf16", "fp32"]) -@pytest.mark.parametrize("weight_type", ["int8", "int4_clip", "int3_clip", "int2_clip", "nf4", "fp4_e2m1_bnb", "fp4_e2m1", "fp8_e5m2", "fp8_e4m3"]) +@pytest.mark.parametrize("weight_type", ["int8", "int4", "int3", "int2", "nf4", "fp4_e2m1_bnb", "fp4_e2m1", "fp8_e5m2", "fp8_e4m3"]) @pytest.mark.parametrize("scale_type", ["fp32", "bf16", "fp8_e8m0"]) @pytest.mark.parametrize("asym", [True, False]) @pytest.mark.parametrize("transpose", [True, False])