Skip to content

Commit

Permalink
remove is_fp8_kv_supported function
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar committed Oct 18, 2024
1 parent 1517b16 commit a5f1c25
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def __init__(
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
if paged_attention.is_fp8_kv_supported(config.quantize):
if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
else:
Expand Down
8 changes: 3 additions & 5 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,13 +958,11 @@ def init_kv_cache(

element_size = torch.tensor([], dtype=dtype).element_size()
x = BLOCK_SIZE // element_size

if paged_attention.is_fp8_kv_supported(self.config.quantize):
kv_dtype = torch.float8_e4m3fn
else:
kv_dtype = dtype
kv_dtype = dtype

if FLASH_INFER:
if paged_attention.is_fp8_supported() and self.config.quantize and self.config.quantize.endswith('_kv'):
kv_dtype = torch.float8_e4m3fn
self.kv_cache = [
(
torch.empty(
Expand Down
4 changes: 0 additions & 4 deletions server/lorax_server/utils/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ def is_fp8_supported():
or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9)


def is_fp8_kv_supported(quantization_type):
return FLASH_INFER and is_fp8_supported() and quantization_type and quantization_type.endswith('_kv')


def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
Expand Down

0 comments on commit a5f1c25

Please sign in to comment.