Skip to content

Commit

Permalink
Fix example export
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Pappalardo <[email protected]>
  • Loading branch information
volcacius committed Jun 27, 2023
1 parent 32bfbe9 commit b41b614
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 18 deletions.
19 changes: 3 additions & 16 deletions src/brevitas_examples/llm/llm_quant/mlir_custom_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,7 @@ def matmul_rhs_group_quant(
brevitas_lib.impl("matmul_rhs_group_quant", matmul_rhs_group_quant)


def brevitas〇matmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int) -> List[int]:
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
Expand All @@ -77,20 +71,13 @@ def brevitas〇matmul_rhs_group_quant〡shape(
raise ValueError("Input shapes not supported.")


def brevitas〇matmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int) -> int:
def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype


def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return


Expand Down
6 changes: 4 additions & 2 deletions src/brevitas_examples/llm/test_linear_mlir_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ def quantize_and_export(args):
# Run quantization
quantize_model(
model,
dtype=torch.float32,
weight_quant_type=args.weight_quant_type,
weight_bit_width=args.weight_bit_width,
weight_group_size=args.weight_group_size,
weight_param_method='stats',
weight_scale_type='float32',
weight_quant_granularity='per_group')
weight_scale_type='float',
weight_quant_granularity='per_group',
quantize_weight_zero_point=False)

# Run a test forward pass
model(torch.randn(2, 128))
Expand Down

0 comments on commit b41b614

Please sign in to comment.