Skip to content

Commit

Permalink
Merge pull request #737 from google:msingh-kv
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648407356
  • Loading branch information
maxtext authors committed Jul 1, 2024
2 parents 93efadf + 4378403 commit bdeab2b
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 66 deletions.
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ quantize_kvcache: False
# Default to "heads_and_dkv" for faster compution, kv_quant_axis is not used when quantize_kvcache is False
# - "dkv" is expected with better accuracy but degraded computation
kv_quant_axis: "heads_and_dkv"
kv_quant_dtype: "int8"
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint
# Saves params quantized on fly at following path
save_quantized_params_path: ""
Expand Down
77 changes: 40 additions & 37 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
RotaryEmbedding = embeddings.RotaryEmbedding
NdInitializer = initializers.NdInitializer
Quant = quantizations.AqtQuantization
KVQuant = quantizations.KVQuant

AxisNames = common_types.AxisNames
AxisIdxes = common_types.AxisIdxes
Expand Down Expand Up @@ -129,8 +130,7 @@ class AttentionOp(nn.Module):
dropout_rate: float = 0.0
dtype: DType = jnp.float32
quant: Optional[Quant] = None
quantize_kvcache: bool = False
kv_quant_axis: str = "heads_and_dkv"
kv_quant: Optional[KVQuant] = None

def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None:
"""Check attention inputs."""
Expand Down Expand Up @@ -425,9 +425,21 @@ def reverse_transepose(self, transposed_array, transpose_axis_order):
def transpose_tuple(self, items: tuple[Any, Any, Any, Any], axis_order: AxisIdxes) -> tuple[Any, Any, Any, Any]:
return tuple([items[i] for i in axis_order])

def _get_cached_kv_dtype(self, dtype):
return self.kv_quant.dtype if self.kv_quant else dtype

def _get_cache_scale_logical_shape(self, batch, heads):
assert self.kv_quant
if self.kv_quant.axis_cfg == "dkv":
return (batch, self.max_prefill_predict_length, heads, 1)
if self.kv_quant.axis_cfg == "heads_and_dkv":
return (batch, self.max_prefill_predict_length, 1, 1)
raise f"Invalid config for kv_quant_axis:{self.kv_quant.axis_cfg}"


def _get_prefill_cache_vars(self, batch, heads, kv_head_size):

dtype = jnp.int8 if self.quantize_kvcache else self.dtype
dtype = self._get_cached_kv_dtype(self.dtype)
cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size)

cache_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.prefill_cache_axis_order)
Expand Down Expand Up @@ -455,13 +467,8 @@ def _get_prefill_cache_vars(self, batch, heads, kv_head_size):
jnp.int32,
)

if self.quantize_kvcache:

if self.kv_quant_axis == "dkv":
cache_scale_logical_shape = (batch, self.max_prefill_predict_length, heads, 1)
elif self.kv_quant_axis == "heads_and_dkv":
cache_scale_logical_shape = (batch, self.max_prefill_predict_length, 1, 1)

if self.kv_quant:
cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads)
cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.prefill_cache_axis_order)
cache_scale_shape = self.transpose_tuple(cache_scale_logical_shape, self.prefill_cache_axis_order)

Expand Down Expand Up @@ -489,7 +496,7 @@ def _get_prefill_cache_vars(self, batch, heads, kv_head_size):

def _get_ar_cache_vars(self, batch, heads, kv_head_size):

dtype = jnp.int8 if self.quantize_kvcache else self.dtype
dtype = self._get_cached_kv_dtype(self.dtype)
cache_length = self.max_target_length - self.max_prefill_predict_length
cache_logical_shape = (batch, cache_length, heads, kv_head_size)

Expand Down Expand Up @@ -529,13 +536,8 @@ def _get_ar_cache_vars(self, batch, heads, kv_head_size):
jnp.int32,
)

if self.quantize_kvcache:

if self.kv_quant_axis == "dkv":
cache_scale_logical_shape = (batch, cache_length, heads, 1)
elif self.kv_quant_axis == "heads_and_dkv":
cache_scale_logical_shape = (batch, cache_length, 1, 1)

if self.kv_quant:
cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads)
cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order)
cache_scale_shape = self.transpose_tuple(cache_scale_logical_shape, self.ar_cache_axis_order)

Expand Down Expand Up @@ -590,12 +592,13 @@ def kv_cache_prefill(
key_shaped_for_cache = jnp.transpose(key, self.prefill_cache_axis_order)
value_shaped_for_cache = jnp.transpose(value, self.prefill_cache_axis_order)

if self.quantize_kvcache:
prefill_key_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.prefill_cache_axis_order)
key_shaped_for_cache, key_scale_shaped_for_cache = quantizations.quantize_kv(
key_shaped_for_cache, self.kv_quant_axis, prefill_key_axis_names)
value_shaped_for_cache, value_scale_shaped_for_cache = quantizations.quantize_kv(
value_shaped_for_cache, self.kv_quant_axis, prefill_key_axis_names)
if self.kv_quant:
prefill_key_axis_names = self.transpose_tuple(
self.cache_logical_axis_names, self.prefill_cache_axis_order)
key_shaped_for_cache, key_scale_shaped_for_cache = self.kv_quant.quantize(
key_shaped_for_cache, prefill_key_axis_names)
value_shaped_for_cache, value_scale_shaped_for_cache = self.kv_quant.quantize(
value_shaped_for_cache, prefill_key_axis_names)
cached_prefill_key_vars[1].value = key_scale_shaped_for_cache
cached_prefill_value_vars[1].value = value_scale_shaped_for_cache

Expand Down Expand Up @@ -637,11 +640,11 @@ def update_ar_key_value(
one_token_value_shaped_for_cache = jnp.transpose(one_token_value, self.ar_cache_axis_order)

ar_cache_axis_names = self.transpose_tuple(self.cache_logical_axis_names, self.ar_cache_axis_order)
if self.quantize_kvcache:
one_token_key_shaped_for_cache, one_token_key_scale_shaped_for_cache = quantizations.quantize_kv(
one_token_key_shaped_for_cache, self.kv_quant_axis, ar_cache_axis_names)
one_token_value_shaped_for_cache, one_token_value_scale_shaped_for_cache = quantizations.quantize_kv(
one_token_value_shaped_for_cache, self.kv_quant_axis, ar_cache_axis_names)
if self.kv_quant:
one_token_key_shaped_for_cache, one_token_key_scale_shaped_for_cache = self.kv_quant.quantize(
one_token_key_shaped_for_cache, ar_cache_axis_names)
one_token_value_shaped_for_cache, one_token_value_scale_shaped_for_cache = self.kv_quant.quantize(
one_token_value_shaped_for_cache, ar_cache_axis_names)

one_hot_indices = one_hot_indices.astype(int)
ar_cache_update_idx = jnp.squeeze(one_hot_indices)
Expand All @@ -654,7 +657,7 @@ def update_ar_key_value(
cached_value_var.value, one_token_value_shaped_for_cache, ar_cache_update_idx, ar_cache_update_axis)
cached_value_var.value = nn.with_logical_constraint(cached_value_var.value, ar_cache_axis_names)

if self.quantize_kvcache:
if self.kv_quant:
ar_cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order)
ar_cache_scale_update_axis = ar_cache_scale_axis_names.index(CACHE_SCALE_SEQUENCE)
cached_key_scale_var.value = jax.lax.dynamic_update_index_in_dim(
Expand All @@ -669,7 +672,7 @@ def get_cached_values(self, cache_vars, target_dtype, cache_axis_order):
cached_value = cache_var.value
if cache_scale_var is not None:
cached_scale_value = cache_scale_var.value
cached_value = quantizations.unquantize_kv(cached_value, cached_scale_value, target_dtype)
cached_value = self.kv_quant.unquantize(cached_value, cached_scale_value, target_dtype)

cache_value_in_logical_shape = self.reverse_transepose(cached_value, cache_axis_order)
return cache_value_in_logical_shape
Expand Down Expand Up @@ -830,7 +833,7 @@ class Attention(nn.Module):
float32_logits: bool, if True then cast logits to float32 before softmax to avoid
numerical issues with bfloat16.
quant: Quant, stores quantization parameters, defaults to None implying no quantization.
quantize_kvcache: bool, quantize the kv cache.
kv_quant: KVQuant, stores KV cache quantization parameters, defaults to None
"""

config: Config
Expand All @@ -848,7 +851,7 @@ class Attention(nn.Module):
float32_qk_product: bool = False # computes logits in float32 for stability.
float32_logits: bool = False # cast logits in float32 for stability.
quant: Optional[Quant] = None
quantize_kvcache: bool = False
kv_quant: Optional[KVQuant] = None

# Shard the query activation as the same as the key and value.
# TODO: Find a better sharding axis name.
Expand All @@ -861,7 +864,7 @@ class Attention(nn.Module):
ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3)
compute_axis_order: AxisIdxes = (0, 1, 2, 3)
reshape_q: bool = False
kv_quant_axis: str = "heads_and_dkv"


def query_projection(self, inputs_q: Array) -> Array:
"""Query projection."""
Expand Down Expand Up @@ -950,7 +953,7 @@ def out_projection(self, output_dim: int, out: Array) -> Array:

def key_rotary(self, key: Array, inputs_positions: Array):
"""Apply Rotary Embedding to key."""
key = RotaryEmbedding(min_timescale=self.config.rope_min_timescale, max_timescale = self.config.rope_max_timescale,
key = RotaryEmbedding(min_timescale=self.config.rope_min_timescale, max_timescale = self.config.rope_max_timescale,
embedding_dims=self.head_dim, name="key_rotary")(inputs=key, position=inputs_positions)
return key

Expand Down Expand Up @@ -1008,6 +1011,7 @@ def __call__(
value = nn.with_logical_constraint(value, self.value_axis_names)
value = checkpoint_name(value, "value_proj")

assert not self.config.quantize_kvcache or self.kv_quant
attention_op = AttentionOp(
mesh=self.mesh,
attention_kernel=self.attention_kernel,
Expand All @@ -1016,7 +1020,7 @@ def __call__(
float32_qk_product=self.float32_qk_product,
float32_logits=self.float32_logits,
quant=self.quant,
quantize_kvcache=self.quantize_kvcache,
kv_quant=self.kv_quant,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
dropout_rate=self.dropout_rate,
Expand All @@ -1025,7 +1029,6 @@ def __call__(
ar_cache_axis_order=self.ar_cache_axis_order,
compute_axis_order=self.compute_axis_order,
reshape_q=self.reshape_q,
kv_quant_axis=self.kv_quant_axis,
)

out = attention_op(query, key, value, decoder_segment_ids, model_mode)
Expand Down
3 changes: 2 additions & 1 deletion MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

nd_dense_init = initializers.nd_dense_init
Quant = quantizations.AqtQuantization
KVQuant = quantizations.KVQuant


# Decoder and Model definitions
Expand Down Expand Up @@ -93,7 +94,7 @@ def __call__(
float32_qk_product=True,
float32_logits=True,
quant=self.quant,
quantize_kvcache=cfg.quantize_kvcache,
kv_quant=quantizations.configure_kv_quant(cfg),
)

attention_lnx = attention_layer(
Expand Down
6 changes: 5 additions & 1 deletion MaxText/layers/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Initializer = initializers.Initializer
nd_dense_init = initializers.nd_dense_init
Quant = quantizations.AqtQuantization
KVQuant = quantizations.KVQuant


# -----------------------------------------
Expand Down Expand Up @@ -144,13 +145,15 @@ class Gpt3MultiHeadAttention(nn.Module):
float32_logits: bool = True # cast logits in float32 for stability.
fused_qkv: bool = True
quant: Optional[Quant] = None
kv_quant: Optional[KVQuant] = None
use_bias: bool = True

query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)


def qkv_projection(self, inputs: Array, proj_name: str):
"""Fused QKV projection"""

Expand Down Expand Up @@ -233,7 +236,7 @@ def __call__(
float32_qk_product=self.float32_qk_product,
float32_logits=self.float32_logits,
quant=self.quant,
quantize_kvcache=self.config.quantize_kvcache,
kv_quant=self.kv_quant,
num_query_heads=self.num_heads,
num_kv_heads=self.num_heads,
dtype=self.dtype,
Expand Down Expand Up @@ -306,6 +309,7 @@ def __call__(
fused_qkv=cfg.fused_qkv,
use_bias=True,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg)
)

attention_lnx = attention_layer(
Expand Down
3 changes: 1 addition & 2 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,11 @@ def __call__(
dropout_rate=cfg.dropout_rate,
name="self_attention",
quant=self.quant,
quantize_kvcache=cfg.quantize_kvcache,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]),
compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]),
reshape_q=cfg.reshape_q,
kv_quant_axis=cfg.kv_quant_axis,
)

attention_lnx = attention_layer(
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __call__(
dropout_rate=cfg.dropout_rate,
name="self_attention",
quant=self.quant,
quantize_kvcache=cfg.quantize_kvcache,
kv_quant=quantizations.configure_kv_quant(cfg),
)

attention_lnx = attention_layer(
Expand Down
3 changes: 1 addition & 2 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,11 @@ def __call__(
dropout_rate=cfg.dropout_rate,
name="self_attention",
quant=self.quant,
quantize_kvcache=cfg.quantize_kvcache,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]),
compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]),
reshape_q=cfg.reshape_q,
kv_quant_axis=cfg.kv_quant_axis,
)

attention_lnx = attention_layer(
Expand Down
73 changes: 51 additions & 22 deletions MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import Tuple, Sequence

MAX_INT8 = 127.5
MAX_INT4 = 7.5

Array = common_types.Array
Config = common_types.Config
Expand Down Expand Up @@ -195,26 +196,54 @@ def remove_quantized_params(params, aqt_vars):
tree_flat[i] = v
return tree_unflatten(tree_struct, tree_flat)

def configure_kv_quant(config):
return None if not config.quantize_kvcache else KVQuant(config)

class KVQuant:
axis_cfg = ""
dtype = None

def __init__(self, config:Config):
assert config.quantize_kvcache
self.axis_cfg = config.kv_quant_axis
self.dtype = self._get_dtype(config.kv_quant_dtype)

def _get_dtype(self, dtype_cfg: str):
if dtype_cfg == "int4":
return jnp.int4
if dtype_cfg == "int8":
return jnp.int8
raise ValueError(f"Invalid kv_quant_dtype: {dtype_cfg}")

def _get_max_axis(self, axis_names: AxisNames):
if self.axis_cfg == "dkv":
return axis_names.index(CACHE_KV)
if self.axis_cfg == "heads_and_dkv":
return (
axis_names.index(CACHE_HEADS),
axis_names.index(CACHE_KV)
)
raise ValueError(f"Invalid KV quant axis cfg: {self.axis_cfg}")

def quantize(self, kv: Array, axis_names: AxisNames):
"""Quantize key/values stored in kvcache."""
assert self.axis_cfg, 'KV quant axis cannot be None'
max_axis = self._get_max_axis(axis_names)
scale = jnp.max(jnp.abs(kv), axis=max_axis, keepdims=True)
if self.dtype == jnp.int8:
value = jnp.int8(jnp.rint(kv * (MAX_INT8 / scale)))
return value, scale
if self.dtype == jnp.int4:
value = jnp.int4(jnp.rint(kv * (MAX_INT4 / scale)))
return value, scale
raise ValueError(f"Invalid KV quant dtype:{self.dtype}.")


def unquantize(self, value: Array, scale: Array, dtype: jnp.dtype):
"""Unquantize key/values stored in kvcache."""
if self.dtype == jnp.int8:
return value.astype(dtype) * scale / MAX_INT8
if self.dtype == jnp.int4:
return value.astype(dtype) * scale / MAX_INT4
raise ValueError(f"Invalid KV quant dtype: {self.dtype}.")

def configure_kv_quantization(config: Config):
"""Configure kv quantization based on user config."""
return False if not config.quantize_kvcache else True


def quantize_kv(kv: Array, kv_quant_axis: str, axis_names: AxisNames):
"""Quantize key/values stored in kvcache."""
if kv_quant_axis == "dkv":
max_axis_over = axis_names.index(CACHE_KV)
elif kv_quant_axis == "heads_and_dkv":
max_axis_over = (
axis_names.index(CACHE_HEADS),
axis_names.index(CACHE_KV)
)
scale = jnp.max(jnp.abs(kv), axis=max_axis_over, keepdims=True)
value = jnp.int8(jnp.rint(kv * (MAX_INT8 / scale)))
return value, scale


def unquantize_kv(value: Array, scale: Array, dtype: jnp.dtype):
"""Unquantize key/values stored in kvcache."""
return value.astype(dtype) * scale / MAX_INT8

0 comments on commit bdeab2b

Please sign in to comment.