From 91ef7a86db89a8651b127f7f714da4b5b7b8c2ec Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Tue, 30 Jul 2024 15:30:51 -0700 Subject: [PATCH] Support FP8 for Mistral (#559) Co-authored-by: Travis Addair --- server/Makefile-vllm | 2 +- server/lorax_server/layers/fp8.py | 58 +++++++------- server/lorax_server/layers/linear.py | 7 +- server/lorax_server/layers/tensor_parallel.py | 16 +++- .../custom_modeling/flash_mistral_modeling.py | 16 +++- server/lorax_server/utils/layers.py | 14 +++- server/lorax_server/utils/paged_attention.py | 8 +- server/lorax_server/utils/torch_utils.py | 10 +++ server/lorax_server/utils/weights.py | 40 ++++++++-- server/tests/utils/test_weights.py | 78 +++++++++++++++++++ 10 files changed, 204 insertions(+), 45 deletions(-) create mode 100644 server/tests/utils/test_weights.py diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 3e322379b..4c92391b3 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -4,7 +4,7 @@ vllm-cuda: git clone https://github.com/vllm-project/vllm.git vllm build-vllm-cuda: vllm-cuda - cd vllm && git fetch && git checkout 5448f67 + cd vllm && git fetch && git checkout 766435e660a786933392eb8ef0a873bc38cf0c8b cd vllm && python setup.py build install-vllm-cuda: build-vllm-cuda diff --git a/server/lorax_server/layers/fp8.py b/server/lorax_server/layers/fp8.py index dcd6d0114..1af3d36d5 100644 --- a/server/lorax_server/layers/fp8.py +++ b/server/lorax_server/layers/fp8.py @@ -1,20 +1,26 @@ +from typing import Optional + import torch +from vllm import _custom_ops as ops + +####### from vLLM code ####### + +def apply_fp8_linear( + input: torch.Tensor, + qweight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + input_scale_ub: Optional[torch.Tensor] = None, + qbias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + qinput, x_scale = ops.scaled_fp8_quant(input, input_scale, scale_ub=input_scale_ub, use_per_token_if_dynamic=False) -def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): - # weight, scale = quant_weights(weight, torch.int8, False) - finfo = torch.finfo(qdtype) - # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(qdtype) - scale = scale.float().reciprocal() - return qweight, scale + output = ops.cutlass_scaled_mm( + qinput, qweight, out_dtype=input.dtype, scale_a=x_scale, scale_b=weight_scale, bias=qbias + ) + + return output class Fp8Linear(torch.nn.Module): @@ -22,24 +28,24 @@ def __init__( self, weight, bias, + weight_scale, + input_scale, ) -> None: super().__init__() self.dtype = weight.dtype - self.qweight, self.scale = fp8_quantize(weight) - - self.bias = bias if bias is not None else None + self.qweight = weight.t() + self.weight_scale = weight_scale.view(1, -1).contiguous() + self.qbias = bias if bias is not None else None + self.input_scale = input_scale def forward(self, input: torch.Tensor) -> torch.Tensor: - qinput, scale = fp8_quantize(input) - output, _ = torch._scaled_mm( - qinput, - self.qweight.t(), - out_dtype=self.dtype, - scale_a=scale, - scale_b=self.scale, - bias=self.bias, + return apply_fp8_linear( + input=input, + qweight=self.qweight, + weight_scale=self.weight_scale, + input_scale=self.input_scale, + qbias=self.qbias, ) - return output @property def weight(self) -> torch.Tensor: diff --git a/server/lorax_server/layers/linear.py b/server/lorax_server/layers/linear.py index 910b0d5e5..274694bbf 100644 --- a/server/lorax_server/layers/linear.py +++ b/server/lorax_server/layers/linear.py @@ -89,18 +89,17 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: return F.linear(inp, self.weight, self.bias) -def get_linear(weight, bias, quantize, fan_in_fan_out=False): +def get_linear(weight, bias, quantize, fan_in_fan_out=False, weight_scale=None, input_scale=None): # https://huggingface.co/docs/peft/package_reference/tuners#peft.LoraConfig.fan_in_fan_out # Set to True if replacing a Conv1D layer with a Linear layer if fan_in_fan_out: weight = weight.T.contiguous() - if quantize is None: + if quantize is None or (quantize == 'fp8' and weight_scale is None): linear = FastLinear(weight, bias) elif quantize == "fp8": from lorax_server.layers.fp8 import Fp8Linear - - linear = Fp8Linear(weight, bias) + linear = Fp8Linear(weight, bias, weight_scale=weight_scale, input_scale=input_scale) elif quantize == "bitsandbytes": from lorax_server.layers.bnb import Linear8bitLt diff --git a/server/lorax_server/layers/tensor_parallel.py b/server/lorax_server/layers/tensor_parallel.py index 49f2fe3b7..78f6a8d88 100644 --- a/server/lorax_server/layers/tensor_parallel.py +++ b/server/lorax_server/layers/tensor_parallel.py @@ -37,7 +37,7 @@ def load(config, prefix: str, weights): should_gather = False # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) - if config.quantize in ["gptq", "awq", "eetq"]: + if config.quantize in ["gptq", "awq", "eetq", "fp8"]: quantize = None else: quantize = config.quantize @@ -110,12 +110,24 @@ def load(cls, config, prefix: str, weights, bias: bool): def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): weight = weights.get_multi_weights_col(prefixes, quantize=config.quantize, dim=dim) + input_scale, weight_scale = None, None + if type(weight) is tuple: + weight, input_scale, weight_scale = weight + if bias: b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes] bias = torch.cat(b, dim=dim) else: bias = None - linear = get_linear(weight, bias, config.quantize) + + linear = get_linear( + weight, + bias, + config.quantize, + weight_scale=weight_scale, + input_scale=input_scale, + ) + return cls(linear) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 3f1a0f5ec..125f2b4d9 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -201,7 +201,11 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + input_scale, weight_scale = None, None + if type(weight) is tuple: + weight, input_scale, weight_scale = weight + + if config.quantize not in ["gptq", "awq", "fp8"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -212,7 +216,15 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize)) + return TensorParallelColumnLinear( + get_linear( + weight, + bias=None, + quantize=config.quantize, + weight_scale=weight_scale, + input_scale=input_scale, + ) + ) class MistralAttention(torch.nn.Module): diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index adbf9ce50..7ce38a13f 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -296,13 +296,25 @@ def load( ): weight = weights.get_multi_weights_row(prefix, quantize=config.quantize) + input_scale, weight_scale = None, None + if type(weight) is tuple: + weight, input_scale, weight_scale = weight + if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") else: bias = None + return cls( - get_linear(weight, bias, config.quantize, fan_in_fan_out=fan_in_fan_out), + get_linear( + weight, + bias, + config.quantize, + fan_in_fan_out=fan_in_fan_out, + weight_scale=weight_scale, + input_scale=input_scale, + ), process_group=weights.process_group, all_reduce=all_reduce, ) diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 1d568c8cc..2587cd4b3 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -34,7 +34,9 @@ def reshape_and_cache( if SYSTEM == "xpu": ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, value_cache, slots) else: - torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "fp8" if fp8_supported else "auto", 1.0, 1.0) + torch.ops._C_cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "fp8" if fp8_supported else "auto", 1.0, 1.0 + ) def attention( @@ -108,7 +110,7 @@ def attention( None, "fp8" if fp8_supported else "auto", 1.0, - 1.0 + 1.0, ) else: # Run PagedAttention V2. @@ -142,5 +144,5 @@ def attention( None, "fp8" if fp8_supported else "auto", 1.0, - 1.0 + 1.0, ) diff --git a/server/lorax_server/utils/torch_utils.py b/server/lorax_server/utils/torch_utils.py index 7171e5222..6919d8677 100644 --- a/server/lorax_server/utils/torch_utils.py +++ b/server/lorax_server/utils/torch_utils.py @@ -8,3 +8,13 @@ def is_bf16_supported() -> bool: True if supported, False otherwise. """ return torch.cuda.is_available() and torch.cuda.is_bf16_supported() + + +def is_fp8_quantized(config, layer_name): + # check if quantization is fp8 and either of the fused layers is not ignored + # typically, either all qkv will be quantized or none so just check for one + if config.quantize == "fp8" and hasattr(config, "quantization_config"): + ignored_layers = set(config.quantization_config.get("ignored_layers", [])) + if layer_name not in ignored_layers: + return "fp8" + return None diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index 91ce2390a..1a153dbee 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -69,7 +69,7 @@ def get_partial_sharded(self, tensor_name: str, dim: int, range: Optional[Tuple[ raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert # u4 which are disguised as int32 - if tensor.dtype != torch.int32: + if tensor.dtype not in [torch.int32, torch.int64, torch.float8_e4m3fn, torch.float8_e5m2]: tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor @@ -116,8 +116,27 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str bits, groupsize = self._get_bits_and_groupsize() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: - w = self.get_sharded_list("weight", prefixes, dim=0) - weight = torch.cat(w, dim=dim) + weight_list = self.get_sharded_list("weight", prefixes, dim=0) + if quantize == "fp8" and weight_list[0].dtype == torch.float8_e4m3fn: + # Since there is no kernel for concatenating two tensors in PyTorch + # for fp8 datatypes, we have to cast to fp16, concat, cast back to fp8 + fp16_weight_list = [w.to(torch.float16) for w in weight_list] + weight = torch.cat(fp16_weight_list, dim=dim).to(torch.float8_e4m3fn) + input_scale = None + if self.has_tensor(f"{prefixes[0]}.input_scale"): + # if the layers are being fused, then they have the same inputs + # hence their input scales will have to be the same so we pick the first one + input_scale = self.get_tensor(f"{prefixes[0]}.input_scale", use_self_dtype=False) + weight_scale_list = [self.get_tensor(f"{p}.weight_scale", use_self_dtype=False) for p in prefixes] + if len(weight_scale_list[0].shape) > 1: + weight_scale_list = self.get_sharded_list("weight_scale", prefixes, dim=0) + else: + weight_scale_list = [si.repeat(wi.shape[dim]) for si, wi in zip(weight_scale_list, weight_list)] + # weight scales are in fp32 already so no problem with concatenating them + weight_scale = torch.cat(weight_scale_list, dim=0) + return weight, input_scale, weight_scale + weight = torch.cat(weight_list, dim=dim) + return weight def get_multi_weights_row(self, prefix: str, quantize: str): @@ -201,6 +220,14 @@ def get_multi_weights_row(self, prefix: str, quantize: str): weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) + if quantize == "fp8" and weight.dtype == torch.float8_e4m3fn: + # weight_scale could be a tensor but if we're sharding row-wise then no + # need to shard the weight_scale as its row dimension would be 1 + weight_scale = self.get_tensor(f"{prefix}.weight_scale", use_self_dtype=False) + input_scale = None + if self.has_tensor(f"{prefix}.input_scale"): + input_scale = self.get_tensor(f"{prefix}.input_scale", use_self_dtype=False) + return weight, input_scale, weight_scale return weight def _get_bits_and_groupsize(self) -> Tuple[int, int]: @@ -354,14 +381,15 @@ def get_slice(self, tensor_name: str): def get_slice_shape(self, slice) -> torch.Size: return slice.get_shape() - def get_tensor(self, tensor_name: str): + def get_tensor(self, tensor_name: str, use_self_dtype: bool = True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32 - if tensor.dtype not in [torch.int32, torch.int64]: - tensor = tensor.to(dtype=self.dtype) + if tensor.dtype not in [torch.int32, torch.int64, torch.float8_e4m3fn, torch.float8_e5m2]: + if use_self_dtype: + tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py new file mode 100644 index 000000000..36073e02b --- /dev/null +++ b/server/tests/utils/test_weights.py @@ -0,0 +1,78 @@ +import pytest +import torch +from transformers.models.qwen2 import Qwen2Config + +from lorax_server.utils.dist import initialize_torch_distributed +from lorax_server.utils.sources.hub import ( + download_weights, + weight_hub_files, +) +from lorax_server.utils.weights import Weights + + +@pytest.mark.parametrize( + 'model_id', [ + 'neuralmagic/Qwen2-0.5B-Instruct-FP8', + 'Qwen/Qwen2-0.5B-Instruct' + ] +) +@pytest.mark.parametrize( + 'prefixes', [ + ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj'], + ['mlp.gate_proj', 'mlp.up_proj'] + ] +) +def test_get_multi_weights_col(model_id, prefixes): + process_group, _, _ = initialize_torch_distributed() + filenames = weight_hub_files(model_id, 'main', '.safetensors') + local_filenames = download_weights(filenames, model_id, 'main') + config = Qwen2Config.from_pretrained(model_id, revision='main', trust_remote_code=False) + quantize = None + if hasattr(config, 'quantization_config'): + quantize = config.quantization_config['quant_method'] + + weights = Weights(local_filenames, 'cpu', torch.bfloat16, process_group=process_group) + prefix = 'model.layers.0' + prefixes = [f'{prefix}.{k}' for k in prefixes] + weight = weights.get_multi_weights_col( + prefixes=prefixes, + quantize=quantize, + dim=0, + ) + if quantize is not None: + assert type(weight) is tuple + weight, input_scale, weight_scale = weight + assert weight.dtype == torch.float8_e4m3fn + assert input_scale.dtype == torch.float + assert weight_scale.dtype == torch.float + else: + assert weight.dtype == torch.bfloat16 + +@pytest.mark.parametrize( + 'model_id', [ + 'neuralmagic/Qwen2-0.5B-Instruct-FP8', + 'Qwen/Qwen2-0.5B-Instruct' + ] +) +@pytest.mark.parametrize( + 'prefix', ['self_attn.o_proj', 'mlp.down_proj'], +) +def test_get_multi_weights_row(model_id, prefix): + process_group, _, _ = initialize_torch_distributed() + filenames = weight_hub_files(model_id, 'main', '.safetensors') + local_filenames = download_weights(filenames, model_id, 'main') + config = Qwen2Config.from_pretrained(model_id, revision='main', trust_remote_code=False) + quantize = None + if hasattr(config, 'quantization_config'): + quantize = config.quantization_config['quant_method'] + + weights = Weights(local_filenames, 'cpu', torch.bfloat16, process_group=process_group) + weight = weights.get_multi_weights_row(f'model.layers.0.{prefix}', quantize=quantize) + if quantize is not None: + assert type(weight) is tuple + weight, input_scale, weight_scale = weight + assert weight.dtype == torch.float8_e4m3fn + assert input_scale.dtype == torch.float + assert weight_scale.dtype == torch.float + else: + assert weight.dtype == torch.bfloat16