Skip to content

Commit

Permalink
[ Kernel ] Enable fp8-marlin for fbgemm-fp8 models (vllm-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic authored and gnpinkert committed Jul 26, 2024
1 parent faf15e1 commit 946658a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.905
- name: "exact_match,flexible-extract"
value: 0.905
limit: 1000
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-large.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
Expand Down
26 changes: 25 additions & 1 deletion vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

logger = init_logger(__name__)

Expand All @@ -31,6 +34,12 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float):
self.ignore_list = ignore_list
self.input_scale_ub = input_scale_ub

# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89

@classmethod
def get_name(cls) -> str:
return "fbgemm_fp8"
Expand All @@ -41,7 +50,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:

@classmethod
def get_min_capability(cls) -> int:
return 89
return 80

@classmethod
def get_config_filenames(cls) -> List[str]:
Expand Down Expand Up @@ -143,11 +152,26 @@ def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

if self.quant_config.use_marlin:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale_ub

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

if self.quant_config.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)

return apply_fp8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,13 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to(
layer.orig_dtype).to(device)
is_channelwise = layer.weight_scale.shape[0] == part_size_n
if is_channelwise:
scales = layer.weight_scale
else:
scales = layer.weight_scale.repeat(1, part_size_n)
scales = scales.to(layer.orig_dtype).to(device)

# Permute scales
marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k,
Expand Down

0 comments on commit 946658a

Please sign in to comment.