Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP8 KV Cache #652

Merged
merged 27 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6f887aa
(feat) : support fp8 kv cache
ajtejankar Oct 1, 2024
93fc0d1
fix a few things
ajtejankar Oct 17, 2024
939b479
add support for fp8 kv cache using flash infer
ajtejankar Oct 17, 2024
56589ed
add logging
ajtejankar Oct 18, 2024
9590a71
merge back previous code to provide fp8_kv as a quantization option
ajtejankar Oct 18, 2024
1517b16
remove unnecessary comments
ajtejankar Oct 18, 2024
a5f1c25
remove is_fp8_kv_supported function
ajtejankar Oct 18, 2024
52021c0
fix attention api and kv_dtype
ajtejankar Oct 18, 2024
7dd10f3
keep ruff happy
ajtejankar Oct 18, 2024
787b58e
use fp16 prefill with fp8 kv
ajtejankar Oct 24, 2024
8dae2d8
use fp16 prefill for fp8 kv cache (without prefix caching)
ajtejankar Oct 25, 2024
8540461
Merge branch 'main' into fp8-kv-flash-infer
ajtejankar Oct 25, 2024
d880464
add window_left option
ajtejankar Oct 25, 2024
bb75730
fix merge conflicts
ajtejankar Oct 25, 2024
3031ba4
move paged_attention import location
ajtejankar Oct 25, 2024
1742cde
add llama and qwen models
ajtejankar Oct 25, 2024
c732810
refactor according to Travis' comments
ajtejankar Oct 26, 2024
90e85f5
refactor llama and qwen models
ajtejankar Oct 26, 2024
be6e9ff
fix flash_attn call bug in qwen
ajtejankar Oct 26, 2024
f2f9a08
fix flash_attn call bug in llama
ajtejankar Oct 26, 2024
7974ff9
refactor flash_causal_lm for better git diff
ajtejankar Oct 26, 2024
cd29d93
refactor to use fp8_kv in reshape_and_cache
ajtejankar Oct 26, 2024
488d428
add support for gemma models and fix ruff error in qwen
ajtejankar Oct 26, 2024
813b64b
fix is_fp8_supported function
ajtejankar Oct 28, 2024
5dfb5e5
fix fp8 error handling
ajtejankar Oct 28, 2024
bca4daf
keep ruff happy
ajtejankar Oct 28, 2024
422d9a4
rename fp8_kv quantization to fp8-kv
ajtejankar Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum Quantization {
Hqq_3bit,
Hqq_2bit,
Fp8,
Fp8_KV,
}

impl std::fmt::Display for Quantization {
Expand Down Expand Up @@ -68,6 +69,9 @@ impl std::fmt::Display for Quantization {
Quantization::Fp8 => {
write!(f, "fp8")
}
Quantization::Fp8_KV => {
write!(f, "fp8-kv")
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Quantization(str, Enum):
hqq_3bit = "hqq-3bit"
hqq_2bit = "hqq-2bit"
fp8 = "fp8"
fp8_kv = "fp8-kv"


class Dtype(str, Enum):
Expand Down
6 changes: 4 additions & 2 deletions server/lorax_server/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.nn import functional as F

from lorax_server.utils.import_utils import SYSTEM
from lorax_server.utils.torch_utils import is_fp8

if SYSTEM == "rocm":
try:
Expand Down Expand Up @@ -95,9 +96,10 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False, weight_scale=None,
if fan_in_fan_out:
weight = weight.T.contiguous()

if quantize is None or (quantize == "fp8" and weight_scale is None):
if quantize is None:
linear = FastLinear(weight, bias)
elif quantize == "fp8":

elif is_fp8(quantize):
from lorax_server.layers.fp8 import Fp8Linear

linear = Fp8Linear(weight, bias, weight_scale=weight_scale, input_scale=input_scale)
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/layers/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def load(config, prefix: str, weights):
should_gather = False

# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq", "fp8"]:
if config.quantize in ["gptq", "awq", "eetq", "fp8", "fp8-kv"]:
quantize = None
else:
quantize = config.quantize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
UP_PROJ,
V_PROJ,
)
from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized


class Gemma2Config(PretrainedConfig):
Expand Down Expand Up @@ -170,7 +171,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)

if config.quantize not in ["gptq", "awq", "marlin"]:
if not is_quantized(config.quantize):
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

head_size = config.head_dim
Expand Down Expand Up @@ -212,6 +213,14 @@ def __init__(self, layer_id: int, prefix: str, config, weights, causal: bool, is
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.fp8_kv = True
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.fp8_kv = False

self.query_key_value = load_attention(config, prefix, weights, layer_id)

Expand Down Expand Up @@ -257,7 +266,16 @@ def forward(

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)

paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
paged_attention.reshape_and_cache(
kv[:, 0],
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.k_scale,
self.v_scale,
self.fp8_kv,
)

# Prefill
if cu_seqlen_prefill is not None:
Expand All @@ -273,6 +291,9 @@ def forward(
self.softmax_scale,
causal=self.causal,
window_size_left=self.window_size,
k_scale=self.k_scale,
v_scale=self.v_scale,
fp8_kv=self.fp8_kv,
)
# Decode
else:
Expand All @@ -286,6 +307,8 @@ def forward(
block_tables,
seqlen,
max_s,
k_scale=self.k_scale,
v_scale=self.v_scale,
)

return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)
Expand Down
27 changes: 25 additions & 2 deletions server/lorax_server/models/custom_modeling/flash_gemma_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
UP_PROJ,
V_PROJ,
)
from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized


class GemmaConfig(PretrainedConfig):
Expand Down Expand Up @@ -153,7 +154,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0,
)

if config.quantize not in ["gptq", "awq"]:
if not is_quantized(config.quantize):
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

head_size = config.head_dim
Expand Down Expand Up @@ -197,6 +198,14 @@ def __init__(
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.fp8_kv = True
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.fp8_kv = False

self.query_key_value = load_attention(config, prefix, weights, layer_id)

Expand Down Expand Up @@ -264,7 +273,16 @@ def forward(
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)

paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
paged_attention.reshape_and_cache(
kv[:, 0],
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.k_scale,
self.v_scale,
self.fp8_kv,
)

# Prefill
if cu_seqlen_prefill is not None:
Expand All @@ -278,6 +296,9 @@ def forward(
cu_seqlen_prefill,
max_s,
self.softmax_scale,
k_scale=self.k_scale,
v_scale=self.v_scale,
fp8_kv=self.fp8_kv,
)
# Decode
else:
Expand All @@ -291,6 +312,8 @@ def forward(
block_tables,
seqlen,
max_s,
k_scale=self.k_scale,
v_scale=self.v_scale,
)

return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)
Expand Down
27 changes: 25 additions & 2 deletions server/lorax_server/models/custom_modeling/flash_llama_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
UP_PROJ,
V_PROJ,
)
from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized


class LlamaConfig(PretrainedConfig):
Expand Down Expand Up @@ -200,7 +201,7 @@ def _load_gqa(config, prefix: str, weights):
if isinstance(weight, tuple):
weight, input_scale, weight_scale = weight

if config.quantize not in ["gptq", "awq", "fp8"]:
if not is_quantized(config.quantize):
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

head_size = config.hidden_size // config.num_attention_heads
Expand Down Expand Up @@ -252,6 +253,14 @@ def __init__(
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.kv_dtype = 'fp8'
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.kv_dtype = 'auto'

self.query_key_value = load_attention(config, prefix, weights, layer_id)

Expand Down Expand Up @@ -319,7 +328,16 @@ def forward(
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)

paged_attention.reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
paged_attention.reshape_and_cache(
kv[:, 0],
kv[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.k_scale,
self.v_scale,
self.fp8_kv,
)

# Prefill
if cu_seqlen_prefill is not None:
Expand All @@ -333,6 +351,9 @@ def forward(
cu_seqlen_prefill,
max_s,
self.softmax_scale,
k_scale=self.k_scale,
v_scale=self.v_scale,
fp8_kv=self.fp8_kv,
)
# Decode
else:
Expand All @@ -347,6 +368,8 @@ def forward(
block_tables,
seqlen,
max_s,
k_scale=self.k_scale,
v_scale=self.v_scale,
)

return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
UP_PROJ,
V_PROJ,
)
from lorax_server.utils.torch_utils import is_fp8_kv, is_quantized

if not HAS_FLASH_ATTN_V2_CUDA:
raise ImportError("Mistral model requires flash attn v2")
Expand Down Expand Up @@ -205,7 +206,7 @@ def _load_gqa(config, prefix: str, weights, head_size):
if type(weight) is tuple:
weight, input_scale, weight_scale = weight

if config.quantize not in ["gptq", "awq", "fp8"]:
if not is_quantized(config.quantize):
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

num_heads = config.num_attention_heads // weights.process_group.size()
Expand Down Expand Up @@ -260,6 +261,14 @@ def __init__(
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
if is_fp8_kv(config.quantize):
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
self.fp8_kv = True
else:
self.k_scale = 1.0
self.v_scale = 1.0
self.fp8_kv = False

self.query_key_value = load_attention(config, prefix, weights, layer_id, self.head_size)

Expand Down Expand Up @@ -333,7 +342,16 @@ def forward(
else:
kv_to_cache = kv

paged_attention.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots)
paged_attention.reshape_and_cache(
kv_to_cache[:, 0],
kv_to_cache[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.k_scale,
self.v_scale,
self.fp8_kv,
)

# Prefill
if cu_seqlen_prefill is not None:
Expand All @@ -348,6 +366,9 @@ def forward(
max_s,
self.softmax_scale,
window_size_left=self.max_past,
k_scale=self.k_scale,
v_scale=self.v_scale,
fp8_kv=self.fp8_kv,
)
# Decode
else:
Expand All @@ -361,6 +382,8 @@ def forward(
block_tables,
seqlen,
max_s,
k_scale=self.k_scale,
v_scale=self.v_scale,
)

return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)
Expand Down
Loading
Loading