Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Kernel ] Enable fp8-marlin for fbgemm-fp8 models #6606

Merged
merged 47 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
2426e29
stash
robertgshaw2-neuralmagic Jul 18, 2024
7665b7b
format
robertgshaw2-neuralmagic Jul 18, 2024
ef27613
tweak arg name
robertgshaw2-neuralmagic Jul 18, 2024
2f96157
fix test
robertgshaw2-neuralmagic Jul 18, 2024
e748554
format
robertgshaw2-neuralmagic Jul 18, 2024
3ef571b
working e2e with our cutlass kernels
robertgshaw2-neuralmagic Jul 19, 2024
ad83666
added fp8 gemm
robertgshaw2-neuralmagic Jul 19, 2024
eb7d48c
remove
robertgshaw2-neuralmagic Jul 19, 2024
90bd839
format
robertgshaw2-neuralmagic Jul 19, 2024
15cc823
Merge branch 'main' into turn-on-fp8-dyn-per-token
robertgshaw2-neuralmagic Jul 19, 2024
d064dd7
stash
robertgshaw2-neuralmagic Jul 19, 2024
6aa37e5
dynamic per token
robertgshaw2-neuralmagic Jul 19, 2024
c9d819a
format
robertgshaw2-neuralmagic Jul 19, 2024
08cbaf7
reenable cutlass
robertgshaw2-neuralmagic Jul 19, 2024
f4cdda1
cleanup comment
robertgshaw2-neuralmagic Jul 19, 2024
2971f4d
format
robertgshaw2-neuralmagic Jul 19, 2024
b601033
added dynamic per token test case
robertgshaw2-neuralmagic Jul 19, 2024
5d8edf9
Merge branch 'turn-on-fp8-dyn-per-token' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 19, 2024
8b5d638
added use per token
robertgshaw2-neuralmagic Jul 19, 2024
006ccf0
format
Jul 19, 2024
1884acf
format
Jul 19, 2024
fe14072
Make optional ubs none
Jul 19, 2024
254dcff
format
Jul 19, 2024
919d866
Merge branch 'fp8-dpt-fpgemm' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 19, 2024
227a277
hook up end to end with varun's ub quant kernel
robertgshaw2-neuralmagic Jul 19, 2024
951834a
formatted
robertgshaw2-neuralmagic Jul 19, 2024
9aa66d3
updated for nonuniform
robertgshaw2-neuralmagic Jul 19, 2024
458a410
formatting after passing prefix around
robertgshaw2-neuralmagic Jul 19, 2024
278f6d6
Merge branch 'main' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 19, 2024
3e4aaad
fixed bad merge
robertgshaw2-neuralmagic Jul 19, 2024
de2a764
updated message
robertgshaw2-neuralmagic Jul 20, 2024
268fe94
Merge branch 'main' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 20, 2024
c88fe34
merged varun's pr
robertgshaw2-neuralmagic Jul 20, 2024
bb02a3f
fixed
robertgshaw2-neuralmagic Jul 20, 2024
1c8f71c
cleanup pr
robertgshaw2-neuralmagic Jul 20, 2024
6970e50
Update config.py
robertgshaw2-neuralmagic Jul 20, 2024
94617f0
fixed config
robertgshaw2-neuralmagic Jul 20, 2024
f9d569c
updated for new ckpt format, turned on ada lovelace, and added test case
robertgshaw2-neuralmagic Jul 20, 2024
ae45615
format
robertgshaw2-neuralmagic Jul 20, 2024
e2a1eda
add marlin support to fbgemm
robertgshaw2-neuralmagic Jul 20, 2024
a4abc78
fix configs
robertgshaw2-neuralmagic Jul 20, 2024
5008ecb
fix configs
robertgshaw2-neuralmagic Jul 20, 2024
615a2ed
added marlin nonuniform test
robertgshaw2-neuralmagic Jul 20, 2024
7ea9025
Merge branch 'main' into fbgemm-fp8-marlin
robertgshaw2-neuralmagic Jul 20, 2024
da37598
fixed
robertgshaw2-neuralmagic Jul 20, 2024
a14116c
Merge branch 'main' into fbgemm-fp8-marlin
robertgshaw2-neuralmagic Jul 20, 2024
183bfe7
use marlin remove:
robertgshaw2-neuralmagic Jul 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.905
- name: "exact_match,flexible-extract"
value: 0.905
limit: 1000
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-large.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
Expand Down
26 changes: 25 additions & 1 deletion vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

logger = init_logger(__name__)

Expand All @@ -31,6 +34,12 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float):
self.ignore_list = ignore_list
self.input_scale_ub = input_scale_ub

# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89

@classmethod
def get_name(cls) -> str:
return "fbgemm_fp8"
Expand All @@ -41,7 +50,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:

@classmethod
def get_min_capability(cls) -> int:
return 89
return 80

@classmethod
def get_config_filenames(cls) -> List[str]:
Expand Down Expand Up @@ -143,11 +152,26 @@ def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

if self.quant_config.use_marlin:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale_ub

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

if self.quant_config.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)

return apply_fp8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,13 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to(
layer.orig_dtype).to(device)
is_channelwise = layer.weight_scale.shape[0] == part_size_n
if is_channelwise:
scales = layer.weight_scale
else:
scales = layer.weight_scale.repeat(1, part_size_n)
scales = scales.to(layer.orig_dtype).to(device)

# Permute scales
marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k,
Expand Down
Loading