From 82da168af30548c5a65465834f226996e562c087 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 20 Jul 2024 14:50:10 -0400 Subject: [PATCH] [ Kernel ] Enable `fp8-marlin` for `fbgemm-fp8` models (#6606) --- ...lama-3-70B-Instruct-FBGEMM-nonuniform.yaml | 11 ++++++++ .../lm-eval-harness/configs/models-large.txt | 1 + .../layers/quantization/fbgemm_fp8.py | 26 ++++++++++++++++++- .../quantization/utils/marlin_utils_fp8.py | 9 +++++-- 4 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 .buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml new file mode 100644 index 0000000000000..4397effa82cc8 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 +model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.905 + - name: "exact_match,flexible-extract" + value: 0.905 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 94b15a87235b9..37eeac85c933b 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -1,3 +1,4 @@ +Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index e6e8d28e3e16a..e84564714171a 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -9,9 +9,12 @@ UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear, create_per_channel_scale_param) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -31,6 +34,12 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float): self.ignore_list = ignore_list self.input_scale_ub = input_scale_ub + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + self.use_marlin = capability < 89 + @classmethod def get_name(cls) -> str: return "fbgemm_fp8" @@ -41,7 +50,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 89 + return 80 @classmethod def get_config_filenames(cls) -> List[str]: @@ -143,11 +152,26 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight layer.weight = Parameter(weight.t(), requires_grad=False) + if self.quant_config.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale_ub + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.quant_config.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + return apply_fp8_linear(input=x, weight=layer.weight, weight_scale=layer.weight_scale, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index e93eb747ba2eb..aabd46e64536f 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -76,8 +76,13 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: # WEIGHT SCALES # Currently Marlin doesn't support per-tensor scales, so we # expand it to channelwise - scales = layer.weight_scale.repeat(1, part_size_n).to( - layer.orig_dtype).to(device) + is_channelwise = layer.weight_scale.shape[0] == part_size_n + if is_channelwise: + scales = layer.weight_scale + else: + scales = layer.weight_scale.repeat(1, part_size_n) + scales = scales.to(layer.orig_dtype).to(device) + # Permute scales marlin_scales = marlin_permute_scales(s=scales, size_k=part_size_k,