Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP8 KV Cache #652

Merged
merged 27 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6f887aa
(feat) : support fp8 kv cache
ajtejankar Oct 1, 2024
93fc0d1
fix a few things
ajtejankar Oct 17, 2024
939b479
add support for fp8 kv cache using flash infer
ajtejankar Oct 17, 2024
56589ed
add logging
ajtejankar Oct 18, 2024
9590a71
merge back previous code to provide fp8_kv as a quantization option
ajtejankar Oct 18, 2024
1517b16
remove unnecessary comments
ajtejankar Oct 18, 2024
a5f1c25
remove is_fp8_kv_supported function
ajtejankar Oct 18, 2024
52021c0
fix attention api and kv_dtype
ajtejankar Oct 18, 2024
7dd10f3
keep ruff happy
ajtejankar Oct 18, 2024
787b58e
use fp16 prefill with fp8 kv
ajtejankar Oct 24, 2024
8dae2d8
use fp16 prefill for fp8 kv cache (without prefix caching)
ajtejankar Oct 25, 2024
8540461
Merge branch 'main' into fp8-kv-flash-infer
ajtejankar Oct 25, 2024
d880464
add window_left option
ajtejankar Oct 25, 2024
bb75730
fix merge conflicts
ajtejankar Oct 25, 2024
3031ba4
move paged_attention import location
ajtejankar Oct 25, 2024
1742cde
add llama and qwen models
ajtejankar Oct 25, 2024
c732810
refactor according to Travis' comments
ajtejankar Oct 26, 2024
90e85f5
refactor llama and qwen models
ajtejankar Oct 26, 2024
be6e9ff
fix flash_attn call bug in qwen
ajtejankar Oct 26, 2024
f2f9a08
fix flash_attn call bug in llama
ajtejankar Oct 26, 2024
7974ff9
refactor flash_causal_lm for better git diff
ajtejankar Oct 26, 2024
cd29d93
refactor to use fp8_kv in reshape_and_cache
ajtejankar Oct 26, 2024
488d428
add support for gemma models and fix ruff error in qwen
ajtejankar Oct 26, 2024
813b64b
fix is_fp8_supported function
ajtejankar Oct 28, 2024
5dfb5e5
fix fp8 error handling
ajtejankar Oct 28, 2024
bca4daf
keep ruff happy
ajtejankar Oct 28, 2024
422d9a4
rename fp8_kv quantization to fp8-kv
ajtejankar Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum Quantization {
Hqq_3bit,
Hqq_2bit,
Fp8,
Fp8_KV,
}

impl std::fmt::Display for Quantization {
Expand Down Expand Up @@ -68,6 +69,9 @@ impl std::fmt::Display for Quantization {
Quantization::Fp8 => {
write!(f, "fp8")
}
Quantization::Fp8_KV => {
write!(f, "fp8_kv")
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class Quantization(str, Enum):
hqq_3bit = "hqq-3bit"
hqq_2bit = "hqq-2bit"
fp8 = "fp8"
fp8_kv = "fp8_kv"


class Dtype(str, Enum):
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False, weight_scale=None,
if fan_in_fan_out:
weight = weight.T.contiguous()

if quantize is None or (quantize == "fp8" and weight_scale is None):
if quantize is None or (quantize.startswith("fp8") and weight_scale is None):
linear = FastLinear(weight, bias)
elif quantize == "fp8":
elif quantize.startswith("fp8"):
from lorax_server.layers.fp8 import Fp8Linear

linear = Fp8Linear(weight, bias, weight_scale=weight_scale, input_scale=input_scale)
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/layers/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def load(config, prefix: str, weights):
should_gather = False

# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq", "fp8"]:
if config.quantize in ["gptq", "awq", "eetq", "fp8", "fp8_kv"]:
quantize = None
else:
quantize = config.quantize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _load_gqa(config, prefix: str, weights, head_size):
if type(weight) is tuple:
weight, input_scale, weight_scale = weight

if config.quantize not in ["gptq", "awq", "fp8"]:
if config.quantize not in ["gptq", "awq", "fp8", "fp8_kv"]:
ajtejankar marked this conversation as resolved.
Show resolved Hide resolved
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

num_heads = config.num_attention_heads // weights.process_group.size()
Expand Down Expand Up @@ -259,6 +259,12 @@ def __init__(
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'):
ajtejankar marked this conversation as resolved.
Show resolved Hide resolved
self.k_scale = weights.get_tensor(f"{prefix}.k_scale", use_self_dtype=False).item()
self.v_scale = weights.get_tensor(f"{prefix}.v_scale", use_self_dtype=False).item()
else:
self.k_scale = 1.0
self.v_scale = 1.0

self.query_key_value = load_attention(config, prefix, weights, layer_id, self.head_size)

Expand Down Expand Up @@ -332,7 +338,15 @@ def forward(
else:
kv_to_cache = kv

paged_attention.reshape_and_cache(kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots)
paged_attention.reshape_and_cache(
kv_to_cache[:, 0],
kv_to_cache[:, 1],
kv_cache[0],
kv_cache[1],
slots,
self.k_scale,
self.v_scale
)

# Prefill
if cu_seqlen_prefill is not None:
Expand All @@ -347,6 +361,8 @@ def forward(
max_s,
self.softmax_scale,
window_size_left=self.max_past,
k_scale=self.k_scale,
v_scale=self.v_scale,
)
# Decode
else:
Expand All @@ -360,6 +376,8 @@ def forward(
block_tables,
input_lengths,
max_s,
k_scale=self.k_scale,
v_scale=self.v_scale,
)

return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)
Expand Down
11 changes: 8 additions & 3 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
PrefillTokens,
)
from lorax_server.pb import generate_pb2
from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria
from lorax_server.utils import HeterogeneousNextTokenChooser, StoppingCriteria, paged_attention
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, create_merged_weight_files
from lorax_server.utils.attention.utils import block_tables_to_ragged
from lorax_server.utils.dist import MEMORY_FRACTION, MEMORY_WIGGLE_ROOM, initialize_torch_distributed
Expand Down Expand Up @@ -811,6 +811,11 @@ def __init__(

config = config_cls.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code)
config.quantize = quantize
if paged_attention.is_fp8_supported() and config.quantize and config.quantize.endswith('_kv'):
ajtejankar marked this conversation as resolved.
Show resolved Hide resolved
self.kv_dtype = torch.float8_e4m3fn
logger.info('Enabling FP8 KV Cache')
else:
self.kv_dtype = dtype

torch.distributed.barrier(group=self.process_group)

Expand Down Expand Up @@ -1006,7 +1011,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.kv_dtype,
self.device,
)

Expand Down Expand Up @@ -1085,7 +1090,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.kv_dtype,
self.device,
)

Expand Down
12 changes: 12 additions & 0 deletions server/lorax_server/utils/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def attention(
window_size_left=-1,
causal=True,
softcap=0.0,
k_scale=1.0,
v_scale=1.0,
):
assert window_size_left == -1, "Windowing is not supported with flash infer when using kv cache"
from lorax_server.utils.flashinfer_attention import prefill_state, prefill_with_paged_kv_state
Expand All @@ -149,6 +151,8 @@ def attention(
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
k_scale=k_scale,
v_scale=v_scale,
)

elif HAS_FLASH_ATTN_V2_CUDA:
Expand All @@ -165,6 +169,8 @@ def attention(
window_size_left=-1,
causal=True,
softcap=0.0,
k_scale=1.0,
v_scale=1.0,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
Expand Down Expand Up @@ -206,6 +212,8 @@ def attention(
window_size_left=-1,
causal=True,
softcap=0.0,
k_scale=1.0,
v_scale=1.0,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
Expand Down Expand Up @@ -247,6 +255,8 @@ def attention(
window_size_left=-1,
causal=True,
softcap=0.0,
k_scale=1.0,
v_scale=1.0,
):
out = torch.empty_like(q)
output, _ = triton_attention(
Expand Down Expand Up @@ -277,6 +287,8 @@ def attention(
window_size_left=-1,
causal=True,
softcap=0.0,
k_scale=1.0,
v_scale=1.0,
):
if window_size_left != -1:
raise NotImplementedError("window_size_left is only available with flash attn v2")
Expand Down
36 changes: 24 additions & 12 deletions server/lorax_server/utils/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
f"Could not import vllm paged attention. Make sure your installation is correct. Error: {e}"
) from e

# TODO(travis): fix for CUDA 8.9 (Lovelace) and 9.0 (Hopper)
# if torch.cuda.is_available():
# fp8_supported = (
# torch.cuda.get_device_capability()[0] >= 9
# ) # or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9)
# else:
fp8_supported = False

def is_fp8_supported():
return (torch.cuda.get_device_capability()[0] >= 9) \
or (torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9)


def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)


def reshape_and_cache(
Expand All @@ -33,17 +36,22 @@ def reshape_and_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
if FLASH_INFER:
if key_cache.dtype == torch.float8_e4m3fn and value_cache.dtype == torch.float8_e4m3fn:
ajtejankar marked this conversation as resolved.
Show resolved Hide resolved
key = static_per_tensor_quantize(key, k_scale).view(torch.uint8)
value = static_per_tensor_quantize(value, v_scale).view(torch.uint8)
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
elif SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, value_cache, slots)
else:
torch.ops._C_cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "fp8" if fp8_supported else "auto", 1.0, 1.0
)
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, 'auto', 1.0, 1.0)


def attention(
Expand All @@ -57,6 +65,8 @@ def attention(
input_lengths: torch.Tensor,
max_s: int,
softcap: Optional[float] = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
if FLASH_INFER:
from lorax_server.utils.flashinfer_attention import decode_state
Expand All @@ -66,6 +76,8 @@ def attention(
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
k_scale=k_scale,
v_scale=v_scale,
)

# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
Expand Down Expand Up @@ -128,7 +140,7 @@ def attention(
block_size,
max_s,
None,
"fp8" if fp8_supported else "auto",
'auto',
1.0,
1.0,
)
Expand Down Expand Up @@ -162,7 +174,7 @@ def attention(
block_size,
max_s,
None,
"fp8" if fp8_supported else "auto",
'auto',
1.0,
1.0,
)
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def is_bf16_supported() -> bool:
def is_fp8_quantized(config, layer_name):
# check if quantization is fp8 and either of the fused layers is not ignored
# typically, either all qkv will be quantized or none so just check for one
if config.quantize == "fp8" and hasattr(config, "quantization_config"):
if config.quantize and config.quantize.startswith("fp8") and hasattr(config, "quantization_config"):
ajtejankar marked this conversation as resolved.
Show resolved Hide resolved
ignored_layers = set(config.quantization_config.get("ignored_layers", []))
if layer_name not in ignored_layers:
return "fp8"
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_multi_weights_col(self, prefixes: List[Union[str, Tuple]], quantize: str
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
weight_list = self.get_sharded_list("weight", prefixes, dim=0)
if quantize == "fp8" and weight_list[0].dtype == torch.float8_e4m3fn:
if quantize and quantize.startswith("fp8") and weight_list[0].dtype == torch.float8_e4m3fn:
# Since there is no kernel for concatenating two tensors in PyTorch
# for fp8 datatypes, we have to cast to fp16, concat, cast back to fp8
fp16_weight_list = [w.to(torch.float16) for w in weight_list]
Expand Down Expand Up @@ -222,7 +222,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
if quantize == "fp8" and weight.dtype == torch.float8_e4m3fn:
if quantize and quantize.startswith("fp8") and weight.dtype == torch.float8_e4m3fn:
# weight_scale could be a tensor but if we're sharding row-wise then no
# need to shard the weight_scale as its row dimension would be 1
weight_scale = self.get_tensor(f"{prefix}.weight_scale", use_self_dtype=False)
Expand Down
Loading