From b41b614b1f7322e8ebc1fced52bf4a7b92c071c7 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Tue, 27 Jun 2023 20:38:07 +0100 Subject: [PATCH] Fix example export Signed-off-by: Alessandro Pappalardo --- .../llm/llm_quant/mlir_custom_mm.py | 19 +++---------------- .../llm/test_linear_mlir_export.py | 6 ++++-- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py b/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py index 734f15ea0..c4a23f123 100644 --- a/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py +++ b/src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py @@ -62,13 +62,7 @@ def matmul_rhs_group_quant( brevitas_lib.impl("matmul_rhs_group_quant", matmul_rhs_group_quant) -def brevitas〇matmul_rhs_group_quant〡shape( - lhs: List[int], - rhs: List[int], - rhs_scale: List[int], - rhs_zero_point: List[int], - rhs_bit_width: int, - rhs_group_size: int) -> List[int]: +def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: if len(lhs) == 3 and len(rhs) == 2: return [lhs[0], lhs[1], rhs[0]] elif len(lhs) == 2 and len(rhs) == 2: @@ -77,20 +71,13 @@ def brevitas〇matmul_rhs_group_quant〡shape( raise ValueError("Input shapes not supported.") -def brevitas〇matmul_rhs_group_quant〡dtype( - lhs_rank_dtype: Tuple[int, int], - rhs_rank_dtype: Tuple[int, int], - rhs_scale_rank_dtype: Tuple[int, int], - rhs_zero_point_rank_dtype: Tuple[int, int], - rhs_bit_width: int, - rhs_group_size: int) -> int: +def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: # output dtype is the dtype of the lhs float input lhs_rank, lhs_dtype = lhs_rank_dtype return lhs_dtype -def brevitas〇matmul_rhs_group_quant〡has_value_semantics( - lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: +def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: return diff --git a/src/brevitas_examples/llm/test_linear_mlir_export.py b/src/brevitas_examples/llm/test_linear_mlir_export.py index cab502519..870ec9406 100644 --- a/src/brevitas_examples/llm/test_linear_mlir_export.py +++ b/src/brevitas_examples/llm/test_linear_mlir_export.py @@ -54,12 +54,14 @@ def quantize_and_export(args): # Run quantization quantize_model( model, + dtype=torch.float32, weight_quant_type=args.weight_quant_type, weight_bit_width=args.weight_bit_width, weight_group_size=args.weight_group_size, weight_param_method='stats', - weight_scale_type='float32', - weight_quant_granularity='per_group') + weight_scale_type='float', + weight_quant_granularity='per_group', + quantize_weight_zero_point=False) # Run a test forward pass model(torch.randn(2, 128))