From a5f1c253ecbeb8630deb4ec78d4cf9ef2d23320d Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Fri, 18 Oct 2024 11:11:36 -0700 Subject: [PATCH] remove is_fp8_kv_supported function --- .../models/custom_modeling/flash_mistral_modeling.py | 2 +- server/lorax_server/models/flash_causal_lm.py | 8 +++----- server/lorax_server/utils/paged_attention.py | 4 ---- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index 3520803d5..deb992561 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -260,7 +260,7 @@ def __init__( ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - if paged_attention.is_fp8_kv_supported(config.quantize): + if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() else: diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index d93669fea..2cda456c7 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -958,13 +958,11 @@ def init_kv_cache( element_size = torch.tensor([], dtype=dtype).element_size() x = BLOCK_SIZE // element_size - - if paged_attention.is_fp8_kv_supported(self.config.quantize): - kv_dtype = torch.float8_e4m3fn - else: - kv_dtype = dtype + kv_dtype = dtype if FLASH_INFER: + if paged_attention.is_fp8_supported() and self.config.quantize and self.config.quantize.endswith('_kv'): + kv_dtype = torch.float8_e4m3fn self.kv_cache = [ ( torch.empty( diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index baa9e503c..93279a20a 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -25,10 +25,6 @@ def is_fp8_supported(): or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9) -def is_fp8_kv_supported(quantization_type): - return FLASH_INFER and is_fp8_supported() and quantization_type and quantization_type.endswith('_kv') - - def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn) qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)