diff --git a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py index ebf255728..b308edc5d 100644 --- a/server/lorax_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_llama_modeling.py @@ -256,11 +256,11 @@ def __init__( if is_fp8_kv(config.quantize): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() - self.kv_dtype = 'fp8' + self.fp8_kv = True else: self.k_scale = 1.0 self.v_scale = 1.0 - self.kv_dtype = 'auto' + self.fp8_kv = False self.query_key_value = load_attention(config, prefix, weights, layer_id) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 490ab1acd..068d2e0b6 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -196,11 +196,11 @@ def __init__( if is_fp8_kv(config.quantize): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() - self.kv_dtype = 'fp8' + self.fp8_kv = True else: self.k_scale = 1.0 self.v_scale = 1.0 - self.kv_dtype = 'auto' + self.fp8_kv = False self.query_key_value = load_attention(config, prefix, weights, layer_id) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py index ef9d5a052..606248af1 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen_modeling.py @@ -204,11 +204,11 @@ def __init__( if is_fp8_kv(config.quantize): self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item() self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item() - self.kv_dtype = 'fp8' + self.fp8_kv = True else: self.k_scale = 1.0 self.v_scale = 1.0 - self.kv_dtype = 'auto' + self.fp8_kv = False self.c_attn = load_attention(config, prefix, weights, layer_id)