From 6f887aabc4f67314e94120c29460f4003686bad7 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Tue, 1 Oct 2024 16:16:29 -0700 Subject: [PATCH 01/25] (feat) : support fp8 kv cache --- launcher/src/main.rs | 4 +++ server/lorax_server/cli.py | 1 + server/lorax_server/layers/linear.py | 4 +-- server/lorax_server/layers/tensor_parallel.py | 2 +- .../custom_modeling/flash_llama_modeling.py | 26 ++++++++++++++-- .../custom_modeling/flash_mistral_modeling.py | 28 +++++++++++++++-- server/lorax_server/utils/paged_attention.py | 31 ++++++++++--------- server/lorax_server/utils/torch_utils.py | 2 +- server/lorax_server/utils/weights.py | 4 +-- 9 files changed, 78 insertions(+), 24 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0edc5c684..96b17babe 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -32,6 +32,7 @@ enum Quantization { Hqq_3bit, Hqq_2bit, Fp8, + Fp8_KV, } impl std::fmt::Display for Quantization { @@ -68,6 +69,9 @@ impl std::fmt::Display for Quantization { Quantization::Fp8 => { write!(f, "fp8") } + Quantization::Fp8_KV => { + write!(f, "fp8_kv") + } } } } diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index c553b0dd1..bca7400c0 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -23,6 +23,7 @@ class Quantization(str, Enum): hqq_3bit = "hqq-3bit" hqq_2bit = "hqq-2bit" fp8 = "fp8" + fp8_kv = "fp8_kv" class Dtype(str, Enum): diff --git a/server/lorax_server/layers/linear.py b/server/lorax_server/layers/linear.py index 5af0a509e..68b4c9609 100644 --- a/server/lorax_server/layers/linear.py +++ b/server/lorax_server/layers/linear.py @@ -95,9 +95,9 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False, weight_scale=None, if fan_in_fan_out: weight = weight.T.contiguous() - if quantize is None or (quantize == "fp8" and weight_scale is None): + if quantize is None or (quantize.startswith("fp8") and weight_scale is None): linear = FastLinear(weight, bias) - elif quantize == "fp8": + elif quantize.startswith("fp8"): from lorax_server.layers.fp8 import Fp8Linear linear = Fp8Linear(weight, bias, weight_scale=weight_scale, input_scale=input_scale) diff --git a/server/lorax_server/layers/tensor_parallel.py b/server/lorax_server/layers/tensor_parallel.py index 78f6a8d88..c2a140fad 100644 --- a/server/lorax_server/layers/tensor_parallel.py +++ b/server/lorax_server/layers/tensor_parallel.py @@ -37,7 +37,7 @@ def load(config, prefix: str, weights): should_gather = False # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) - if config.quantize in ["gptq", "awq", "eetq", "fp8"]: + if config.quantize in ["gptq", "awq", "eetq", "fp8", "fp8_kv"]: quantize = None else: quantize = config.quantize diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index bc21484a2..d66af60f0 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -199,7 +199,7 @@ def _load_gqa(config, prefix: str, weights): if isinstance(weight, tuple): weight, input_scale, weight_scale = weight - if config.quantize not in ["gptq", "awq", "fp8"]: + if config.quantize not in ["gptq", "awq", "fp8", "fp8_kv"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -251,6 +251,16 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + # todo(ajinkya): only supports the default 'fp8' dtype in vLLM for kv cache but + # we can also support other dtypes like f8_e4m3 + if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + self.kv_cache_dtype = 'fp8' + self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False) + self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False) + else: + self.kv_cache_dtype = 'auto' + self.k_scale = 1.0 + self.v_scale = 1.0 self.query_key_value = load_attention(config, prefix, weights, layer_id) @@ -318,7 +328,16 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache( + kv[:, 0], + kv[:, 1], + kv_cache[0], + kv_cache[1], + slots, + # self.kv_cache_dtype, + # self.k_scale, + # self.v_scale, + ) # Prefill if cu_seqlen_prefill is not None: @@ -346,6 +365,9 @@ def forward( block_tables, input_lengths, max_s, + # self.kv_cache_dtype, + # self.k_scale, + # self.v_scale, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index cc6f93dc6..cdc51f079 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -204,7 +204,7 @@ def _load_gqa(config, prefix: str, weights, head_size): if type(weight) is tuple: weight, input_scale, weight_scale = weight - if config.quantize not in ["gptq", "awq", "fp8"]: + if config.quantize not in ["gptq", "awq", "fp8", "fp8_kv"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) num_heads = config.num_attention_heads // weights.process_group.size() @@ -259,6 +259,16 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + # todo(ajinkya): only supports the default 'fp8' dtype in vLLM for kv cache but + # we can also support other dtypes like f8_e4m3 + if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + self.kv_cache_dtype = 'fp8' + self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False) + self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False) + else: + self.kv_cache_dtype = 'auto' + self.k_scale = 1.0 + self.v_scale = 1.0 self.query_key_value = load_attention(config, prefix, weights, layer_id, self.head_size) @@ -332,11 +342,21 @@ def forward( else: kv_to_cache = kv - paged_attention.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache( + kv_to_cache[:, 0], + kv_to_cache[:, 1], + kv_cache[0], + kv_cache[1], + slots, + self.kv_cache_dtype, + self.k_scale, + self.v_scale, + ) # Prefill if cu_seqlen_prefill is not None: # flash attention + # note: flashinfer backend + fp8 kv cache can cause problems attn_output = flash_attn.attention( query, torch.select(kv, dim=1, index=0), @@ -360,6 +380,10 @@ def forward( block_tables, input_lengths, max_s, + None, + self.kv_cache_dtype, + self.k_scale, + self.v_scale, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 34d11e843..bc12935d7 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -18,13 +18,10 @@ f"Could not import vllm paged attention. Make sure your installation is correct. Error: {e}" ) from e -# TODO(travis): fix for CUDA 8.9 (Lovelace) and 9.0 (Hopper) -# if torch.cuda.is_available(): -# fp8_supported = ( -# torch.cuda.get_device_capability()[0] >= 9 -# ) # or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) -# else: -fp8_supported = False + +def is_fp8_supported(): + return (torch.cuda.get_device_capability()[0] >= 9) \ + or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) def reshape_and_cache( @@ -33,6 +30,9 @@ def reshape_and_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, + kv_cache_dtype: str = 'auto', + k_scale: torch.Tensor = 1.0, + v_scale: torch.Tensor = 1.0, ): if FLASH_INFER: shape = key_cache.shape @@ -42,7 +42,7 @@ def reshape_and_cache( ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, value_cache, slots) else: torch.ops._C_cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "fp8" if fp8_supported else "auto", 1.0, 1.0 + key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale ) @@ -57,6 +57,9 @@ def attention( input_lengths: torch.Tensor, max_s: int, softcap: Optional[float] = None, + kv_cache_dtype: str = 'auto', + k_scale: torch.Tensor = 1.0, + v_scale: torch.Tensor = 1.0, ): if FLASH_INFER: from lorax_server.utils.flashinfer_attention import decode_state @@ -128,9 +131,9 @@ def attention( block_size, max_s, None, - "fp8" if fp8_supported else "auto", - 1.0, - 1.0, + kv_cache_dtype, + k_scale, + v_scale, ) else: # Run PagedAttention V2. @@ -162,9 +165,9 @@ def attention( block_size, max_s, None, - "fp8" if fp8_supported else "auto", - 1.0, - 1.0, + kv_cache_dtype, + k_scale, + v_scale, ) return out diff --git a/server/lorax_server/utils/torch_utils.py b/server/lorax_server/utils/torch_utils.py index 6919d8677..c76d1c44c 100644 --- a/server/lorax_server/utils/torch_utils.py +++ b/server/lorax_server/utils/torch_utils.py @@ -13,7 +13,7 @@ def is_bf16_supported() -> bool: def is_fp8_quantized(config, layer_name): # check if quantization is fp8 and either of the fused layers is not ignored # typically, either all qkv will be quantized or none so just check for one - if config.quantize == "fp8" and hasattr(config, "quantization_config"): + if config.quantize and config.quantize.startswith("fp8") and hasattr(config, "quantization_config"): ignored_layers = set(config.quantization_config.get("ignored_layers", [])) if layer_name not in ignored_layers: return "fp8" diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index a71a09a42..56e632821 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -119,7 +119,7 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: weight_list = self.get_sharded_list("weight", prefixes, dim=0) - if quantize == "fp8" and weight_list[0].dtype == torch.float8_e4m3fn: + if quantize and quantize.startswith("fp8") and weight_list[0].dtype == torch.float8_e4m3fn: # Since there is no kernel for concatenating two tensors in PyTorch # for fp8 datatypes, we have to cast to fp16, concat, cast back to fp8 fp16_weight_list = [w.to(torch.float16) for w in weight_list] @@ -222,7 +222,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str): weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) - if quantize == "fp8" and weight.dtype == torch.float8_e4m3fn: + if quantize and quantize.startswith("fp8") and weight.dtype == torch.float8_e4m3fn: # weight_scale could be a tensor but if we're sharding row-wise then no # need to shard the weight_scale as its row dimension would be 1 weight_scale = self.get_tensor(f"{prefix}.weight_scale", use_self_dtype=False) From 93fc0d17026e7447a40b91814e0fc2534f6786e9 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Wed, 16 Oct 2024 18:50:51 -0700 Subject: [PATCH 02/25] fix a few things --- .../models/custom_modeling/flash_mistral_modeling.py | 6 +++--- server/lorax_server/models/flash_causal_lm.py | 4 ++-- server/lorax_server/utils/paged_attention.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index cdc51f079..e4a13b41e 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -262,9 +262,9 @@ def __init__( # todo(ajinkya): only supports the default 'fp8' dtype in vLLM for kv cache but # we can also support other dtypes like f8_e4m3 if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): - self.kv_cache_dtype = 'fp8' - self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False) - self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False) + self.kv_cache_dtype = 'fp8_e4m3' + self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() + self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() else: self.kv_cache_dtype = 'auto' self.k_scale = 1.0 diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 99821bed7..51fb4b717 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -979,12 +979,12 @@ def init_kv_cache( ( torch.empty( (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), - dtype=dtype, + dtype=torch.uint8, device=device, ), torch.empty( (num_blocks, num_heads, head_size, BLOCK_SIZE), - dtype=dtype, + dtype=torch.uint8, device=device, ), ) diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index bc12935d7..3d207a3a6 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -31,8 +31,8 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, kv_cache_dtype: str = 'auto', - k_scale: torch.Tensor = 1.0, - v_scale: torch.Tensor = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, ): if FLASH_INFER: shape = key_cache.shape @@ -58,8 +58,8 @@ def attention( max_s: int, softcap: Optional[float] = None, kv_cache_dtype: str = 'auto', - k_scale: torch.Tensor = 1.0, - v_scale: torch.Tensor = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, ): if FLASH_INFER: from lorax_server.utils.flashinfer_attention import decode_state From 939b479ab73041b2a74dbd7a06e7fa2c553de246 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 17 Oct 2024 00:54:32 -0700 Subject: [PATCH 03/25] add support for fp8 kv cache using flash infer --- .../custom_modeling/flash_mistral_modeling.py | 16 +++++++++++++++- server/lorax_server/models/flash_causal_lm.py | 4 ++-- server/lorax_server/utils/flash_attn.py | 4 ++++ server/lorax_server/utils/paged_attention.py | 15 +++++++++++++++ 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index cc6f93dc6..eb8cecf7f 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -259,6 +259,8 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() + self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() self.query_key_value = load_attention(config, prefix, weights, layer_id, self.head_size) @@ -332,7 +334,15 @@ def forward( else: kv_to_cache = kv - paged_attention.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache( + kv_to_cache[:, 0], + kv_to_cache[:, 1], + kv_cache[0], + kv_cache[1], + slots, + self.k_scale, + self.v_scale + ) # Prefill if cu_seqlen_prefill is not None: @@ -347,6 +357,8 @@ def forward( max_s, self.softmax_scale, window_size_left=self.max_past, + k_scale=self.k_scale, + v_scale=self.v_scale, ) # Decode else: @@ -360,6 +372,8 @@ def forward( block_tables, input_lengths, max_s, + k_scale=self.k_scale, + v_scale=self.v_scale, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2626d30b8..381ed8789 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -963,12 +963,12 @@ def init_kv_cache( ( torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=dtype, + dtype=torch.float8_e4m3fn, device=device, ), torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=dtype, + dtype=torch.float8_e4m3fn, device=device, ), ) diff --git a/server/lorax_server/utils/flash_attn.py b/server/lorax_server/utils/flash_attn.py index 404cec60c..bd023967b 100644 --- a/server/lorax_server/utils/flash_attn.py +++ b/server/lorax_server/utils/flash_attn.py @@ -127,6 +127,8 @@ def attention( window_size_left=-1, causal=True, softcap=0.0, + k_scale=1.0, + v_scale=1.0, ): assert window_size_left == -1, "Windowing is not supported with flash infer when using kv cache" from lorax_server.utils.flashinfer_attention import prefill_state, prefill_with_paged_kv_state @@ -149,6 +151,8 @@ def attention( paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, + k_scale=k_scale, + v_scale=v_scale, ) elif HAS_FLASH_ATTN_V2_CUDA: diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 34d11e843..4cbba1f04 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -26,6 +26,11 @@ # else: fp8_supported = False +def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fn) + def reshape_and_cache( key: torch.Tensor, @@ -33,8 +38,14 @@ def reshape_and_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, + k_scale: float = 1.0, + v_scale: float = 1.0, ): if FLASH_INFER: + key = static_per_tensor_quantize(key, k_scale).view(torch.uint8) + value = static_per_tensor_quantize(value, v_scale).view(torch.uint8) + key_cache = key_cache.view(torch.uint8) + value_cache = value_cache.view(torch.uint8) shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value @@ -57,6 +68,8 @@ def attention( input_lengths: torch.Tensor, max_s: int, softcap: Optional[float] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, ): if FLASH_INFER: from lorax_server.utils.flashinfer_attention import decode_state @@ -66,6 +79,8 @@ def attention( paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, + k_scale=k_scale, + v_scale=v_scale, ) # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py From 56589ed84930412464419d4e7357124dd41e4444 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 17 Oct 2024 20:21:09 -0700 Subject: [PATCH 04/25] add logging --- .../models/custom_modeling/flash_mistral_modeling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index eb8cecf7f..79e23ecfb 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -19,6 +19,7 @@ # limitations under the License. from typing import List, Optional, Tuple +from loguru import logger # Flash attention imports import dropout_layer_norm @@ -261,6 +262,7 @@ def __init__( self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() + logger.info('load kv scales') self.query_key_value = load_attention(config, prefix, weights, layer_id, self.head_size) From 1517b161433ecbf25268f51c0c9a92dadb291c4c Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 18 Oct 2024 11:04:03 -0700 Subject: [PATCH 05/25] remove unnecessary comments --- .../models/custom_modeling/flash_mistral_modeling.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 77a30e013..3520803d5 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -260,8 +260,6 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - # todo(ajinkya): only supports the default 'fp8' dtype in vLLM for kv cache but - # we can also support other dtypes like f8_e4m3 if paged_attention.is_fp8_kv_supported(config.quantize): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() @@ -354,7 +352,6 @@ def forward( # Prefill if cu_seqlen_prefill is not None: # flash attention - # note: flashinfer backend + fp8 kv cache can cause problems attn_output = flash_attn.attention( query, torch.select(kv, dim=1, index=0), From a5f1c253ecbeb8630deb4ec78d4cf9ef2d23320d Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 18 Oct 2024 11:11:36 -0700 Subject: [PATCH 06/25] remove is_fp8_kv_supported function --- .../models/custom_modeling/flash_mistral_modeling.py | 2 +- server/lorax_server/models/flash_causal_lm.py | 8 +++----- server/lorax_server/utils/paged_attention.py | 4 ---- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 3520803d5..deb992561 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -260,7 +260,7 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - if paged_attention.is_fp8_kv_supported(config.quantize): + if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() else: diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index d93669fea..2cda456c7 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -958,13 +958,11 @@ def init_kv_cache( element_size = torch.tensor([], dtype=dtype).element_size() x = BLOCK_SIZE // element_size - - if paged_attention.is_fp8_kv_supported(self.config.quantize): - kv_dtype = torch.float8_e4m3fn - else: - kv_dtype = dtype + kv_dtype = dtype if FLASH_INFER: + if paged_attention.is_fp8_supported() and self.config.quantize and self.config.quantize.endswith('_kv'): + kv_dtype = torch.float8_e4m3fn self.kv_cache = [ ( torch.empty( diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index baa9e503c..93279a20a 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -25,10 +25,6 @@ def is_fp8_supported(): or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) -def is_fp8_kv_supported(quantization_type): - return FLASH_INFER and is_fp8_supported() and quantization_type and quantization_type.endswith('_kv') - - def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn) qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) From 52021c0eb8ecac0202e82893456adb9692bad333 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 18 Oct 2024 14:27:08 -0700 Subject: [PATCH 07/25] fix attention api and kv_dtype --- server/lorax_server/models/flash_causal_lm.py | 20 ++++++++++--------- server/lorax_server/utils/flash_attn.py | 8 ++++++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2cda456c7..4421d07e6 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -812,6 +812,11 @@ def __init__( config = config_cls.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code) config.quantize = quantize + if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + self.kv_dtype = torch.float8_e4m3fn + logger.info('Enabling FP8 KV Cache') + else: + self.kv_dtype = dtype torch.distributed.barrier(group=self.process_group) @@ -958,21 +963,18 @@ def init_kv_cache( element_size = torch.tensor([], dtype=dtype).element_size() x = BLOCK_SIZE // element_size - kv_dtype = dtype if FLASH_INFER: - if paged_attention.is_fp8_supported() and self.config.quantize and self.config.quantize.endswith('_kv'): - kv_dtype = torch.float8_e4m3fn self.kv_cache = [ ( torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=kv_dtype, + dtype=dtype, device=device, ), torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=kv_dtype, + dtype=dtype, device=device, ), ) @@ -983,12 +985,12 @@ def init_kv_cache( ( torch.empty( (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), - dtype=kv_dtype, + dtype=dtype, device=device, ), torch.empty( (num_blocks, num_heads, head_size, BLOCK_SIZE), - dtype=kv_dtype, + dtype=dtype, device=device, ), ) @@ -1010,7 +1012,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_dtype, self.device, ) @@ -1089,7 +1091,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_dtype, self.device, ) diff --git a/server/lorax_server/utils/flash_attn.py b/server/lorax_server/utils/flash_attn.py index bd023967b..132fff468 100644 --- a/server/lorax_server/utils/flash_attn.py +++ b/server/lorax_server/utils/flash_attn.py @@ -169,6 +169,8 @@ def attention( window_size_left=-1, causal=True, softcap=0.0, + k_scale=1.0, + v_scale=1.0, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -210,6 +212,8 @@ def attention( window_size_left=-1, causal=True, softcap=0.0, + k_scale=1.0, + v_scale=1.0, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -251,6 +255,8 @@ def attention( window_size_left=-1, causal=True, softcap=0.0, + k_scale=1.0, + v_scale=1.0, ): out = torch.empty_like(q) output, _ = triton_attention( @@ -281,6 +287,8 @@ def attention( window_size_left=-1, causal=True, softcap=0.0, + k_scale=1.0, + v_scale=1.0, ): if window_size_left != -1: raise NotImplementedError("window_size_left is only available with flash attn v2") From 7dd10f3b64c200ec18d4c92bf141e5b8c0973ea9 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 18 Oct 2024 14:53:50 -0700 Subject: [PATCH 08/25] keep ruff happy --- .../models/custom_modeling/flash_mistral_modeling.py | 1 - server/lorax_server/models/flash_causal_lm.py | 3 +-- server/lorax_server/utils/paged_attention.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index deb992561..3af8918a0 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -19,7 +19,6 @@ # limitations under the License. from typing import List, Optional, Tuple -from loguru import logger # Flash attention imports import dropout_layer_norm diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 4421d07e6..ebc165c2b 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -23,8 +23,7 @@ PrefillTokens, ) from lorax_server.pb import generate_pb2 -from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria -from lorax_server.utils import paged_attention +from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, paged_attention from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, create_merged_weight_files from lorax_server.utils.attention.utils import block_tables_to_ragged from lorax_server.utils.dist import MEMORY_FRACTION, MEMORY_WIGGLE_ROOM, initialize_torch_distributed diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 93279a20a..c5df00ac4 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -2,7 +2,6 @@ import torch -from lorax_server.utils import paged_attention from lorax_server.utils.import_utils import SYSTEM from lorax_server.utils.state import FLASH_INFER From 787b58e488bc3d95a5106e6c85b369ae67ad7401 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 24 Oct 2024 00:51:24 -0700 Subject: [PATCH 09/25] use fp16 prefill with fp8 kv --- .../custom_modeling/flash_mistral_modeling.py | 4 ++-- server/lorax_server/models/flash_causal_lm.py | 14 ++++---------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 3af8918a0..4c5202753 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -355,8 +355,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - kv_cache[0], - kv_cache[1], + None, + None, cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index ebc165c2b..e6ceaa58c 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1124,25 +1124,19 @@ def _forward_context( from lorax_server.utils.flashinfer_attention import ( use_decode_state, + use_prefill_state, use_prefill_with_paged_kv_state, ) # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) if cu_seqlen_prefill is not None: - return use_prefill_with_paged_kv_state( - state=(state if state is not None else self.prefill_with_paged_kv_state), - # block_tables=block_tables_to_ragged( - # block_tables=block_tables, - # input_lengths=input_lengths, - # prefix_lens=prefix_lens, - # ), - block_tables=block_tables, + return use_prefill_state( + state=(state if state is not None else self.prefill_state), cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, - page_size=BLOCK_SIZE, + query_dtype='bfloat16' ) else: assert input_lengths_tensor is not None From 8dae2d8ac969a8c63775616a11dc2dfd672af99d Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 24 Oct 2024 17:30:35 -0700 Subject: [PATCH 10/25] use fp16 prefill for fp8 kv cache (without prefix caching) --- .../custom_modeling/flash_mistral_modeling.py | 6 ++-- server/lorax_server/models/flash_causal_lm.py | 33 ++++++++++++++----- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 4c5202753..4159cb764 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -262,9 +262,11 @@ def __init__( if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() + self.kv_dtype = 'fp8' else: self.k_scale = 1.0 self.v_scale = 1.0 + self.kv_dtype = 'auto' self.query_key_value = load_attention(config, prefix, weights, layer_id, self.head_size) @@ -355,8 +357,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - None, - None, + None if self.kv_dtype == 'fp8' else kv_cache[0], + None if self.kv_dtype == 'fp8' else kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index e6ceaa58c..e6693bdb1 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1130,14 +1130,31 @@ def _forward_context( # has_prefix_lens = any(prefix_len > 0 for prefix_len in prefix_lens) if cu_seqlen_prefill is not None: - return use_prefill_state( - state=(state if state is not None else self.prefill_state), - cu_seqlens=cu_seqlen_prefill, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - query_dtype='bfloat16' - ) + if self.kv_dtype == torch.float8_e4m3fn: + return use_prefill_state( + state=(state if state is not None else self.prefill_state), + cu_seqlens=cu_seqlen_prefill, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + query_dtype=self.dtype, + ) + else: + return use_prefill_with_paged_kv_state( + state=(state if state is not None else self.prefill_with_paged_kv_state), + # block_tables=block_tables_to_ragged( + # block_tables=block_tables, + # input_lengths=input_lengths, + # prefix_lens=prefix_lens, + # ), + block_tables=block_tables, + cu_seqlens=cu_seqlen_prefill, + input_lengths=input_lengths_tensor, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + ) else: assert input_lengths_tensor is not None return use_decode_state( From d88046425cab76ef9dc7aa860993810993a76bb4 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 24 Oct 2024 18:36:07 -0700 Subject: [PATCH 11/25] add window_left option --- server/lorax_server/utils/flashinfer_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/lorax_server/utils/flashinfer_attention.py b/server/lorax_server/utils/flashinfer_attention.py index 2accc60ca..87bbbd218 100644 --- a/server/lorax_server/utils/flashinfer_attention.py +++ b/server/lorax_server/utils/flashinfer_attention.py @@ -108,6 +108,7 @@ def use_prefill_state( num_kv_heads: int, head_size: int, query_dtype: str = "float16", + window_left: int, ): """ Context manager to set the active flashinfer prefill state to the given @@ -124,6 +125,7 @@ def use_prefill_state( num_kv_heads=num_kv_heads, head_dim=head_size, q_data_type=query_dtype, + # window_left=window_left, TODO ) yield finally: From bb7573006b8dff77998a2902349a9a6c0ad1844a Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 24 Oct 2024 22:41:03 -0700 Subject: [PATCH 12/25] fix merge conflicts --- server/lorax_server/models/flash_causal_lm.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 5ab03995d..27cd6b0b8 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1285,13 +1285,14 @@ def _forward_context( num_kv_heads=self.num_kv_heads, head_size=self.head_size, query_dtype=self.dtype, + window_left=self.sliding_window, ) else: - assert input_lengths_tensor is not None - return use_decode_state( - state=state if state is not None else self.decode_state, - input_lengths=input_lengths_tensor + cache_lengths_tensor, + return use_prefill_with_paged_kv_state( + state=(state if state is not None else self.prefill_with_paged_kv_state), block_tables=block_tables, + cu_seqlens=cu_seqlen_prefill, + input_lengths=input_lengths_tensor + cache_lengths_tensor, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_size, From 3031ba47b47b62f596e6769059a12535262bfd4e Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 25 Oct 2024 13:28:25 -0700 Subject: [PATCH 13/25] move paged_attention import location --- server/lorax_server/models/flash_causal_lm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 27cd6b0b8..eb83efb6c 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -21,7 +21,7 @@ NextTokens, ) from lorax_server.pb import generate_pb2 -from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, paged_attention +from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, create_merged_weight_files from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.attention.utils import block_tables_to_ragged @@ -955,6 +955,8 @@ def __init__( config = config_cls.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code) config.quantize = quantize + + from lorax_server.utils import paged_attention if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): self.kv_dtype = torch.float8_e4m3fn logger.info('Enabling FP8 KV Cache') From 1742cdeeaff66be5ac596fcc53d20fc3caed6875 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 25 Oct 2024 13:58:27 -0700 Subject: [PATCH 14/25] add llama and qwen models --- .../custom_modeling/flash_llama_modeling.py | 26 ++++++++++++++--- .../custom_modeling/flash_qwen2_modeling.py | 28 ++++++++++++++++--- .../custom_modeling/flash_qwen_modeling.py | 26 +++++++++++++++-- 3 files changed, 69 insertions(+), 11 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 524fa792d..ee4b6175f 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -200,7 +200,7 @@ def _load_gqa(config, prefix: str, weights): if isinstance(weight, tuple): weight, input_scale, weight_scale = weight - if config.quantize not in ["gptq", "awq", "fp8"]: + if config.quantize not in ["gptq", "awq", "fp8", "fp8_kv"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -252,6 +252,14 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() + self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() + self.kv_dtype = 'fp8' + else: + self.k_scale = 1.0 + self.v_scale = 1.0 + self.kv_dtype = 'auto' self.query_key_value = load_attention(config, prefix, weights, layer_id) @@ -319,7 +327,15 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache( + kv[:, 0], + kv[:, 1], + kv_cache[0], + kv_cache[1], + slots, + self.k_scale, + self.v_scale + ) # Prefill if cu_seqlen_prefill is not None: @@ -328,8 +344,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - kv_cache[0], - kv_cache[1], + None if self.kv_dtype == 'fp8' else kv_cache[0], + None if self.kv_dtype == 'fp8' else kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, @@ -347,6 +363,8 @@ def forward( block_tables, seqlen, max_s, + k_scale=self.k_scale, + v_scale=self.v_scale, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 261c6cff5..3cbd9ccf4 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -133,7 +133,7 @@ def _load_gqa(config, prefix: str, weights): if isinstance(weight, tuple): weight, input_scale, weight_scale = weight - if config.quantize not in ["gptq", "awq", "fp8"]: + if config.quantize not in ["gptq", "awq", "fp8", "fp8_kv"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -192,6 +192,14 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() + self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() + self.kv_dtype = 'fp8' + else: + self.k_scale = 1.0 + self.v_scale = 1.0 + self.kv_dtype = 'auto' self.query_key_value = load_attention(config, prefix, weights, layer_id) @@ -244,7 +252,15 @@ def forward( else: kv_to_cache = kv - paged_attention.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache( + kv_to_cache[:, 0], + kv_to_cache[:, 1], + kv_cache[0], + kv_cache[1], + slots, + self.k_scale, + self.v_scale + ) # Prefill if cu_seqlen_prefill is not None: @@ -253,12 +269,14 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - kv_cache[0], - kv_cache[1], + None if self.kv_dtype == 'fp8' else kv_cache[0], + None if self.kv_dtype == 'fp8' else kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, window_size_left=self.max_past, + k_scale=self.k_scale, + v_scale=self.v_scale, ) # Decode else: @@ -273,6 +291,8 @@ def forward( block_tables, seqlen, max_s, + k_scale=self.k_scale, + v_scale=self.v_scale, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index e55c10544..da3eb581c 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -200,6 +200,14 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = self.num_heads + if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() + self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() + self.kv_dtype = 'fp8' + else: + self.k_scale = 1.0 + self.v_scale = 1.0 + self.kv_dtype = 'auto' self.c_attn = load_attention(config, prefix, weights, layer_id) @@ -246,7 +254,15 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache( + kv[:, 0], + kv[:, 1], + kv_cache[0], + kv_cache[1], + slots, + self.k_scale, + self.v_scale + ) # Prefill if cu_seqlen_prefill is not None: @@ -255,11 +271,13 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - kv_cache[0], - kv_cache[1], + None if self.kv_dtype == 'fp8' else kv_cache[0], + None if self.kv_dtype == 'fp8' else kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, + k_scale=self.k_scale, + v_scale=self.v_scale, ) # Decode else: @@ -274,6 +292,8 @@ def forward( block_tables, seqlen, max_s, + k_scale=self.k_scale, + v_scale=self.v_scale, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) From c732810f183757fd5894ac691e442788b8202182 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 25 Oct 2024 17:11:01 -0700 Subject: [PATCH 15/25] refactor according to Travis' comments --- server/lorax_server/layers/linear.py | 6 +++-- .../custom_modeling/flash_mistral_modeling.py | 14 ++++++----- server/lorax_server/models/flash_causal_lm.py | 4 ++-- server/lorax_server/utils/flash_attn.py | 7 +++++- server/lorax_server/utils/paged_attention.py | 5 ---- server/lorax_server/utils/torch_utils.py | 23 ++++++++++++------- server/lorax_server/utils/weights.py | 5 ++-- 7 files changed, 38 insertions(+), 26 deletions(-) diff --git a/server/lorax_server/layers/linear.py b/server/lorax_server/layers/linear.py index 68b4c9609..115f3bccf 100644 --- a/server/lorax_server/layers/linear.py +++ b/server/lorax_server/layers/linear.py @@ -3,6 +3,7 @@ from torch.nn import functional as F from lorax_server.utils.import_utils import SYSTEM +from lorax_server.utils.torch_utils import is_fp8 if SYSTEM == "rocm": try: @@ -95,9 +96,10 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False, weight_scale=None, if fan_in_fan_out: weight = weight.T.contiguous() - if quantize is None or (quantize.startswith("fp8") and weight_scale is None): + if quantize is None: linear = FastLinear(weight, bias) - elif quantize.startswith("fp8"): + + elif is_fp8(quantize): from lorax_server.layers.fp8 import Fp8Linear linear = Fp8Linear(weight, bias, weight_scale=weight_scale, input_scale=input_scale) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 538c42ea0..fe57836d5 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -53,6 +53,7 @@ UP_PROJ, V_PROJ, ) +from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized if not HAS_FLASH_ATTN_V2_CUDA: raise ImportError("Mistral model requires flash attn v2") @@ -205,7 +206,7 @@ def _load_gqa(config, prefix: str, weights, head_size): if type(weight) is tuple: weight, input_scale, weight_scale = weight - if config.quantize not in ["gptq", "awq", "fp8", "fp8_kv"]: + if not is_quantized(config.quantize): weight = weight.to(dtype=weights.dtype).to(device=weights.device) num_heads = config.num_attention_heads // weights.process_group.size() @@ -260,14 +261,14 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + if is_fp8_kv(config.quantize): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() - self.kv_dtype = 'fp8' + self.fp8_kv = True else: self.k_scale = 1.0 self.v_scale = 1.0 - self.kv_dtype = 'auto' + self.fp8_kv = False self.query_key_value = load_attention(config, prefix, weights, layer_id, self.head_size) @@ -358,14 +359,15 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - None if self.kv_dtype == 'fp8' else kv_cache[0], - None if self.kv_dtype == 'fp8' else kv_cache[1], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, window_size_left=self.max_past, k_scale=self.k_scale, v_scale=self.v_scale, + fp8_kv=self.fp8_kv, ) # Decode else: diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index eb83efb6c..763f2ce93 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -40,6 +40,7 @@ warmup_mode, ) from lorax_server.utils.tokenizer import TokenizerManager +from lorax_server.utils.torch_utils import is_fp8_kv from lorax_server.utils.weights import Weights ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.1")) @@ -956,8 +957,7 @@ def __init__( config = config_cls.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code) config.quantize = quantize - from lorax_server.utils import paged_attention - if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + if is_fp8_kv(config.quantize): self.kv_dtype = torch.float8_e4m3fn logger.info('Enabling FP8 KV Cache') else: diff --git a/server/lorax_server/utils/flash_attn.py b/server/lorax_server/utils/flash_attn.py index 56d3c97c3..a08583f9b 100644 --- a/server/lorax_server/utils/flash_attn.py +++ b/server/lorax_server/utils/flash_attn.py @@ -129,10 +129,11 @@ def attention( softcap=0.0, k_scale=1.0, v_scale=1.0, + fp8_kv=False, ): from lorax_server.utils.flashinfer_attention import prefill_state, prefill_with_paged_kv_state - if key_cache is None or value_cache is None: + if fp8_kv or (key_cache is None or value_cache is None): return prefill_state.get().forward( q, k, @@ -171,6 +172,7 @@ def attention( softcap=0.0, k_scale=1.0, v_scale=1.0, + fp8_kv=False, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -214,6 +216,7 @@ def attention( softcap=0.0, k_scale=1.0, v_scale=1.0, + fp8_kv=False, ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") @@ -257,6 +260,7 @@ def attention( softcap=0.0, k_scale=1.0, v_scale=1.0, + fp8_kv=False, ): out = torch.empty_like(q) output, _ = triton_attention( @@ -289,6 +293,7 @@ def attention( softcap=0.0, k_scale=1.0, v_scale=1.0, + fp8_kv=False, ): if window_size_left != -1: raise NotImplementedError("window_size_left is only available with flash attn v2") diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 5e9fbee75..fa13959f1 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -20,11 +20,6 @@ ) from e -def is_fp8_supported(): - return (torch.cuda.get_device_capability()[0] >= 9) \ - or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) - - def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn) qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) diff --git a/server/lorax_server/utils/torch_utils.py b/server/lorax_server/utils/torch_utils.py index c76d1c44c..18cb4dfda 100644 --- a/server/lorax_server/utils/torch_utils.py +++ b/server/lorax_server/utils/torch_utils.py @@ -10,11 +10,18 @@ def is_bf16_supported() -> bool: return torch.cuda.is_available() and torch.cuda.is_bf16_supported() -def is_fp8_quantized(config, layer_name): - # check if quantization is fp8 and either of the fused layers is not ignored - # typically, either all qkv will be quantized or none so just check for one - if config.quantize and config.quantize.startswith("fp8") and hasattr(config, "quantization_config"): - ignored_layers = set(config.quantization_config.get("ignored_layers", [])) - if layer_name not in ignored_layers: - return "fp8" - return None +def is_quantized(quantize): + return quantize and quantize in ["gptq", "awq", "fp8", "fp8_kv"] + + +def is_fp8_supported(): + return (torch.cuda.get_device_capability()[0] >= 9) \ + or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) + + +def is_fp8_kv(quantize): + return is_fp8_supported() and quantize and quantize == 'fp8_kv' + + +def is_fp8(quantize): + return is_fp8_supported() and quantize and quantize.startswith('fp8') diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index 56e632821..68147ed9a 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -12,6 +12,7 @@ from safetensors import safe_open from lorax_server.utils.sources import PBASE, S3, map_pbase_model_id_to_s3 +from lorax_server.utils.torch_utils import is_fp8 class AbstractWeights(ABC): @@ -119,7 +120,7 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: weight_list = self.get_sharded_list("weight", prefixes, dim=0) - if quantize and quantize.startswith("fp8") and weight_list[0].dtype == torch.float8_e4m3fn: + if is_fp8(quantize) and weight_list[0].dtype == torch.float8_e4m3fn: # Since there is no kernel for concatenating two tensors in PyTorch # for fp8 datatypes, we have to cast to fp16, concat, cast back to fp8 fp16_weight_list = [w.to(torch.float16) for w in weight_list] @@ -222,7 +223,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str): weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) - if quantize and quantize.startswith("fp8") and weight.dtype == torch.float8_e4m3fn: + if is_fp8(quantize) and weight.dtype == torch.float8_e4m3fn: # weight_scale could be a tensor but if we're sharding row-wise then no # need to shard the weight_scale as its row dimension would be 1 weight_scale = self.get_tensor(f"{prefix}.weight_scale", use_self_dtype=False) From 90e85f5a61492d1716372bc861f45b78ecddaa3c Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 25 Oct 2024 17:54:09 -0700 Subject: [PATCH 16/25] refactor llama and qwen models --- .../models/custom_modeling/flash_llama_modeling.py | 8 ++++++-- .../models/custom_modeling/flash_qwen2_modeling.py | 6 ++++-- .../models/custom_modeling/flash_qwen_modeling.py | 4 +++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index ee4b6175f..93cbb1df9 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -52,6 +52,7 @@ UP_PROJ, V_PROJ, ) +from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized class LlamaConfig(PretrainedConfig): @@ -200,7 +201,7 @@ def _load_gqa(config, prefix: str, weights): if isinstance(weight, tuple): weight, input_scale, weight_scale = weight - if config.quantize not in ["gptq", "awq", "fp8", "fp8_kv"]: + if not is_quantized(config.quantize): weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -252,7 +253,7 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + if is_fp8_kv(config.quantize): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() self.kv_dtype = 'fp8' @@ -349,6 +350,9 @@ def forward( cu_seqlen_prefill, max_s, self.softmax_scale, + k_scale=self.k_scale, + v_scale=self.v_scale, + fp8_kv=self.fp8_kv, ) # Decode else: diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 3cbd9ccf4..1b62ae3cf 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -29,6 +29,7 @@ get_linear, ) from lorax_server.utils.lora import LM_HEAD +from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized ATTN_Q_PROJ = "self_attn.q_proj" ATTN_K_PROJ = "self_attn.k_proj" @@ -133,7 +134,7 @@ def _load_gqa(config, prefix: str, weights): if isinstance(weight, tuple): weight, input_scale, weight_scale = weight - if config.quantize not in ["gptq", "awq", "fp8", "fp8_kv"]: + if not is_quantized(config.quantize): weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads @@ -192,7 +193,7 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + if is_fp8_kv(config.quantize): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() self.kv_dtype = 'fp8' @@ -277,6 +278,7 @@ def forward( window_size_left=self.max_past, k_scale=self.k_scale, v_scale=self.v_scale, + fp8_kv=self.fp8_kv, ) # Decode else: diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index da3eb581c..6af8ca2e2 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -29,6 +29,7 @@ TensorParallelRowLinear, ) from lorax_server.utils.lora import LM_HEAD +from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized ATTN_C_ATTN = "attn.c_attn" ATTN_C_PROJ = "attn.c_proj" @@ -200,7 +201,7 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = self.num_heads - if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): + if is_fp8_kv(config.quantize): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() self.kv_dtype = 'fp8' @@ -278,6 +279,7 @@ def forward( self.softmax_scale, k_scale=self.k_scale, v_scale=self.v_scale, + fp8_kv=self.fp8_kv, ) # Decode else: From be6e9fff834030d1441d6463ce09d5df71b31e1b Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 25 Oct 2024 17:56:11 -0700 Subject: [PATCH 17/25] fix flash_attn call bug in qwen --- .../models/custom_modeling/flash_qwen2_modeling.py | 4 ++-- .../models/custom_modeling/flash_qwen_modeling.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 1b62ae3cf..c6eb82f51 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -270,8 +270,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - None if self.kv_dtype == 'fp8' else kv_cache[0], - None if self.kv_dtype == 'fp8' else kv_cache[1], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 6af8ca2e2..849a30249 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -272,8 +272,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - None if self.kv_dtype == 'fp8' else kv_cache[0], - None if self.kv_dtype == 'fp8' else kv_cache[1], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, From f2f9a0801ae7a938bb8df798ac32653a17bda2e2 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 25 Oct 2024 17:57:01 -0700 Subject: [PATCH 18/25] fix flash_attn call bug in llama --- .../models/custom_modeling/flash_llama_modeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index 93cbb1df9..b3a804571 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -345,8 +345,8 @@ def forward( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - None if self.kv_dtype == 'fp8' else kv_cache[0], - None if self.kv_dtype == 'fp8' else kv_cache[1], + kv_cache[0], + kv_cache[1], cu_seqlen_prefill, max_s, self.softmax_scale, From 7974ff97648c27f9c5292174b238a966ff54da9c Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 25 Oct 2024 17:59:43 -0700 Subject: [PATCH 19/25] refactor flash_causal_lm for better git diff --- server/lorax_server/models/flash_causal_lm.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 763f2ce93..31e6b123d 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1289,19 +1289,18 @@ def _forward_context( query_dtype=self.dtype, window_left=self.sliding_window, ) - else: - return use_prefill_with_paged_kv_state( - state=(state if state is not None else self.prefill_with_paged_kv_state), - block_tables=block_tables, - cu_seqlens=cu_seqlen_prefill, - input_lengths=input_lengths_tensor + cache_lengths_tensor, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - page_size=BLOCK_SIZE, - dtype=self.dtype, - window_left=self.sliding_window, - ) + return use_prefill_with_paged_kv_state( + state=(state if state is not None else self.prefill_with_paged_kv_state), + block_tables=block_tables, + cu_seqlens=cu_seqlen_prefill, + input_lengths=input_lengths_tensor + cache_lengths_tensor, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + page_size=BLOCK_SIZE, + dtype=self.dtype, + window_left=self.sliding_window, + ) else: assert input_lengths_tensor is not None return use_decode_state( From cd29d93378c8ff4e48f79219e4a7742173a2522f Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 25 Oct 2024 18:05:17 -0700 Subject: [PATCH 20/25] refactor to use fp8_kv in reshape_and_cache --- .../models/custom_modeling/flash_llama_modeling.py | 3 ++- .../models/custom_modeling/flash_mistral_modeling.py | 3 ++- .../models/custom_modeling/flash_qwen2_modeling.py | 3 ++- .../lorax_server/models/custom_modeling/flash_qwen_modeling.py | 3 ++- server/lorax_server/utils/paged_attention.py | 3 ++- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index b3a804571..ebf255728 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -335,7 +335,8 @@ def forward( kv_cache[1], slots, self.k_scale, - self.v_scale + self.v_scale, + self.fp8_kv, ) # Prefill diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index fe57836d5..92ab41b1d 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -349,7 +349,8 @@ def forward( kv_cache[1], slots, self.k_scale, - self.v_scale + self.v_scale, + self.fp8_kv, ) # Prefill diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index c6eb82f51..490ab1acd 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -260,7 +260,8 @@ def forward( kv_cache[1], slots, self.k_scale, - self.v_scale + self.v_scale, + self.fp8_kv, ) # Prefill diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index 849a30249..dae9844b9 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -262,7 +262,8 @@ def forward( kv_cache[1], slots, self.k_scale, - self.v_scale + self.v_scale, + self.fp8_kv, ) # Prefill diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index fa13959f1..9e41d4595 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -34,9 +34,10 @@ def reshape_and_cache( slots: torch.Tensor, k_scale: float = 1.0, v_scale: float = 1.0, + fp8_kv: bool = False, ): if FLASH_INFER: - if key_cache.dtype == torch.float8_e4m3fn and value_cache.dtype == torch.float8_e4m3fn: + if fp8_kv: key = static_per_tensor_quantize(key, k_scale).view(torch.uint8) value = static_per_tensor_quantize(value, v_scale).view(torch.uint8) key_cache = key_cache.view(torch.uint8) From 488d42870546d8415bb7f79fd1ac17449b8aba44 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 25 Oct 2024 18:44:33 -0700 Subject: [PATCH 21/25] add support for gemma models and fix ruff error in qwen --- .../custom_modeling/flash_gemma2_modeling.py | 27 +++++++++++++++++-- .../custom_modeling/flash_gemma_modeling.py | 27 +++++++++++++++++-- .../custom_modeling/flash_qwen_modeling.py | 2 +- 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py index e51e26ed1..eec7fddf8 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py @@ -51,6 +51,7 @@ UP_PROJ, V_PROJ, ) +from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized class Gemma2Config(PretrainedConfig): @@ -170,7 +171,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: + if not is_quantized(config.quantize): weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim @@ -212,6 +213,14 @@ def __init__(self, layer_id: int, prefix: str, config, weights, causal: bool, is ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + if is_fp8_kv(config.quantize): + self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() + self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() + self.fp8_kv = True + else: + self.k_scale = 1.0 + self.v_scale = 1.0 + self.fp8_kv = False self.query_key_value = load_attention(config, prefix, weights, layer_id) @@ -257,7 +266,16 @@ def forward( self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache( + kv[:, 0], + kv[:, 1], + kv_cache[0], + kv_cache[1], + slots, + self.k_scale, + self.v_scale, + self.fp8_kv, + ) # Prefill if cu_seqlen_prefill is not None: @@ -273,6 +291,9 @@ def forward( self.softmax_scale, causal=self.causal, window_size_left=self.window_size, + k_scale=self.k_scale, + v_scale=self.v_scale, + fp8_kv=self.fp8_kv, ) # Decode else: @@ -286,6 +307,8 @@ def forward( block_tables, seqlen, max_s, + k_scale=self.k_scale, + v_scale=self.v_scale, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) diff --git a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py index 139f7e822..2e8b6cba9 100644 --- a/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gemma_modeling.py @@ -44,6 +44,7 @@ UP_PROJ, V_PROJ, ) +from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized class GemmaConfig(PretrainedConfig): @@ -153,7 +154,7 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq"]: + if not is_quantized(config.quantize): weight = weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim @@ -197,6 +198,14 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + if is_fp8_kv(config.quantize): + self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() + self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() + self.fp8_kv = True + else: + self.k_scale = 1.0 + self.v_scale = 1.0 + self.fp8_kv = False self.query_key_value = load_attention(config, prefix, weights, layer_id) @@ -264,7 +273,16 @@ def forward( self.rotary_emb(query, cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + paged_attention.reshape_and_cache( + kv[:, 0], + kv[:, 1], + kv_cache[0], + kv_cache[1], + slots, + self.k_scale, + self.v_scale, + self.fp8_kv, + ) # Prefill if cu_seqlen_prefill is not None: @@ -278,6 +296,9 @@ def forward( cu_seqlen_prefill, max_s, self.softmax_scale, + k_scale=self.k_scale, + v_scale=self.v_scale, + fp8_kv=self.fp8_kv, ) # Decode else: @@ -291,6 +312,8 @@ def forward( block_tables, seqlen, max_s, + k_scale=self.k_scale, + v_scale=self.v_scale, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index dae9844b9..ef9d5a052 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -29,7 +29,7 @@ TensorParallelRowLinear, ) from lorax_server.utils.lora import LM_HEAD -from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized +from lorax_server.utils.torch_utils import is_fp8_kv ATTN_C_ATTN = "attn.c_attn" ATTN_C_PROJ = "attn.c_proj" From 813b64b96cf1dff8be58d2f94b36a96d93ecbe76 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Mon, 28 Oct 2024 14:20:19 -0700 Subject: [PATCH 22/25] fix is_fp8_supported function --- server/lorax_server/utils/torch_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/lorax_server/utils/torch_utils.py b/server/lorax_server/utils/torch_utils.py index 18cb4dfda..bedb58730 100644 --- a/server/lorax_server/utils/torch_utils.py +++ b/server/lorax_server/utils/torch_utils.py @@ -15,8 +15,9 @@ def is_quantized(quantize): def is_fp8_supported(): - return (torch.cuda.get_device_capability()[0] >= 9) \ - or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) + return torch.cuda.is_available() and \ + (torch.cuda.get_device_capability()[0] >= 9) or \ + (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) def is_fp8_kv(quantize): From 5dfb5e56c736be4838c148fce569859927a33735 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Mon, 28 Oct 2024 16:14:15 -0700 Subject: [PATCH 23/25] fix fp8 error handling --- server/lorax_server/models/flash_causal_lm.py | 5 ++++- server/lorax_server/utils/torch_utils.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 31e6b123d..0b64f83a1 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -40,7 +40,7 @@ warmup_mode, ) from lorax_server.utils.tokenizer import TokenizerManager -from lorax_server.utils.torch_utils import is_fp8_kv +from lorax_server.utils.torch_utils import is_fp8_kv, is_fp8_supported, is_fp8 from lorax_server.utils.weights import Weights ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.1")) @@ -957,6 +957,9 @@ def __init__( config = config_cls.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code) config.quantize = quantize + if is_fp8(config.quantize) and not is_fp8_supported(): + raise ValueError('FP8 quantization is only supported on hardware that supports FP8') + if is_fp8_kv(config.quantize): self.kv_dtype = torch.float8_e4m3fn logger.info('Enabling FP8 KV Cache') diff --git a/server/lorax_server/utils/torch_utils.py b/server/lorax_server/utils/torch_utils.py index bedb58730..ab33cebce 100644 --- a/server/lorax_server/utils/torch_utils.py +++ b/server/lorax_server/utils/torch_utils.py @@ -21,8 +21,8 @@ def is_fp8_supported(): def is_fp8_kv(quantize): - return is_fp8_supported() and quantize and quantize == 'fp8_kv' + return quantize and quantize == 'fp8_kv' def is_fp8(quantize): - return is_fp8_supported() and quantize and quantize.startswith('fp8') + return quantize and quantize.startswith('fp8') From bca4daf90abc7132ed7f9bb9c8cde233b8b9e5e0 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Mon, 28 Oct 2024 16:15:45 -0700 Subject: [PATCH 24/25] keep ruff happy --- server/lorax_server/models/flash_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 0b64f83a1..d39909de5 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -40,7 +40,7 @@ warmup_mode, ) from lorax_server.utils.tokenizer import TokenizerManager -from lorax_server.utils.torch_utils import is_fp8_kv, is_fp8_supported, is_fp8 +from lorax_server.utils.torch_utils import is_fp8, is_fp8_kv, is_fp8_supported from lorax_server.utils.weights import Weights ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.1")) From 422d9a47b3553e4c23d64665ef7f26983ed9cf58 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Mon, 28 Oct 2024 17:52:57 -0700 Subject: [PATCH 25/25] rename fp8_kv quantization to fp8-kv --- launcher/src/main.rs | 2 +- server/lorax_server/cli.py | 2 +- server/lorax_server/layers/tensor_parallel.py | 2 +- server/lorax_server/models/flash_causal_lm.py | 4 +++- server/lorax_server/utils/torch_utils.py | 4 ++-- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 1fdc095c7..bac9f04a8 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -70,7 +70,7 @@ impl std::fmt::Display for Quantization { write!(f, "fp8") } Quantization::Fp8_KV => { - write!(f, "fp8_kv") + write!(f, "fp8-kv") } } } diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index bca7400c0..c7f2c7c5d 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -23,7 +23,7 @@ class Quantization(str, Enum): hqq_3bit = "hqq-3bit" hqq_2bit = "hqq-2bit" fp8 = "fp8" - fp8_kv = "fp8_kv" + fp8_kv = "fp8-kv" class Dtype(str, Enum): diff --git a/server/lorax_server/layers/tensor_parallel.py b/server/lorax_server/layers/tensor_parallel.py index c2a140fad..a6e0a01f6 100644 --- a/server/lorax_server/layers/tensor_parallel.py +++ b/server/lorax_server/layers/tensor_parallel.py @@ -37,7 +37,7 @@ def load(config, prefix: str, weights): should_gather = False # GPTQ,AWQ,EETQ don't quantize heads (nor embeddings) - if config.quantize in ["gptq", "awq", "eetq", "fp8", "fp8_kv"]: + if config.quantize in ["gptq", "awq", "eetq", "fp8", "fp8-kv"]: quantize = None else: quantize = config.quantize diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index d39909de5..1c74cbcac 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -961,8 +961,10 @@ def __init__( raise ValueError('FP8 quantization is only supported on hardware that supports FP8') if is_fp8_kv(config.quantize): + if not FLASH_INFER: + raise ValueError('FP8 KV cache requires FLASH_INFER backend') self.kv_dtype = torch.float8_e4m3fn - logger.info('Enabling FP8 KV Cache') + logger.info('Enabling FP8 KV cache. Prefix caching will not work.') else: self.kv_dtype = dtype diff --git a/server/lorax_server/utils/torch_utils.py b/server/lorax_server/utils/torch_utils.py index ab33cebce..682402cc1 100644 --- a/server/lorax_server/utils/torch_utils.py +++ b/server/lorax_server/utils/torch_utils.py @@ -11,7 +11,7 @@ def is_bf16_supported() -> bool: def is_quantized(quantize): - return quantize and quantize in ["gptq", "awq", "fp8", "fp8_kv"] + return quantize and quantize in ["gptq", "awq", "fp8", "fp8-kv"] def is_fp8_supported(): @@ -21,7 +21,7 @@ def is_fp8_supported(): def is_fp8_kv(quantize): - return quantize and quantize == 'fp8_kv' + return quantize and quantize == 'fp8-kv' def is_fp8(quantize):