Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support calibrating kv cache scales #17

Merged
merged 10 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions auto_fp8/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional, Tuple


class BaseQuantizeConfig:
Expand All @@ -17,13 +17,17 @@ class BaseQuantizeConfig:
regex style matching i.e. re.search(), for each Linear layer.
By default, "re:.*lm_head" is included to ignore the embedding
Linear layer usually at the end of decoder LLMs
kv_cache_quant_targets: Tuple of Linear module names to target for
calibration of the output scales for KV cache quantization.
Usually, these should be `("k_proj", "v_proj")`.
"""

def __init__(
self,
quant_method: str = "fp8",
activation_scheme: str = "static",
ignore_patterns: List[str] = [],
ignore_patterns: List[str] = ["re:.*lm_head"],
kv_cache_quant_targets: Optional[Tuple[str]] = None,
):
if quant_method != "fp8":
raise ValueError("Only FP8 quantization is supported.")
Expand All @@ -34,4 +38,5 @@ def __init__(
self.quant_method = quant_method
self.activation_scheme = activation_scheme
self.ignore_patterns = ignore_patterns
self.kv_cache_quant_targets = kv_cache_quant_targets
self.ignored_layers = []
49 changes: 36 additions & 13 deletions auto_fp8/modeling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import List
from typing import List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -27,6 +27,16 @@ def __init__(
self.model, quantize_config.ignore_patterns
)

if quantize_config.kv_cache_quant_targets:
kv_cache_quant_layers = get_kv_cache_quant_layers(
self.model, quantize_config.kv_cache_quant_targets
)
if len(kv_cache_quant_layers) == 0:
raise ValueError(
f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument."
)
quantize_config.kv_cache_quant_layers = kv_cache_quant_layers

self.quantize_config = quantize_config

@classmethod
Expand Down Expand Up @@ -97,26 +107,28 @@ def skip(*args, **kwargs):

return cls(model, quantize_config)

def quantize(self, calibration_tokens):
def _prepare_calibration_data(calibration_tokens):
if hasattr(calibration_tokens, "input_ids"):
return calibration_tokens.input_ids
return calibration_tokens
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):

# Always quantize the weights as they do not require calibration data
quantize_weights(self.model, self.quantize_config)

if self.quantize_config.activation_scheme == "static":
assert (
calibration_tokens is not None
), "Calibration tokens required for activation quantization"


def _prepare_calibration_data(calibration_tokens):
if hasattr(calibration_tokens, "input_ids"):
return calibration_tokens.input_ids
return calibration_tokens

quantize_activations(
self.model,
self.quantize_config,
_prepare_calibration_data(calibration_tokens),
)

# import copy
# for layer in self.model.model.layers:
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.input_scale)

def save_quantized(self, save_dir):
save_quantized_model(
self.model,
Expand All @@ -128,9 +140,6 @@ def save_quantized(self, save_dir):
def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers = set()

# TODO: don't always ignore lm_head
ignore_patterns.append("re:.*lm_head")

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue
Expand All @@ -148,3 +157,17 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers.add(name)

return list(ignored_layers)


def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
kv_cache_quant_layers = []

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue

for output_quant_target in kv_cache_quant_targets:
if name.endswith(output_quant_target):
kv_cache_quant_layers.append(name)

return kv_cache_quant_layers
156 changes: 109 additions & 47 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc
import re
from typing import List, Tuple
from typing import Optional, Tuple
import copy

import torch
Expand Down Expand Up @@ -61,14 +61,22 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
return qweight, scale


def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)


def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
if A.numel() == 0:
# Deal with empty tensors (triggeted by empty MoE experts)
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)

native_fp8_support = (
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
)

# TODO: Disable native fp8 gemm for now, always just dequantize
# native_fp8_support = (
# torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
# )
native_fp8_support = False
if native_fp8_support:
need_reshape = A.dim() == 3
if need_reshape:
Expand Down Expand Up @@ -98,25 +106,24 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
return output


class FP8StaticLinearQuantizer(torch.nn.Module):
# Class responsible for quantizing weights
class FP8DynamicLinear(torch.nn.Module):
def __init__(
self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor
self,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.nn.Parameter,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.input_scale = None
self.bias = bias

def forward(self, x):
qinput, x_input_scale = per_tensor_quantize(x)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale)
qinput, x_scale = per_tensor_quantize(x)
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
A_scale=x_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
Expand All @@ -125,29 +132,29 @@ def forward(self, x):
return output


class FP8StaticLinear(torch.nn.Module):
# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
class FP8StaticLinearQuantizer(torch.nn.Module):
def __init__(
self,
qweight: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.Tensor,
input_scale: float = 1.0,
bias: torch.nn.Parameter,
quantize_output: bool = False,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight = torch.nn.Parameter(weight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False)
self.bias = bias

def per_tensor_quantize(
self, tensor: torch.Tensor, inv_scale: float
) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)
self.input_scale = None
self.output_scale = None
self.quantize_output = quantize_output

def forward(self, x):
qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale)
qinput, x_input_scale = per_tensor_quantize(x)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
Expand All @@ -156,26 +163,51 @@ def forward(self, x):
bias=self.bias,
out_dtype=x.dtype,
)

# Optionally, quantize output and record scale
if self.quantize_output:
qoutput, output_scale = per_tensor_quantize(output)
if self.output_scale is None:
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
elif output_scale > self.output_scale:
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
output = qoutput.to(output.dtype) * output_scale

return output


class FP8DynamicLinear(torch.nn.Module):
def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor):
# Module responsible for representing the final checkpoint representation
class FP8StaticLinear(torch.nn.Module):
def __init__(
self,
weight: torch.nn.Parameter,
weight_scale: torch.nn.Parameter,
bias: torch.nn.Parameter,
input_scale: torch.nn.Parameter,
output_scale: Optional[torch.nn.Parameter] = None,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
self.weight = weight
self.weight_scale = weight_scale
self.bias = bias
self.input_scale = input_scale
self.output_scale = output_scale

def forward(self, x):
qinput, x_scale = per_tensor_quantize(x)
qinput = static_per_tensor_quantize(x, self.input_scale)
output = fp8_gemm(
A=qinput,
A_scale=x_scale,
A_scale=self.input_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
)

if self.output_scale:
qoutput = static_per_tensor_quantize(output, self.output_scale)
output = qoutput.to(output.dtype) * self.output_scale

return output


Expand All @@ -194,7 +226,6 @@ def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.
def quantize_weights(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
ignored_layers: List[str] = [],
):
named_modules = list(model.named_modules())
for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"):
Expand All @@ -203,9 +234,11 @@ def quantize_weights(
or name in quantize_config.ignored_layers
):
continue
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, bias)
quant_linear = FP8DynamicLinear(
weight=quant_weight, weight_scale=weight_scale, bias=bias
)
replace_module(model, name, quant_linear)
del linear.weight
del linear.bias
Expand All @@ -217,7 +250,6 @@ def quantize_activations(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
calibration_tokens,
ignored_layers: List[str] = [],
):
# Replace weight quantizer with a dynamic activation quantizer observer
for name, dynamic_quant_linear in model.named_modules():
Expand All @@ -227,9 +259,13 @@ def quantize_activations(
):
continue
quantizer = FP8StaticLinearQuantizer(
dynamic_quant_linear.weight,
dynamic_quant_linear.weight_scale,
dynamic_quant_linear.bias,
weight=dynamic_quant_linear.weight,
weight_scale=dynamic_quant_linear.weight_scale,
bias=dynamic_quant_linear.bias,
quantize_output=(
hasattr(quantize_config, "kv_cache_quant_layers")
and name in quantize_config.kv_cache_quant_layers
),
)
replace_module(model, name, quantizer)
del dynamic_quant_linear
Expand All @@ -251,21 +287,45 @@ def quantize_activations(
):
continue
static_proj = FP8StaticLinear(
quantizer.weight,
quantizer.weight_scale,
quantizer.bias,
quantizer.input_scale,
weight=quantizer.weight,
weight_scale=quantizer.weight_scale,
bias=quantizer.bias,
input_scale=quantizer.input_scale,
output_scale=quantizer.output_scale,
)
replace_module(model, name, static_proj)
del quantizer
cleanup_memory()

# Post-process step for kv cache scales to take the k/v module
# `output_scale` parameters, take the max of them, and store them in
# the parent attention module as `kv_scale`
# NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block
if hasattr(quantize_config, "kv_cache_quant_layers"):
# Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...]
# so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...]
kv_proj_pairs = zip(*[iter(quantize_config.kv_cache_quant_layers)]*2)
for k_proj_name, v_proj_name in kv_proj_pairs:
parent_module_name = ".".join(k_proj_name.split(".")[:-1])
assert parent_module_name == ".".join(v_proj_name.split(".")[:-1])
parent_module = dict(model.named_modules())[parent_module_name]

k_proj = dict(model.named_modules())[k_proj_name]
v_proj = dict(model.named_modules())[v_proj_name]

kv_scale = max(k_proj.output_scale, v_proj.output_scale)
parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False)

# Remove output_scale from k_proj and v_proj
k_proj.output_scale = None
v_proj.output_scale = None
cleanup_memory()


def save_quantized_model(
model: AutoModelForCausalLM,
quant_config: BaseQuantizeConfig,
save_dir: str,
ignored_layers: List[str] = [],
):
print(model)
print(f"Saving the model to {save_dir}")
Expand All @@ -276,6 +336,8 @@ def save_quantized_model(
"ignored_layers": quant_config.ignored_layers,
}
}
if hasattr(quant_config, "kv_cache_quant_layers"):
static_q_dict["quantization_config"]["kv_cache_scheme"] = "static"
model.config.update(static_q_dict)
model.save_pretrained(save_dir)
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
Expand Down
Loading
Loading