-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ Misc ] fbgemm
checkpoints
#6559
Merged
Merged
Changes from 4 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
2426e29
stash
robertgshaw2-neuralmagic 7665b7b
format
robertgshaw2-neuralmagic ef27613
tweak arg name
robertgshaw2-neuralmagic 2f96157
fix test
robertgshaw2-neuralmagic e748554
format
robertgshaw2-neuralmagic 3ef571b
working e2e with our cutlass kernels
robertgshaw2-neuralmagic ad83666
added fp8 gemm
robertgshaw2-neuralmagic eb7d48c
remove
robertgshaw2-neuralmagic 90bd839
format
robertgshaw2-neuralmagic 15cc823
Merge branch 'main' into turn-on-fp8-dyn-per-token
robertgshaw2-neuralmagic d064dd7
stash
robertgshaw2-neuralmagic 6aa37e5
dynamic per token
robertgshaw2-neuralmagic c9d819a
format
robertgshaw2-neuralmagic 08cbaf7
reenable cutlass
robertgshaw2-neuralmagic f4cdda1
cleanup comment
robertgshaw2-neuralmagic 2971f4d
format
robertgshaw2-neuralmagic b601033
added dynamic per token test case
robertgshaw2-neuralmagic 5d8edf9
Merge branch 'turn-on-fp8-dyn-per-token' into fbgemm-checkpoints
robertgshaw2-neuralmagic 8b5d638
added use per token
robertgshaw2-neuralmagic 006ccf0
format
1884acf
format
fe14072
Make optional ubs none
254dcff
format
919d866
Merge branch 'fp8-dpt-fpgemm' into fbgemm-checkpoints
robertgshaw2-neuralmagic 227a277
hook up end to end with varun's ub quant kernel
robertgshaw2-neuralmagic 951834a
formatted
robertgshaw2-neuralmagic 9aa66d3
updated for nonuniform
robertgshaw2-neuralmagic 458a410
formatting after passing prefix around
robertgshaw2-neuralmagic 278f6d6
Merge branch 'main' into fbgemm-checkpoints
robertgshaw2-neuralmagic 3e4aaad
fixed bad merge
robertgshaw2-neuralmagic de2a764
updated message
robertgshaw2-neuralmagic 268fe94
Merge branch 'main' into fbgemm-checkpoints
robertgshaw2-neuralmagic c88fe34
merged varun's pr
robertgshaw2-neuralmagic bb02a3f
fixed
robertgshaw2-neuralmagic 1c8f71c
cleanup pr
robertgshaw2-neuralmagic 6970e50
Update config.py
robertgshaw2-neuralmagic 94617f0
fixed config
robertgshaw2-neuralmagic f9d569c
updated for new ckpt format, turned on ada lovelace, and added test case
robertgshaw2-neuralmagic ae45615
format
robertgshaw2-neuralmagic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
import torch | ||
from torch.nn import Module | ||
from torch.nn.parameter import Parameter | ||
|
||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig, QuantizeMethodBase) | ||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( | ||
apply_fp8_linear, create_per_channel_scale_param) | ||
from vllm.model_executor.utils import set_weight_attrs | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class FBGEMMFp8Config(QuantizationConfig): | ||
"""Config class for FBGEMM Fp8.""" | ||
|
||
@classmethod | ||
def get_name(cls) -> str: | ||
return "fbgemm_fp8" | ||
|
||
@classmethod | ||
def get_supported_act_dtypes(cls) -> List[torch.dtype]: | ||
return [torch.bfloat16] | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
return 90 | ||
|
||
@classmethod | ||
def get_config_filenames(cls) -> List[str]: | ||
return [] | ||
|
||
@classmethod | ||
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": | ||
return cls() | ||
|
||
def get_quant_method( | ||
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: | ||
if isinstance(layer, LinearBase): | ||
return FBGEMMFp8LinearMethod(self) | ||
return None | ||
|
||
def get_scaled_act_names(self) -> List[str]: | ||
return [] | ||
|
||
|
||
class FBGEMMFp8LinearMethod(LinearMethodBase): | ||
|
||
def __init__(self, quant_config: FBGEMMFp8Config): | ||
self.quant_config = quant_config | ||
|
||
def create_weights( | ||
self, | ||
layer: torch.nn.Module, | ||
input_size_per_partition: int, | ||
output_partition_sizes: List[int], | ||
input_size: int, | ||
output_size: int, | ||
params_dtype: torch.dtype, | ||
**extra_weight_attrs, | ||
): | ||
del input_size, output_size | ||
output_size_per_partition = sum(output_partition_sizes) | ||
|
||
layer.logical_widths = output_partition_sizes | ||
|
||
layer.input_size_per_partition = input_size_per_partition | ||
layer.output_size_per_partition = output_size_per_partition | ||
layer.orig_dtype = params_dtype | ||
|
||
# WEIGHT | ||
weight = Parameter(torch.empty(output_size_per_partition, | ||
input_size_per_partition, | ||
dtype=torch.float8_e4m3fn), | ||
requires_grad=False) | ||
layer.register_parameter("weight", weight) | ||
set_weight_attrs(weight, { | ||
"input_dim": 1, | ||
"output_dim": 0, | ||
**extra_weight_attrs, | ||
}) | ||
|
||
# WEIGHT SCALE | ||
weight_scale = create_per_channel_scale_param(output_partition_sizes, | ||
**extra_weight_attrs) | ||
layer.register_parameter("weight_scale", weight_scale) | ||
|
||
# NOT USED FOR INFERNECE | ||
input_scale_ub = torch.nn.Parameter( | ||
torch.zeros((1), dtype=torch.float8_e4m3fn)) | ||
layer.register_parameter("input_scale_ub", input_scale_ub) | ||
set_weight_attrs(input_scale_ub, { | ||
"ignore_warning": True, | ||
**extra_weight_attrs | ||
}) | ||
|
||
def process_weights_after_loading(self, layer: Module) -> None: | ||
weight = layer.weight | ||
layer.weight = Parameter(weight.t(), requires_grad=False) | ||
|
||
def apply(self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
|
||
return apply_fp8_linear(input=x, | ||
weight=layer.weight, | ||
weight_scale=layer.weight_scale, | ||
input_scale=None, | ||
mgoin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
bias=bias, | ||
cutlass_fp8_supported=True) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that apply_fp8_linear would work here. Using fbgemm it would look like:
Particularly input_scale=None is most likely wrong, here is a reference implementation for quantize_fp8_per_row
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, we are updating the quant kernel right now to use the
activation_scale_ub
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@varun-sundar-rabindranath
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in combo with https://github.com/vllm-project/vllm/pull/6547/files
which enables per token scales
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides both torch._scaled_mm and ops.scaled_fp8_quant expect scale to be scalar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#6547 extends
ops.scaled_fp8_quant
to accepts per token scales (vector of scales)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch._scaled_mm
will not be used.cutlass_scaled_mm
accepts per channel weights and per token activationsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pulled in #6547 to this PR so you can see
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adds #6593 adds the ub
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, this PR currently now has the
scale_ub