Skip to content

Commit

Permalink
Support FP8 for Mistral (#559)
Browse files Browse the repository at this point in the history
Co-authored-by: Travis Addair <[email protected]>
  • Loading branch information
ajtejankar and tgaddair authored Jul 30, 2024
1 parent d1a4d09 commit 91ef7a8
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 45 deletions.
2 changes: 1 addition & 1 deletion server/Makefile-vllm
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ vllm-cuda:
git clone https://github.com/vllm-project/vllm.git vllm

build-vllm-cuda: vllm-cuda
cd vllm && git fetch && git checkout 5448f67
cd vllm && git fetch && git checkout 766435e660a786933392eb8ef0a873bc38cf0c8b
cd vllm && python setup.py build

install-vllm-cuda: build-vllm-cuda
Expand Down
58 changes: 32 additions & 26 deletions server/lorax_server/layers/fp8.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,51 @@
from typing import Optional

import torch
from vllm import _custom_ops as ops

####### from vLLM code #######


def apply_fp8_linear(
input: torch.Tensor,
qweight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
qbias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=False)

def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()
return qweight, scale
output = ops.cutlass_scaled_mm(
qinput, qweight, out_dtype=input.dtype, scale_a=x_scale, scale_b=weight_scale, bias=qbias
)

return output


class Fp8Linear(torch.nn.Module):
def __init__(
self,
weight,
bias,
weight_scale,
input_scale,
) -> None:
super().__init__()
self.dtype = weight.dtype
self.qweight, self.scale = fp8_quantize(weight)

self.bias = bias if bias is not None else None
self.qweight = weight.t()
self.weight_scale = weight_scale.view(1, -1).contiguous()
self.qbias = bias if bias is not None else None
self.input_scale = input_scale

def forward(self, input: torch.Tensor) -> torch.Tensor:
qinput, scale = fp8_quantize(input)
output, _ = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
scale_a=scale,
scale_b=self.scale,
bias=self.bias,
return apply_fp8_linear(
input=input,
qweight=self.qweight,
weight_scale=self.weight_scale,
input_scale=self.input_scale,
qbias=self.qbias,
)
return output

@property
def weight(self) -> torch.Tensor:
Expand Down
7 changes: 3 additions & 4 deletions server/lorax_server/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,17 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
return F.linear(inp, self.weight, self.bias)


def get_linear(weight, bias, quantize, fan_in_fan_out=False):
def get_linear(weight, bias, quantize, fan_in_fan_out=False, weight_scale=None, input_scale=None):
# https://huggingface.co/docs/peft/package_reference/tuners#peft.LoraConfig.fan_in_fan_out
# Set to True if replacing a Conv1D layer with a Linear layer
if fan_in_fan_out:
weight = weight.T.contiguous()

if quantize is None:
if quantize is None or (quantize == 'fp8' and weight_scale is None):
linear = FastLinear(weight, bias)
elif quantize == "fp8":
from lorax_server.layers.fp8 import Fp8Linear

linear = Fp8Linear(weight, bias)
linear = Fp8Linear(weight, bias, weight_scale=weight_scale, input_scale=input_scale)

elif quantize == "bitsandbytes":
from lorax_server.layers.bnb import Linear8bitLt
Expand Down
16 changes: 14 additions & 2 deletions server/lorax_server/layers/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def load(config, prefix: str, weights):
should_gather = False

# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq"]:
if config.quantize in ["gptq", "awq", "eetq", "fp8"]:
quantize = None
else:
quantize = config.quantize
Expand Down Expand Up @@ -110,12 +110,24 @@ def load(cls, config, prefix: str, weights, bias: bool):
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize, dim=dim)

input_scale, weight_scale = None, None
if type(weight) is tuple:
weight, input_scale, weight_scale = weight

if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim)
else:
bias = None
linear = get_linear(weight, bias, config.quantize)

linear = get_linear(
weight,
bias,
config.quantize,
weight_scale=weight_scale,
input_scale=input_scale,
)

return cls(linear)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)

if config.quantize not in ["gptq", "awq"]:
input_scale, weight_scale = None, None
if type(weight) is tuple:
weight, input_scale, weight_scale = weight

if config.quantize not in ["gptq", "awq", "fp8"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

head_size = config.hidden_size // config.num_attention_heads
Expand All @@ -212,7 +216,15 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
return TensorParallelColumnLinear(
get_linear(
weight,
bias=None,
quantize=config.quantize,
weight_scale=weight_scale,
input_scale=input_scale,
)
)


class MistralAttention(torch.nn.Module):
Expand Down
14 changes: 13 additions & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,25 @@ def load(
):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)

input_scale, weight_scale = None, None
if type(weight) is tuple:
weight, input_scale, weight_scale = weight

if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None

return cls(
get_linear(weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out),
get_linear(
weight,
bias,
config.quantize,
fan_in_fan_out=fan_in_fan_out,
weight_scale=weight_scale,
input_scale=input_scale,
),
process_group=weights.process_group,
all_reduce=all_reduce,
)
Expand Down
8 changes: 5 additions & 3 deletions server/lorax_server/utils/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def reshape_and_cache(
if SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, value_cache, slots)
else:
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "fp8" if fp8_supported else "auto", 1.0, 1.0)
torch.ops._C_cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "fp8" if fp8_supported else "auto", 1.0, 1.0
)


def attention(
Expand Down Expand Up @@ -108,7 +110,7 @@ def attention(
None,
"fp8" if fp8_supported else "auto",
1.0,
1.0
1.0,
)
else:
# Run PagedAttention V2.
Expand Down Expand Up @@ -142,5 +144,5 @@ def attention(
None,
"fp8" if fp8_supported else "auto",
1.0,
1.0
1.0,
)
10 changes: 10 additions & 0 deletions server/lorax_server/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,13 @@ def is_bf16_supported() -> bool:
True if supported, False otherwise.
"""
return torch.cuda.is_available() and torch.cuda.is_bf16_supported()


def is_fp8_quantized(config, layer_name):
# check if quantization is fp8 and either of the fused layers is not ignored
# typically, either all qkv will be quantized or none so just check for one
if config.quantize == "fp8" and hasattr(config, "quantization_config"):
ignored_layers = set(config.quantization_config.get("ignored_layers", []))
if layer_name not in ignored_layers:
return "fp8"
return None
40 changes: 34 additions & 6 deletions server/lorax_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_partial_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype != torch.int32:
if tensor.dtype not in [torch.int32, torch.int64, torch.float8_e4m3fn, torch.float8_e5m2]:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
Expand Down Expand Up @@ -116,8 +116,27 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str
bits, groupsize = self._get_bits_and_groupsize()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
w = self.get_sharded_list("weight", prefixes, dim=0)
weight = torch.cat(w, dim=dim)
weight_list = self.get_sharded_list("weight", prefixes, dim=0)
if quantize == "fp8" and weight_list[0].dtype == torch.float8_e4m3fn:
# Since there is no kernel for concatenating two tensors in PyTorch
# for fp8 datatypes, we have to cast to fp16, concat, cast back to fp8
fp16_weight_list = [w.to(torch.float16) for w in weight_list]
weight = torch.cat(fp16_weight_list, dim=dim).to(torch.float8_e4m3fn)
input_scale = None
if self.has_tensor(f"{prefixes[0]}.input_scale"):
# if the layers are being fused, then they have the same inputs
# hence their input scales will have to be the same so we pick the first one
input_scale = self.get_tensor(f"{prefixes[0]}.input_scale", use_self_dtype=False)
weight_scale_list = [self.get_tensor(f"{p}.weight_scale", use_self_dtype=False) for p in prefixes]
if len(weight_scale_list[0].shape) > 1:
weight_scale_list = self.get_sharded_list("weight_scale", prefixes, dim=0)
else:
weight_scale_list = [si.repeat(wi.shape[dim]) for si, wi in zip(weight_scale_list, weight_list)]
# weight scales are in fp32 already so no problem with concatenating them
weight_scale = torch.cat(weight_scale_list, dim=0)
return weight, input_scale, weight_scale
weight = torch.cat(weight_list, dim=dim)

return weight

def get_multi_weights_row(self, prefix: str, quantize: str):
Expand Down Expand Up @@ -201,6 +220,14 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
if quantize == "fp8" and weight.dtype == torch.float8_e4m3fn:
# weight_scale could be a tensor but if we're sharding row-wise then no
# need to shard the weight_scale as its row dimension would be 1
weight_scale = self.get_tensor(f"{prefix}.weight_scale", use_self_dtype=False)
input_scale = None
if self.has_tensor(f"{prefix}.input_scale"):
input_scale = self.get_tensor(f"{prefix}.input_scale", use_self_dtype=False)
return weight, input_scale, weight_scale
return weight

def _get_bits_and_groupsize(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -354,14 +381,15 @@ def get_slice(self, tensor_name: str):
def get_slice_shape(self, slice) -> torch.Size:
return slice.get_shape()

def get_tensor(self, tensor_name: str):
def get_tensor(self, tensor_name: str, use_self_dtype: bool = True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32
if tensor.dtype not in [torch.int32, torch.int64]:
tensor = tensor.to(dtype=self.dtype)
if tensor.dtype not in [torch.int32, torch.int64, torch.float8_e4m3fn, torch.float8_e5m2]:
if use_self_dtype:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor

Expand Down
78 changes: 78 additions & 0 deletions server/tests/utils/test_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
import torch
from transformers.models.qwen2 import Qwen2Config

from lorax_server.utils.dist import initialize_torch_distributed
from lorax_server.utils.sources.hub import (
download_weights,
weight_hub_files,
)
from lorax_server.utils.weights import Weights


@pytest.mark.parametrize(
'model_id', [
'neuralmagic/Qwen2-0.5B-Instruct-FP8',
'Qwen/Qwen2-0.5B-Instruct'
]
)
@pytest.mark.parametrize(
'prefixes', [
['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj'],
['mlp.gate_proj', 'mlp.up_proj']
]
)
def test_get_multi_weights_col(model_id, prefixes):
process_group, _, _ = initialize_torch_distributed()
filenames = weight_hub_files(model_id, 'main', '.safetensors')
local_filenames = download_weights(filenames, model_id, 'main')
config = Qwen2Config.from_pretrained(model_id, revision='main', trust_remote_code=False)
quantize = None
if hasattr(config, 'quantization_config'):
quantize = config.quantization_config['quant_method']

weights = Weights(local_filenames, 'cpu', torch.bfloat16, process_group=process_group)
prefix = 'model.layers.0'
prefixes = [f'{prefix}.{k}' for k in prefixes]
weight = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
if quantize is not None:
assert type(weight) is tuple
weight, input_scale, weight_scale = weight
assert weight.dtype == torch.float8_e4m3fn
assert input_scale.dtype == torch.float
assert weight_scale.dtype == torch.float
else:
assert weight.dtype == torch.bfloat16

@pytest.mark.parametrize(
'model_id', [
'neuralmagic/Qwen2-0.5B-Instruct-FP8',
'Qwen/Qwen2-0.5B-Instruct'
]
)
@pytest.mark.parametrize(
'prefix', ['self_attn.o_proj', 'mlp.down_proj'],
)
def test_get_multi_weights_row(model_id, prefix):
process_group, _, _ = initialize_torch_distributed()
filenames = weight_hub_files(model_id, 'main', '.safetensors')
local_filenames = download_weights(filenames, model_id, 'main')
config = Qwen2Config.from_pretrained(model_id, revision='main', trust_remote_code=False)
quantize = None
if hasattr(config, 'quantization_config'):
quantize = config.quantization_config['quant_method']

weights = Weights(local_filenames, 'cpu', torch.bfloat16, process_group=process_group)
weight = weights.get_multi_weights_row(f'model.layers.0.{prefix}', quantize=quantize)
if quantize is not None:
assert type(weight) is tuple
weight, input_scale, weight_scale = weight
assert weight.dtype == torch.float8_e4m3fn
assert input_scale.dtype == torch.float
assert weight_scale.dtype == torch.float
else:
assert weight.dtype == torch.bfloat16

0 comments on commit 91ef7a8

Please sign in to comment.