From b19c86273b59283f8d912752cf9851492336ccac Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 16 Jul 2024 23:25:15 +0000 Subject: [PATCH 1/7] seeing accuracy jumps with ifeval --- .../run-lm-eval-gsm-vllm-baseline.sh | 2 +- vllm/config.py | 3 +- .../compressed_tensors/compressed_tensors.py | 21 +++-- .../schemes/compressed_tensors_w8a8_fp8.py | 79 ++++++++++++------- .../quantization/compressed_tensors/utils.py | 5 ++ 5 files changed, 72 insertions(+), 38 deletions(-) diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index 2f04cc1283df3..614642fed2e2a 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -47,5 +47,5 @@ done lm_eval --model vllm \ --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \ - --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ + --tasks ifeval --num_fewshot $FEWSHOT --limit $LIMIT \ --batch_size $BATCH_SIZE diff --git a/vllm/config.py b/vllm/config.py index de7bb3943a45f..b478409188d8b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -237,7 +237,8 @@ def _verify_quantization(self) -> None: f"{self.quantization} quantization is currently not " f"supported in ROCm.") if (self.quantization - not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")): + not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", + "compressed-tesnors")): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 524b4c894b9b5..4212bea8ae92b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -132,10 +132,11 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_weight = ( - weight_quant.strategy == QuantizationStrategy.TENSOR) + is_per_tensor_or_channel_weight = ( + weight_quant.strategy == QuantizationStrategy.TENSOR or + weight_quant.strategy == QuantizationStrategy.CHANNEL) if not (is_symmetric_weight and is_static_weight - and is_per_tensor_weight): + and is_per_tensor_or_channel_weight): return False # Dynamic quantization is always supported if weights supported. @@ -166,7 +167,8 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel, def _get_schema(self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": - + + # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): self._check_gptq_and_marlin_can_run() if (self.quant_format == CompressionFormat.marlin_24.value @@ -181,12 +183,15 @@ def _get_schema(self, weight_quant: BaseModel, num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, group_size=weight_quant.group_size) - - if (self.quant_format == CompressionFormat.int_quantized.value or - self.quant_format == CompressionFormat.float_quantized.value): + + # Detect If Activation Quantization. + if (self.quant_format == CompressionFormat.naive_quantized.value or + self.quant_format == CompressionFormat.int_quantized.value or + self.quant_format == CompressionFormat.float_quantized.value): if self._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8( - input_dynamic=input_quant.dynamic) + strategy=weight_quant.strategy, + is_static_input_scheme=(not input_quant.dynamic)) if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index b93425fb2d629..f5cc7b06ec8d8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -1,52 +1,70 @@ from typing import Callable, List, Optional import torch +from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported, - requantize_with_max_scale) + apply_fp8_linear, create_per_tensor_scale_param, create_per_channel_scale_param, + cutlass_fp8_supported, requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs __all__ = ["CompressedTensorsW8A8Fp8"] - class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): - def __init__(self, input_dynamic: bool): - self.input_dynamic = input_dynamic + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() + + # On Lovelace, fall ba + if (not self.cutlass_fp8_supported and + self.strategy == QuantizationStrategy.CHANNEL): + raise ValueError( + "Channelwise fp8 quantization requires vLLM's custom " + "cutlass kernels, which are not supported on your device." + "Consider quantizing with per tensor scales or upgrading " + "to Hopper.") - # W8A8-Fp8 kernels support only per-tensor and per-channel cases. - # So if we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), we requantize with a single scale. def process_weights_after_loading(self, layer) -> None: - # Dequant -> Quant with max scale. - max_w_scale, weight = requantize_with_max_scale( - weight=layer.weight, - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths, - ) + # If per tensor, when we have a fused module (e.g. QKV) with per + # tensor scales (thus N scales being passed to the kernel), + # requantize so we can always run per tensor with torch._scaled_mm + if (self.strategy == QuantizationStrategy.TENSOR or + not self.cutlass_fp8_supported): + max_w_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths,) + + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # If channelwise, scales are already lined up, so just transpose. + elif self.strategy == QuantizationStrategy.CHANNEL: + assert self.cutlass_fp8_supported + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) - # Update layer with new values. - layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False) - layer.weight_scale = torch.nn.Parameter(max_w_scale, - requires_grad=False) - if self.input_dynamic: - layer.input_scale = None else: - layer.input_scale = torch.nn.Parameter(layer.input_scale.max(), - requires_grad=False) + raise ValueError(f"Unknown quantization strategy {self.strategy}") + + # INPUT SCALE + if self.is_static_input_scheme: + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + else: + layer.input_scale = None def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - - del params_dtype - output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes @@ -63,12 +81,17 @@ def create_weights(self, layer: torch.nn.Module, }) # WEIGHT SCALE - weight_scale = create_per_tensor_scale_param( - output_partition_sizes, weight_loader=weight_loader) + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = create_per_channel_scale_param( + output_partition_sizes, weight_loader=weight_loader) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = create_per_tensor_scale_param( + output_partition_sizes, weight_loader=weight_loader) layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE - if not self.input_dynamic: + if self.is_static_input_scheme: input_scale = create_per_tensor_scale_param( output_partition_sizes, weight_loader=weight_loader) layer.register_parameter("input_scale", input_scale) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 5b44c215535b5..ede3ad8100a91 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -8,10 +8,15 @@ class CompressionFormat(Enum): dense = "dense" + # Sparsity sparse_bitmask = "sparse-bitmask" + # For Activation Quantization + naive_quantized = "naive-quantized" float_quantized = "float-quantized" int_quantized = "int-quantized" + # For Marlin pack_quantized = "pack-quantized" + # For Marlin 2:4 marlin_24 = "marlin-24" From 433c160490976d76533b7e386a06cd3769af006e Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 16 Jul 2024 19:41:01 -0400 Subject: [PATCH 2/7] Update run-lm-eval-gsm-vllm-baseline.sh --- .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index 614642fed2e2a..2f04cc1283df3 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -47,5 +47,5 @@ done lm_eval --model vllm \ --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \ - --tasks ifeval --num_fewshot $FEWSHOT --limit $LIMIT \ + --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ --batch_size $BATCH_SIZE From 39f89b951e6aa296a1f8a031fde6286582c8a586 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Tue, 16 Jul 2024 23:42:56 +0000 Subject: [PATCH 3/7] format --- .../compressed_tensors/compressed_tensors.py | 15 ++++++----- .../schemes/compressed_tensors_w8a8_fp8.py | 27 ++++++++++--------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4212bea8ae92b..1c78df96db6d2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -133,8 +133,8 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic is_per_tensor_or_channel_weight = ( - weight_quant.strategy == QuantizationStrategy.TENSOR or - weight_quant.strategy == QuantizationStrategy.CHANNEL) + weight_quant.strategy == QuantizationStrategy.TENSOR + or weight_quant.strategy == QuantizationStrategy.CHANNEL) if not (is_symmetric_weight and is_static_weight and is_per_tensor_or_channel_weight): return False @@ -167,7 +167,7 @@ def _is_wNa16_group_channel(self, weight_quant: BaseModel, def _get_schema(self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": - + # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): self._check_gptq_and_marlin_can_run() @@ -183,11 +183,12 @@ def _get_schema(self, weight_quant: BaseModel, num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, group_size=weight_quant.group_size) - + # Detect If Activation Quantization. - if (self.quant_format == CompressionFormat.naive_quantized.value or - self.quant_format == CompressionFormat.int_quantized.value or - self.quant_format == CompressionFormat.float_quantized.value): + if (self.quant_format == CompressionFormat.naive_quantized.value + or self.quant_format == CompressionFormat.int_quantized.value + or self.quant_format + == CompressionFormat.float_quantized.value): if self._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index f5cc7b06ec8d8..8a8c6087430f4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -8,22 +8,24 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, create_per_tensor_scale_param, create_per_channel_scale_param, - cutlass_fp8_supported, requantize_with_max_scale) + apply_fp8_linear, create_per_channel_scale_param, + create_per_tensor_scale_param, cutlass_fp8_supported, + requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs __all__ = ["CompressedTensorsW8A8Fp8"] + class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() - + # On Lovelace, fall ba - if (not self.cutlass_fp8_supported and - self.strategy == QuantizationStrategy.CHANNEL): + if (not self.cutlass_fp8_supported + and self.strategy == QuantizationStrategy.CHANNEL): raise ValueError( "Channelwise fp8 quantization requires vLLM's custom " "cutlass kernels, which are not supported on your device." @@ -31,15 +33,16 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): "to Hopper.") def process_weights_after_loading(self, layer) -> None: - # If per tensor, when we have a fused module (e.g. QKV) with per - # tensor scales (thus N scales being passed to the kernel), + # If per tensor, when we have a fused module (e.g. QKV) with per + # tensor scales (thus N scales being passed to the kernel), # requantize so we can always run per tensor with torch._scaled_mm - if (self.strategy == QuantizationStrategy.TENSOR or - not self.cutlass_fp8_supported): + if (self.strategy == QuantizationStrategy.TENSOR + or not self.cutlass_fp8_supported): max_w_scale, weight = requantize_with_max_scale( weight=layer.weight, weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths,) + logical_widths=layer.logical_widths, + ) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) @@ -52,10 +55,10 @@ def process_weights_after_loading(self, layer) -> None: else: raise ValueError(f"Unknown quantization strategy {self.strategy}") - + # INPUT SCALE if self.is_static_input_scheme: - layer.input_scale = Parameter(layer.input_scale.max(), + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: layer.input_scale = None From 0ca238ad6a1ae968e73469fce3da362fcdf7faf3 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 17 Jul 2024 00:13:19 +0000 Subject: [PATCH 4/7] fix typo --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index b478409188d8b..d3949c5e486c9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -238,7 +238,7 @@ def _verify_quantization(self) -> None: f"supported in ROCm.") if (self.quantization not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", - "compressed-tesnors")): + "compressed_tensors")): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " From 3e6174e162c052eca78ad53cab4cf9f36a6eb2e8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 17 Jul 2024 00:14:29 +0000 Subject: [PATCH 5/7] typo --- .../compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 8a8c6087430f4..892ee04be5350 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -23,7 +23,8 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() - # On Lovelace, fall ba + # On Lovelace, fail for now if channelwise. + # TODO: (@tms) fallback if (not self.cutlass_fp8_supported and self.strategy == QuantizationStrategy.CHANNEL): raise ValueError( From 8becbb3fee6f5fc9f09ca04bbdea5e5284767e7e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 17 Jul 2024 18:49:29 +0000 Subject: [PATCH 6/7] format --- .../compressed_tensors/compressed_tensors.py | 14 ++++++-------- .../quantization/compressed_tensors/utils.py | 13 +++++++++---- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 1c78df96db6d2..1424c620ae675 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -13,7 +13,8 @@ CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, - QuantizationType, find_first_name_or_class_match) + QuantizationType, find_first_name_or_class_match, + is_activation_quantization_format) from vllm.platforms import current_platform @@ -132,9 +133,9 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_per_tensor_or_channel_weight = ( - weight_quant.strategy == QuantizationStrategy.TENSOR - or weight_quant.strategy == QuantizationStrategy.CHANNEL) + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) if not (is_symmetric_weight and is_static_weight and is_per_tensor_or_channel_weight): return False @@ -185,10 +186,7 @@ def _get_schema(self, weight_quant: BaseModel, group_size=weight_quant.group_size) # Detect If Activation Quantization. - if (self.quant_format == CompressionFormat.naive_quantized.value - or self.quant_format == CompressionFormat.int_quantized.value - or self.quant_format - == CompressionFormat.float_quantized.value): + if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index ede3ad8100a91..25db308753eee 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -8,15 +8,11 @@ class CompressionFormat(Enum): dense = "dense" - # Sparsity sparse_bitmask = "sparse-bitmask" - # For Activation Quantization naive_quantized = "naive-quantized" float_quantized = "float-quantized" int_quantized = "int-quantized" - # For Marlin pack_quantized = "pack-quantized" - # For Marlin 2:4 marlin_24 = "marlin-24" @@ -81,6 +77,15 @@ class QuantizationArgs(BaseModel): ) +def is_activation_quantization_format(format: str) -> bool: + _ACTIVATION_QUANTIZATION_FORMATS = [ + CompressionFormat.naive_quantized.value, + CompressionFormat.int_quantized.value, + CompressionFormat.float_quantized.value + ] + return format in _ACTIVATION_QUANTIZATION_FORMATS + + def find_first_name_or_class_match( name: str, module: Module, From 5f2cb452526b8e948d96b57c28b02a0f52cf9754 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 17 Jul 2024 18:52:53 +0000 Subject: [PATCH 7/7] update condition slightly --- .../schemes/compressed_tensors_w8a8_fp8.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 892ee04be5350..f1ca9510d92aa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -36,9 +36,8 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): def process_weights_after_loading(self, layer) -> None: # If per tensor, when we have a fused module (e.g. QKV) with per # tensor scales (thus N scales being passed to the kernel), - # requantize so we can always run per tensor with torch._scaled_mm - if (self.strategy == QuantizationStrategy.TENSOR - or not self.cutlass_fp8_supported): + # requantize so we can always run per tensor + if self.strategy == QuantizationStrategy.TENSOR: max_w_scale, weight = requantize_with_max_scale( weight=layer.weight, weight_scale=layer.weight_scale,