Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Misc ] fbgemm checkpoints #6559

Merged
merged 39 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
2426e29
stash
robertgshaw2-neuralmagic Jul 18, 2024
7665b7b
format
robertgshaw2-neuralmagic Jul 18, 2024
ef27613
tweak arg name
robertgshaw2-neuralmagic Jul 18, 2024
2f96157
fix test
robertgshaw2-neuralmagic Jul 18, 2024
e748554
format
robertgshaw2-neuralmagic Jul 18, 2024
3ef571b
working e2e with our cutlass kernels
robertgshaw2-neuralmagic Jul 19, 2024
ad83666
added fp8 gemm
robertgshaw2-neuralmagic Jul 19, 2024
eb7d48c
remove
robertgshaw2-neuralmagic Jul 19, 2024
90bd839
format
robertgshaw2-neuralmagic Jul 19, 2024
15cc823
Merge branch 'main' into turn-on-fp8-dyn-per-token
robertgshaw2-neuralmagic Jul 19, 2024
d064dd7
stash
robertgshaw2-neuralmagic Jul 19, 2024
6aa37e5
dynamic per token
robertgshaw2-neuralmagic Jul 19, 2024
c9d819a
format
robertgshaw2-neuralmagic Jul 19, 2024
08cbaf7
reenable cutlass
robertgshaw2-neuralmagic Jul 19, 2024
f4cdda1
cleanup comment
robertgshaw2-neuralmagic Jul 19, 2024
2971f4d
format
robertgshaw2-neuralmagic Jul 19, 2024
b601033
added dynamic per token test case
robertgshaw2-neuralmagic Jul 19, 2024
5d8edf9
Merge branch 'turn-on-fp8-dyn-per-token' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 19, 2024
8b5d638
added use per token
robertgshaw2-neuralmagic Jul 19, 2024
006ccf0
format
Jul 19, 2024
1884acf
format
Jul 19, 2024
fe14072
Make optional ubs none
Jul 19, 2024
254dcff
format
Jul 19, 2024
919d866
Merge branch 'fp8-dpt-fpgemm' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 19, 2024
227a277
hook up end to end with varun's ub quant kernel
robertgshaw2-neuralmagic Jul 19, 2024
951834a
formatted
robertgshaw2-neuralmagic Jul 19, 2024
9aa66d3
updated for nonuniform
robertgshaw2-neuralmagic Jul 19, 2024
458a410
formatting after passing prefix around
robertgshaw2-neuralmagic Jul 19, 2024
278f6d6
Merge branch 'main' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 19, 2024
3e4aaad
fixed bad merge
robertgshaw2-neuralmagic Jul 19, 2024
de2a764
updated message
robertgshaw2-neuralmagic Jul 20, 2024
268fe94
Merge branch 'main' into fbgemm-checkpoints
robertgshaw2-neuralmagic Jul 20, 2024
c88fe34
merged varun's pr
robertgshaw2-neuralmagic Jul 20, 2024
bb02a3f
fixed
robertgshaw2-neuralmagic Jul 20, 2024
1c8f71c
cleanup pr
robertgshaw2-neuralmagic Jul 20, 2024
6970e50
Update config.py
robertgshaw2-neuralmagic Jul 20, 2024
94617f0
fixed config
robertgshaw2-neuralmagic Jul 20, 2024
f9d569c
updated for new ckpt format, turned on ada lovelace, and added test case
robertgshaw2-neuralmagic Jul 20, 2024
ae45615
format
robertgshaw2-neuralmagic Jul 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
Expand All @@ -24,6 +25,7 @@
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
Expand Down
115 changes: 115 additions & 0 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)


class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""

@classmethod
def get_name(cls) -> str:
return "fbgemm_fp8"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
return 90

@classmethod
def get_config_filenames(cls) -> List[str]:
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
return cls()

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return FBGEMMFp8LinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class FBGEMMFp8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)

layer.logical_widths = output_partition_sizes

layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype

# WEIGHT
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
**extra_weight_attrs,
})

# WEIGHT SCALE
weight_scale = create_per_channel_scale_param(output_partition_sizes,
**extra_weight_attrs)
layer.register_parameter("weight_scale", weight_scale)

# NOT USED FOR INFERNECE
input_scale_ub = torch.nn.Parameter(
torch.zeros((1), dtype=torch.float8_e4m3fn))
layer.register_parameter("input_scale_ub", input_scale_ub)
set_weight_attrs(input_scale_ub, {
"ignore_warning": True,
**extra_weight_attrs
})

def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

return apply_fp8_linear(input=x,
Copy link

@vlasenkoalexey vlasenkoalexey Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that apply_fp8_linear would work here. Using fbgemm it would look like:

        xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
            x, num_tokens, self.activation_scale_ub
        )
        y = torch.ops.fbgemm.f8f8bf16_rowwise(
            xq, layer.weight, x_scale, layer.weight_scale, use_fast_accum=True
        )

Particularly input_scale=None is most likely wrong, here is a reference implementation for quantize_fp8_per_row

def fp8_row_quantize_ref(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    # Quantize an input tensor and return the fp8 tensor and its inverse scale.
    x_row_max = torch.max(torch.abs(x), dim=1).values
    max_scaling_factor = E4M3_MAX_POS * 512.0  # Match kernel logics
    scale = torch.Tensor(E4M3_MAX_POS / x_row_max).clamp(max=max_scaling_factor)
    xq = (x * scale.unsqueeze(1)).to(fp8_e4m3)
    return xq, scale.reciprocal().to(torch.float32)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, we are updating the quant kernel right now to use the activation_scale_ub

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in combo with https://github.com/vllm-project/vllm/pull/6547/files

which enables per token scales

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides both torch._scaled_mm and ops.scaled_fp8_quant expect scale to be scalar

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#6547 extends ops.scaled_fp8_quant to accepts per token scales (vector of scales)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch._scaled_mm will not be used.

cutlass_scaled_mm accepts per channel weights and per token activations

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pulled in #6547 to this PR so you can see

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adds #6593 adds the ub

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, this PR currently now has the

  • scale_ub
  • uses dynamic per token activation scales

weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
mgoin marked this conversation as resolved.
Show resolved Hide resolved
bias=bias,
cutlass_fp8_supported=True)
Loading