diff --git a/benchmarks/mixtral_offline.sh b/benchmarks/mixtral_offline.sh new file mode 100644 index 00000000..9572366f --- /dev/null +++ b/benchmarks/mixtral_offline.sh @@ -0,0 +1,25 @@ +CACHE_LENGTH=$1 +BATCH_SIZE=$2 +INPUT_SIZE=1024 +OUTPUT_SIZE=1024 +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ +export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2" +export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" + +pushd .. +python -m benchmarks.run_offline \ + --lazy_cache_update=1 \ + --ring_buffer=0 \ + --model_name=mixtral \ + --batch_size=$BATCH_SIZE \ + --max_cache_length=$CACHE_LENGTH \ + --max_decode_length=$OUTPUT_SIZE \ + --context_length=$INPUT_SIZE \ + --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ + --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ + --quantize_weights=1 \ + --quantize_type=int8_per_channel \ + --quantize_kv_cache=1 \ + --profiling_output=/mnt/disks/hanq/mixtral-profiles +popd +echo "batch was $2 cache was $1" diff --git a/benchmarks/offline_benchmark.py b/benchmarks/offline_benchmark.py new file mode 100644 index 00000000..0b007eea --- /dev/null +++ b/benchmarks/offline_benchmark.py @@ -0,0 +1,97 @@ +import math +import pandas as pd +import dataclasses +from collections import defaultdict +from absl import flags, app + +from typing import Dict + +FLAGS = flags.FLAGS + +flags.DEFINE_string('dataset_path', '', '') + +@dataclasses.dataclass +class Stat: + cache_size: int + batch_size: int + prefill_times: Dict[int, float] + decode_time: float + +scenario1 = [ + Stat( + cache_size = 512, + batch_size = 2048, + prefill_times = {16: 0.02084908019969589, 32: 0.024125573800120037, 64: 0.02697298339990084, 128: 0.03641403259971412, 256: 0.05809259879970341, 512: 0.10703752639965387}, + decode_time = 0.359 + #ecode_time = 0.28 + ), + Stat( + cache_size = 1280, + batch_size = 512, + prefill_times={16: 0.02070321020000847, 32: 0.02408570580009837, 64: 0.02650543759955326, 128: 0.036217428799864136, 256: 0.057748028799687746, 512: 0.10604073840004276, 1024: 0.20993155719988862}, + decode_time=0.094, + ), + Stat( + cache_size = 3072, + batch_size = 256, + prefill_times={16: 0.020371186199918158, 32: 0.024281639599939807, 64: 0.02710893359981128, 128: 0.03605372060046648, 256: 0.0574128626001766, 512: 0.10610043820051943, 1024: 0.2097496903996216, 2048: 0.4301163775999157}, + decode_time = 0.0552, + ), +] + +scenario2 = [ + scenario1[2], + scenario1[2], + scenario1[2] +] +def eval_scenario(dataset, scenario): + + total_input_tokens = 0 + total_output_tokens = 0 + total_prefill_times = defaultdict(float) + total_decode_times = defaultdict(float) + output_tokens_by_bucket = defaultdict(int) + for _, data in dataset.iterrows(): + stat = scenario[data.bucket] + total_input_tokens += data.tok_input_len + total_output_tokens += data.tok_ref_output_len + input_len_bucket = 2**math.ceil(math.log2(data.tok_input_len)) + if input_len_bucket == 2048 and data.bucket == 1: + import pdb; pdb.set_trace() + total_prefill_times[input_len_bucket] += stat.prefill_times[input_len_bucket] + output_tokens_by_bucket[data.bucket] += data.tok_ref_output_len + + for k in output_tokens_by_bucket.keys(): + stat = scenario[k] + total_decode_times[k] = output_tokens_by_bucket[k] / stat.batch_size * scenario[k].decode_time + + prefill_total = sum(total_prefill_times.values()) + decode_total = sum(total_decode_times.values()) + print('Total input tokens', total_input_tokens) + print('Total output tokens', total_output_tokens) + print('Input / output', total_input_tokens / total_output_tokens) + print('Prefill times', total_prefill_times) + print('pref throughput', total_input_tokens / sum(total_prefill_times.values())) + print('decode times', total_decode_times) + print('decode throughput', total_output_tokens / sum(total_decode_times.values()) ) + print('overall throughput', + total_output_tokens / + (sum(total_decode_times.values()) + sum(total_prefill_times.values()))) + print('prefill total time', prefill_total) + print('decode total time', decode_total) + + + +def main(argv): + dataset = pd.read_pickle(FLAGS.dataset_path) + total_len = dataset.tok_input_len + dataset.tok_ref_output_len + bucket = 0 + (total_len > 512) + ((total_len > 1280) | (dataset.tok_input_len > 1024)) + dataset.insert(2, 'bucket', bucket) + eval_scenario(dataset, scenario1) + print('======== scenario 2 ========') + eval_scenario(dataset, scenario2) + +if __name__ == '__main__': + app.run(main) + + diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index daeafac7..e705dfe5 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -32,7 +32,7 @@ flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file") -def run_prefill_time(engine, params, decode_state, seqlen): +def run_prefill_time(engine, params, decode_state, seqlen, profiler_started): """Run prefill and measure time.""" metadata = engine.get_tokenizer() tokenizer = engine.build_tokenizer(metadata) @@ -53,6 +53,10 @@ def run_prefill_time(engine, params, decode_state, seqlen): nums = 5 start = time.perf_counter() for i in range(nums): + if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started: + jax.profiler.start_trace(FLAGS.profiling_output) + profiler_started = True + prefill_result, _ = engine.prefill( params=params, padded_tokens=tokens, true_length=true_length ) @@ -60,8 +64,9 @@ def run_prefill_time(engine, params, decode_state, seqlen): prefill_result, decode_state, slot=jnp.int32(i) ) jax.block_until_ready(decode_state) + end = time.perf_counter() - return (end - start) / nums, decode_state + return (end - start) / nums, decode_state, profiler_started MAXTEXT_PREFILL = { @@ -72,6 +77,7 @@ def run_prefill_time(engine, params, decode_state, seqlen): 256: 23.59, 512: 35.28, 1024: 60.28, + 2048: 60.28, } @@ -86,9 +92,12 @@ def main(argv): prefill_times = {} decode_state = engine.init_decode_state() + profiler_started = False for batch, _ in MAXTEXT_PREFILL.items(): - runtime, decode_state = run_prefill_time( - engine, params, decode_state, batch + if batch > FLAGS.max_cache_length: + continue + runtime, decode_state, profiler_started = run_prefill_time( + engine, params, decode_state, batch, profiler_started ) prefill_times[batch] = runtime @@ -103,10 +112,12 @@ def main(argv): profiling_output = FLAGS.profiling_output print("======= decode starting ===") + dec_times = [] for i in range(10): - if profiling_output and i == 7: + if profiling_output and i == 7 and not profiler_started: jax.profiler.start_trace(profiling_output) + profiler_started = True start = time.perf_counter() # pylint: disable-next=all decode_state, sampled_tokens = engine.generate(params, decode_state) @@ -116,7 +127,7 @@ def main(argv): dec_times.append(end - start) print(i, "decode time", (end - start)) - if profiling_output: + if profiler_started: jax.profiler.stop_trace() print("prefill ", prefill_times) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 96bb4233..6d571d2c 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -16,6 +16,7 @@ def ragged_flash_attention_kernel( + layer_ref, start_ref, end_ref, line_end_ref, @@ -105,30 +106,45 @@ def run(): @functools.partial( - jax.jit, static_argnames=["bk", "mask_value", "normalize_var"] + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "testing", + "quantized", + ], ) def ragged_mqa( q: jax.Array, k: jax.Array, v: jax.Array, + layer, start: jax.Array, end: jax.Array, - k_scaler: jax.Array | None = None, - v_scaler: jax.Array | None = None, ragged_batch_index=None, ragged_block_index=None, + k_scaler: jax.Array | None = None, + v_scaler: jax.Array | None = None, bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, + testing: bool = False, + quantized: bool = False, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi query attention.""" with jax.named_scope("ragged_mqa"): - batch_size, num_heads, head_dim = q.shape - seq_len = k.shape[1] + batch_size, time, head_dim = q.shape + seq_len = k.shape[-2] + + stacked = False + if k.ndim == 5: + stacked = True def kv_index_map( b, i, + layer_ref, start_ref, end_ref, line_end_ref, @@ -136,11 +152,20 @@ def kv_index_map( ragged_block_index_ref, ): index = b * (seq_len // bk) + i + + if stacked: + return ( + layer_ref[0], + ragged_batch_index_ref[index], + ragged_block_index_ref[index], + 0, + ) return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 def q_index_map( b, i, + layer_ref, start_ref, end_ref, line_end_ref, @@ -148,17 +173,32 @@ def q_index_map( ragged_block_index_ref, ): index = b * (seq_len // bk) + i + if stacked: + return layer_ref[0], ragged_batch_index_ref[index], 0, 0 return ragged_batch_index_ref[index], 0, 0 - def scaler_index_map(b, i, *_): + def scaler_index_map(b, i, layer_ref, *_): + if stacked: + return layer_ref[0], b, 0, i return b, 0, i line_end = jnp.where(start < end, end, seq_len - 1) + if stacked: + q_bp = (None, None, time, head_dim) + kv_bp = (None, None, bk, head_dim) + ks_bp = (None, None, 1, bk) + else: + q_bp = (None, time, head_dim) + kv_bp = (None, bk, head_dim) + ks_bp = (None, 1, bk) + in_specs = [ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), - pl.BlockSpec(kv_index_map, (None, bk, head_dim)), + pl.BlockSpec(q_index_map, q_bp), + pl.BlockSpec(kv_index_map, kv_bp), + pl.BlockSpec(kv_index_map, kv_bp), + pl.BlockSpec(scaler_index_map, ks_bp), + pl.BlockSpec(scaler_index_map, ks_bp), ] inputs = ( start, @@ -169,15 +209,9 @@ def scaler_index_map(b, i, *_): q, k, v, + k_scaler, + v_scaler, ) - quantized = False - if k_scaler is not None: - in_specs = in_specs + [ - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - pl.BlockSpec(scaler_index_map, (None, 1, bk)), - ] - inputs = inputs + (k_scaler, v_scaler) - quantized = True out, m, l = pl.pallas_call( functools.partial( @@ -191,33 +225,241 @@ def scaler_index_map(b, i, *_): num_scalar_prefetch=5, in_specs=in_specs, out_specs=[ - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), - pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), + pl.BlockSpec(q_index_map, (None, time, head_dim)), ], grid=(batch_size, seq_len // bk), ), compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, + interpret=testing, out_shape=[ q, - jax.ShapeDtypeStruct( - (batch_size, num_heads, head_dim), jnp.float32 - ), - jax.ShapeDtypeStruct( - (batch_size, num_heads, head_dim), jnp.float32 - ), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), ], )(*inputs) return out, (m[..., 0], l[..., 0]) +def ragged_mqa_kernel_reference( + layer_ref, + start_ref, + end_ref, + line_end_ref, + pre_b_ref, + pre_i_ref, + q_ref, + k_ref, + v_ref, + k_scaler_ref, + v_scaler_ref, + o_ref, + m_ref, + l_ref, + bk: int, + mask_value: float, + normalize_var: bool, + quantized: bool, +): + """Pallas kernel for ragged attention.""" + b, i = pl.program_id(0), pl.program_id(1) + del layer_ref + + @pl.when(i == 0) + def init(): + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) + l_ref[...] = jnp.zeros_like(l_ref) + o_ref[...] = jnp.zeros_like(o_ref) + + # length = lengths_ref[b] + # Always start from 0, left aligned + length = end_ref[b] + + @pl.when(i * bk < length) + def run(): + q = q_ref[...].astype(jnp.float32) + k = k_ref[...].astype(jnp.float32) + v = v_ref[...].astype(jnp.float32) + m_prev, l_prev = m_ref[...], l_ref[...] + + qk = jax.lax.dot_general( + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 + ) + + if normalize_var: + qk = qk / math.sqrt(k.shape[-1]) # Align with meta llama + # Quantized + if quantized: + qk = qk * k_scaler_ref[...] + + mask = i * bk + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length + qk = qk + jnp.where(mask, 0.0, mask_value) + m_curr = qk.max(axis=-1) + + s_curr = jnp.exp(qk - m_curr[..., None]) + + l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,)) + # Quantized + if quantized: + s_curr = s_curr * v_scaler_ref[...] + + o_curr_times_l_curr = jnp.dot(s_curr, v) + + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) + m_next = jnp.maximum(m_prev, m_curr) + alpha = jnp.exp(m_prev - m_next) + beta = jnp.exp(m_curr - m_next) + l_next = alpha * l_prev + beta * l_curr + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) + + m_ref[...], l_ref[...] = m_next, l_next_safe + o_ref[...] = ( + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe + ).astype(o_ref.dtype) + + @functools.partial( - jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"] + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "testing", + "quantized", + ], +) +def ragged_mqa_reference( + q: jax.Array, + k: jax.Array, + v: jax.Array, + layer, + start: jax.Array, + end: jax.Array, + ragged_batch_index=None, + ragged_block_index=None, + k_scaler: jax.Array = None, + v_scaler: jax.Array = None, + bk: int = 512, + mask_value: float = DEFAULT_MASK_VALUE, + normalize_var: bool = True, + testing: bool = False, + quantized: bool = False, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: + """Ragged multi query attention.""" + batch_size, time, head_dim = q.shape + # assert end.shape == (batch_size,) + seq_len = k.shape[-2] + + stacked = False + if k.ndim == 4: + stacked = True + + def _compute_ragged_block_indices(b, i, lengths_ref): + length = lengths_ref[b] + not_done = i * bk < length + am_last_batch = b == batch_size - 1 + # if length < bk, then it's -1, should be 0? + last_good_block = jax.lax.div(length, bk) - 1 + + # if not done, then still work on b, otherwise next batch + b_next = jnp.where(not_done, b, jnp.where(am_last_batch, b, b + 1)) + # if not done, i next = i + # if done + # if last batch, previous good block + # if not last batch, i next = 0 + i_next = jnp.where( + not_done, i, jnp.where(am_last_batch, last_good_block, 0) + ) + return b_next, i_next + + def kv_index_map(b, i, layer_ref, start_ref, end_ref, *_): + b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) + if stacked: + return layer_ref[0], b_next, i_next, 0 + return b_next, i_next, 0 + + def kv_scale_index_map(b, i, layer_ref, start_ref, end_ref, *_): + b_next, i_next = _compute_ragged_block_indices(b, i, end_ref) + if stacked: + return layer_ref[0], b_next, 0, i_next + return b_next, 0, i_next + + if stacked: + kv_bp = (None, None, bk, head_dim) + ks_bp = (None, None, 1, bk) + else: + kv_bp = (None, bk, head_dim) + ks_bp = (None, 1, bk) + + in_specs = [ + pl.BlockSpec(lambda b, i, *_: (b, 0, 0), (None, time, head_dim)), # q + pl.BlockSpec(kv_index_map, kv_bp), # k + pl.BlockSpec(kv_index_map, kv_bp), # v + pl.BlockSpec(kv_scale_index_map, ks_bp), # k_scaler + pl.BlockSpec(kv_scale_index_map, ks_bp), # v_scaler + ] + + inputs = ( + jnp.array([layer]), + start, + end, + end, # line_end, not actually used + ragged_batch_index, + ragged_block_index, + q, + k, + v, + k_scaler, + v_scaler, + ) + + out, m, l = pl.pallas_call( + functools.partial( + ragged_mqa_kernel_reference, + bk=bk, + mask_value=mask_value, + normalize_var=normalize_var, + quantized=quantized, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=6, + in_specs=in_specs, + out_specs=[ + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + pl.BlockSpec(lambda b, *_: (b, 0, 0), (None, time, head_dim)), + ], + grid=(batch_size, seq_len // bk), + ), + interpret=testing, + # debug=True, + compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, + out_shape=[ + q, + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + jax.ShapeDtypeStruct((batch_size, time, head_dim), jnp.float32), + ], + )(*inputs) + return out, (m[..., 0], l[..., 0]) + + +@functools.partial( + jax.jit, + static_argnames=[ + "bk", + "mask_value", + "normalize_var", + "q_shard_axis", + "kv_shard_axis", + "testing", + ], ) def ragged_mha( q: jax.Array, k: jax.Array, v: jax.Array, + layer, start: jax.Array, end: jax.Array, ragged_batch_index: jax.Array, @@ -227,7 +469,9 @@ def ragged_mha( bk: int = 512, mask_value: float = DEFAULT_MASK_VALUE, normalize_var: bool = True, - shard_axis: int = 1, + q_shard_axis: int = 0, + kv_shard_axis: int = 0, + testing: bool = False, ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: """Ragged multi head attention. Args: @@ -251,35 +495,66 @@ def ragged_mha( softmax denominator ([batch_size, num_heads, compute_dim, 1]). """ mask_value = DEFAULT_MASK_VALUE + bk = min(bk, k.shape[-2]) + bq, hq, tq, dq = q.shape + hkv = k.shape[-3] + tk = k.shape[-2] + + assert k.shape[-1] == q.shape[-1] + assert k.shape[-4] == q.shape[-4] + + rep = hq // hkv + if rep > 1: + q = q.reshape(bq, hkv, rep, tq, dq).reshape(bq, hkv, rep * tq, dq) + stacked = k.ndim == 5 + + replicated_in_axes = 7 if k_scaler is None: - replicated_in_axes = 4 - replicated_inputs = (ragged_batch_index, ragged_block_index) + quantized = False + if k.ndim == 5: + kv_scale_shape = (k.shape[0], bq, 1, tk) + else: + kv_scale_shape = (bq, 1, tk) + k_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16) + v_scale = jnp.ones(kv_scale_shape, dtype=jnp.bfloat16) else: - replicated_in_axes = 6 - replicated_inputs = ( - jnp.squeeze(k_scaler, -1), - jnp.squeeze(v_scaler, -1), - ragged_batch_index, - ragged_block_index, - ) + quantized = True + k_scale = jnp.squeeze(k_scaler, -1) + v_scale = jnp.squeeze(v_scaler, -1) + + if stacked: + assert k_scale.shape == (k.shape[0], bq, 1, tk) + else: + assert k_scale.shape == (bq, 1, tk) + + replicated_inputs = ( + ragged_batch_index, + ragged_block_index, + k_scale, + v_scale, + ) + # New cache has t=1 with jax.named_scope("ragged_mha_vmap"): out, (m, l) = jax.vmap( functools.partial( - ragged_mqa, + # ragged_mqa, + ragged_mqa_reference, bk=bk, mask_value=mask_value, normalize_var=normalize_var, + testing=testing, + quantized=quantized, # out_dtype=out_dtype, ), in_axes=( - shard_axis, - shard_axis, - shard_axis, + q_shard_axis, + kv_shard_axis, + kv_shard_axis, *([None] * replicated_in_axes), ), - out_axes=shard_axis, - )(q, k, v, start, end, *replicated_inputs) + out_axes=q_shard_axis, + )(q, k, v, layer, start, end, *replicated_inputs) return out, (m, l) @@ -310,15 +585,77 @@ def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): return output +def flash_attention( + xq, + keys, + values, + layer, + k_scaler=None, + v_scaler=None, + mask=None, + normalize_var=True, +): + """Flash attention kernel.""" + if keys.ndim == 5: + keys = keys[layer] + values = values[layer] + k_scaler = k_scaler[layer] if k_scaler is not None else None + v_scaler = v_scaler[layer] if v_scaler is not None else None + + logits = torch.einsum( + "bhqd,bhkd->bhqk", xq.type(torch.float32), keys.type(torch.float32) + ) + + if normalize_var: + logits = logits / math.sqrt(keys.shape[-1]) # Align with meta llama + # Quantized + if k_scaler is not None: + logits = logits * k_scaler.reshape( + k_scaler.shape[-4], 1, 1, k_scaler.shape[-2] + ) + + # mask = jnp.arange(keys.shape[1])[None] < lengths[:, None] + if mask is not None: + # logits = logits + jnp.where(mask, 0.0, DEFAULT_MASK_VALUE)[:, None] + logits = logits + mask + + logits_max, _ = torch.max(logits, axis=-1, keepdim=True) + unnormalized = torch.exp(logits - logits_max) + # Quantized, should not put here, otherwise sum will have this too, which cancels with denominator + # unnormalized = unnormalized * v_scaler + + denominator = unnormalized.sum(axis=-1, keepdim=True) + if v_scaler is not None: + unnormalized = unnormalized * v_scaler.reshape( + v_scaler.shape[-4], 1, 1, v_scaler.shape[-2] + ) + o = ( + torch.einsum("bhqk,bhkd->bhqd", unnormalized.type_as(xq), values) + / denominator + ) + + return o, (logits_max, denominator) + + class RaggedAttentionKernel: """Ragged attention kernel.""" - def __init__(self, env, input_specs, output_specs, sharding_axis): + def __init__( + self, env, input_specs, output_specs, q_shard_axis, kv_shard_axis + ): self.binded_ragged_mha = functools.partial( - ragged_mha, bk=env.block_size, shard_axis=sharding_axis + ragged_mha, + bk=env.block_size, + q_shard_axis=q_shard_axis, + kv_shard_axis=kv_shard_axis, + testing=env.testing, ) self.binded_ragged_mha = shard_map( - ragged_mha, env.mesh, input_specs, output_specs, check_rep=False + self.binded_ragged_mha, + env.mesh, + input_specs, + output_specs, + check_rep=False, ) self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) @@ -327,6 +664,7 @@ def __call__( xq, keys, values, + layer, start, end, ragged_batch_index, @@ -338,6 +676,7 @@ def __call__( xq, keys, values, + layer, start, end, ragged_batch_index, diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 13789f91..76f44120 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -14,7 +14,10 @@ import jax import jax.numpy as jnp +from jax.experimental.shard_map import shard_map import torch +import torch_xla2 + from jetstream_pt import torchjax @@ -38,19 +41,20 @@ def update(self, key, value): class KVCachePrefill: """Prefill kv cache""" - def __init__(self, kv_quantize=False): + def __init__(self, kv_quantize=False, stacked=False): self.kv_quantize = kv_quantize self.cache_k = None self.cache_v = None + self.stacked = stacked - def update(self, key, value): + def update(self, key, value, layer_id): """This cache just remembers the stuff.""" self.cache_k = key self.cache_v = value if self.kv_quantize: # pretend to be quantized bsz, _, seq, _ = key.shape ones = torchjax.to_torch(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16)) - return key, value, ones, ones + return key, value, None, None, ones, ones, None, None return key, value @@ -58,6 +62,11 @@ def state(self): """Get prefill cache state""" return self.cache_k, self.cache_v + # Placeholder, to match with GenerateCache + def finalize(self): + """Finalize the cache operation and updates the cache.""" + return + # pylint: disable-next=all def KVCachePrefill_flatten(cache): @@ -80,57 +89,225 @@ def KVCachePrefill_unflatten(auxdata, data): ) -# Refactor out cache management -# Easier to test for quantized kv cache class KVCacheGenerate: """Kvache generator without quantization""" + # pylint: disable=too-many-instance-attributes + # More than 7 is reasonable in this case. def __init__( self, cache_k: torch.Tensor, # previous cache cache_v: torch.Tensor, # previous cache - position: int, # position to store the cache + position: int | torch.Tensor, # position to store the cache sharding, env=None, ): super().__init__() self.cache_k = cache_k self.cache_v = cache_v - self.pos = position + self.input_pos = position self.sharding = sharding self.env = env - def update(self, key, value): + self.new_ks = None + self.new_vs = None + self.env = env + # Keep this one it's used in the specific model code. + self.stacked = env.generate_cache_stacked + self.batch = jnp.arange(self.env.batch_size) + # The other way is to store the list and loop over to insert in finalize() + if self.env.lazy_cache_update: + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + layer, batch, heads, _, dim = self.cache_k.shape + new_dim = (layer, batch, heads, 1, dim) + self.new_ks, self.new_vs = torchjax.to_torch( + ( + jnp.zeros(new_dim, dtype=self.env.default_type), + jnp.zeros(new_dim, dtype=self.env.default_type), + ) + ) + else: + self.new_ks, self.new_vs = [], [] + else: # when generate cache is not stacked, new cache cannot stack + assert not self.env.new_cache_stacked + + cache_pspec = self.env.partition_by_axis( + self.env.cache_sharding_axis + ) # Number of heads + none_pspec = self.env.partition_by_axis() + in_specs = (cache_pspec, cache_pspec, cache_pspec, cache_pspec, none_pspec) + out_specs = (cache_pspec, cache_pspec) + self.update_single_cache_line = jax.jit( + shard_map( + self.update_single_cache_line, + self.env.mesh, + in_specs, + out_specs, + check_rep=False, + ) + ) + + # pylint: disable=method-hidden + # False alarm. The jit above doesn't hide this method. + def update_single_cache_line(self, cache_k, cache_v, new_ks, new_vs, pos): + """The shard map version of single cache line update.""" + b = cache_k.shape[-4] + for bb, pp in enumerate(pos.reshape(b)): + slice_dim = 0 + update_start_indices = (bb, 0, pp, 0) + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + slice_dim = 1 + update_start_indices = (0, bb, 0, pp, 0) + # We are not handling generate_cache_stacked=True new_cache_stacked=False here + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + return cache_k, cache_v + + def finalize(self): + """Finalize the cache operation and updates the cache.""" + if not self.env.lazy_cache_update: + return + + if self.env.ring_buffer: + # Assume no cache stack for ring buffer + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax()) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax()) + ) + else: + if self.env.generate_cache_stacked: + _, b, head, _, dim = self.cache_k.shape + if self.env.new_cache_stacked: + self.cache_k, self.cache_v = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.input_pos, + ) + else: + for i in range(self.env.num_layers): + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax() + .at[i, self.batch, :, self.input_pos, :] + .set(self.new_ks[i].jax().reshape(b, head, dim)) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax() + .at[i, self.batch, :, self.input_pos, :] + .set(self.new_vs[i].jax().reshape(b, head, dim)) + ) + else: + # Try to use shard_map to get rid of the data copy + self.cache_k, self.cache_v = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.input_pos, + ) + + def update(self, key, value, layer_id: int): """Update kv cache""" keyj, valuej = torchjax.to_torch((key, value)) + if self.env.lazy_cache_update: + if self.env.new_cache_stacked: + assert ( + self.env.generate_cache_stacked + ), "When new cache stacked, must have generate_cache_stacked!" + self.new_ks[layer_id, ...] = keyj + self.new_vs[layer_id, ...] = valuej + return self.cache_k[layer_id], self.cache_v[layer_id] + + # Generate cache stacked, but new cache unstacked + if self.env.generate_cache_stacked: + self.new_ks.append(keyj) + self.new_vs.append(valuej) + return self.cache_k[layer_id], self.cache_v[layer_id] + + # all cache unstacked + self.new_ks = keyj + self.new_vs = valuej + return self.cache_k, self.cache_v + if self.env.ring_buffer: + assert ( + not self.env.new_cache_stacked and not self.env.generate_cache_stacked + ), "Ring buffer doesn't support stacked cache." # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj) + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(keyj) + ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej) - else: - batch = jnp.arange(self.env.batch_size) + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(valuej) + ) + return self.cache_k, self.cache_v + + # Non lazy cache update, non ring buffer, generate cache stacked + if self.env.generate_cache_stacked: # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos].set( - keyj.squeeze(2) + self.cache_k._elem = ( + self.cache_k.jax() + .at[layer_id, self.batch, :, self.input_pos, :] + .set(keyj.squeeze(2)) ) # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos].set( - valuej.squeeze(2) + self.cache_v._elem = ( + self.cache_v.jax() + .at[layer_id, self.batch, :, self.input_pos, :] + .set(valuej.squeeze(2)) ) + return self.cache_k[layer_id], self.cache_v[layer_id] + + # Non lazy cache update, non ring buffer, generate cache non stacked + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax() + .at[self.batch, :, self.input_pos, :] + .set(keyj.squeeze(2)) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax() + .at[self.batch, :, self.input_pos, :] + .set(valuej.squeeze(2)) + ) return self.cache_k, self.cache_v def state(self): """Get kv cache state""" - # pylint: disable-next=all return self.cache_k.jax(), self.cache_v.jax() @classmethod - def empty(cls, shape, device, bf16_enable, env): + def empty(cls, shape, device, env): """Create empty kv caches""" - default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32 - k = jnp.zeros(shape, device=device, dtype=default_dtype) - v = jnp.zeros(shape, device=device, dtype=default_dtype) + default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 + in_shape = shape + if env.testing: + key = jax.random.key(env.testing_seed) + k_key, v_key = jax.random.split(key) + k = jax.random.uniform(k_key, shape=in_shape, dtype=default_dtype) + v = jax.random.uniform(v_key, shape=in_shape, dtype=default_dtype) + else: + k = jnp.zeros(in_shape, device=device, dtype=default_dtype) + v = jnp.zeros(in_shape, device=device, dtype=default_dtype) k, v = torchjax.to_torch((k, v)) return cls(k, v, 0, device, env=env) @@ -159,7 +336,8 @@ def KVCacheGenerate_unflatten(auxdata, data): class Int8KVCacheGenerate: """Int8 quantized kvache with scalers""" - # pylint: disable-next=all + # pylint: disable=too-many-instance-attributes + # More than 7 is reasonable in this case. def __init__( self, cache_k, @@ -175,9 +353,153 @@ def __init__( self.cache_v = cache_v self.k_scaler = cache_k_scaler self.v_scaler = cache_v_scaler + self.new_ks = None + self.new_vs = None + self.new_k_scaler = None + self.new_v_scaler = None + + self.batch = jnp.arange(env.batch_size) self.input_pos = input_pos self.sharding = sharding self.env = env + self.stacked = env.generate_cache_stacked + + if self.env.lazy_cache_update: + if self.env.generate_cache_stacked: + layer, batch, heads, _, dim = self.cache_k.shape + new_kv_dim = (layer, batch, heads, 1, dim) + self.new_ks, self.new_vs = torchjax.to_torch( + ( + jnp.zeros(new_kv_dim, dtype=jnp.int8), + jnp.zeros(new_kv_dim, dtype=jnp.int8), + ) + ) + if self.env.new_cache_stacked: + new_scale_dim = (layer, batch, 1, 1, 1) + self.new_k_scaler, self.new_v_scaler = torchjax.to_torch( + ( + jnp.zeros(new_scale_dim, dtype=self.env.default_type), + jnp.zeros(new_scale_dim, dtype=self.env.default_type), + ) + ) + else: + self.new_ks, self.new_vs, self.new_k_scaler, self.new_v_scaler = ( + [], + [], + [], + [], + ) + else: # when generate cache is not stacked, new cache cannot stack + assert not self.env.new_cache_stacked + + cache_pspec = self.env.partition_by_axis( + self.env.cache_sharding_axis + ) # Number of heads + new_cache_pspec = ( + self.env.partition_by_axis(2) + if self.env.new_cache_stacked + else self.env.partition_by_axis(1) + ) + none_pspec = self.env.partition_by_axis() + in_specs = ( + *([cache_pspec] * 2), + *([new_cache_pspec] * 2), + *([none_pspec] * 5), + ) + out_specs = (cache_pspec, cache_pspec, none_pspec, none_pspec) + self.update_single_cache_line = shard_map( + self.update_single_cache_line, + self.env.mesh, + in_specs, + out_specs, + check_rep=False, + ) + self.update_single_cache_line = jax.jit(self.update_single_cache_line) + + # pylint: disable=method-hidden + # False alarm. The jit above doesn't hide this method. + def update_single_cache_line( + self, + cache_k, + cache_v, + new_ks, + new_vs, + k_scaler, + v_scaler, + new_k_scaler, + new_v_scaler, + pos, + ): + """The shard map version of single cache line update.""" + b = cache_k.shape[-4] + + for bb, pp in enumerate(pos.reshape(b)): + slice_dim = 0 + update_start_indices = (bb, 0, pp, 0) + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + slice_dim = 1 + update_start_indices = (0, bb, 0, pp, 0) + if self.env.generate_cache_stacked and not self.env.new_cache_stacked: + for layer in range(self.env.num_layers): + update_start_indices = (layer, bb, 0, pp, 0) + new_ks_slice = jax.lax.dynamic_slice_in_dim( + new_ks[layer], bb, 1, slice_dim + ) + new_ks_slice = jnp.expand_dims(new_ks_slice, 0) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + + new_vs_slice = jax.lax.dynamic_slice_in_dim( + new_vs[layer], bb, 1, slice_dim + ) + new_vs_slice = jnp.expand_dims(new_vs_slice, 0) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_k_scaler[layer], bb, 1, slice_dim + ) + new_k_scaler_slice = jnp.expand_dims(new_k_scaler_slice, 0) + k_scaler = jax.lax.dynamic_update_slice( + k_scaler, new_k_scaler_slice, update_start_indices + ) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_v_scaler[layer], bb, 1, slice_dim + ) + new_v_scaler_slice = jnp.expand_dims(new_v_scaler_slice, 0) + v_scaler = jax.lax.dynamic_update_slice( + v_scaler, new_v_scaler_slice, update_start_indices + ) + else: + new_ks_slice = jax.lax.dynamic_slice_in_dim(new_ks, bb, 1, slice_dim) + cache_k = jax.lax.dynamic_update_slice( + cache_k, new_ks_slice, update_start_indices + ) + + new_vs_slice = jax.lax.dynamic_slice_in_dim(new_vs, bb, 1, slice_dim) + cache_v = jax.lax.dynamic_update_slice( + cache_v, new_vs_slice, update_start_indices + ) + + new_k_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_k_scaler, bb, 1, slice_dim + ) + k_scaler = jax.lax.dynamic_update_slice( + k_scaler, new_k_scaler_slice, update_start_indices + ) + + new_v_scaler_slice = jax.lax.dynamic_slice_in_dim( + new_v_scaler, bb, 1, slice_dim + ) + v_scaler = jax.lax.dynamic_update_slice( + v_scaler, new_v_scaler_slice, update_start_indices + ) + + return cache_k, cache_v, k_scaler, v_scaler def state(self): """Get kv cache state""" @@ -189,13 +511,17 @@ def scalers(self): @classmethod # pylint: disable-next=all - def empty(cls, shape, device, bf16_enable, env): + def empty(cls, shape, device, env): """Create empty kv caches""" cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8) cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8) - # bf16_enable is a placeholder parameter, it's not used in Int8KVCache - kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) - vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16) + + if env.generate_cache_stacked: + s_shape = (shape[0], shape[1], 1, shape[3], 1) + else: + s_shape = (shape[0], 1, shape[2], 1) + kscaler = jnp.ones(s_shape, dtype=jnp.bfloat16) + vscaler = jnp.ones(s_shape, dtype=jnp.bfloat16) cache_k, cache_v, kscaler, vscaler = torchjax.to_torch( (cache_k, cache_v, kscaler, vscaler) @@ -205,23 +531,126 @@ def empty(cls, shape, device, bf16_enable, env): def quantize(self, val): """Quantize value""" # val is (batch, heads, seqlen, dim) - scale = torch.amax(val.abs(), axis=(1, 3), keepdim=True) + scale = torch.amax(val.abs(), axis=(-3, -1), keepdim=True) scale = scale / 127 return (val / scale).to(torch.int8), scale - def update(self, xk, xv): + def update(self, xk, xv, layer_id: int): """Update kv cache""" k_quant, kscale = self.quantize(xk) v_quant, vscale = self.quantize(xv) - if self.env.ring_buffer: + + if self.env.lazy_cache_update: + if self.env.new_cache_stacked: + self.new_ks[layer_id, ...] = k_quant + self.new_vs[layer_id, ...] = v_quant + self.new_k_scaler[layer_id, ...] = kscale + self.new_v_scaler[layer_id, ...] = vscale + else: + if self.env.generate_cache_stacked: + self.new_ks.append(k_quant) + self.new_vs.append(v_quant) + self.new_k_scaler.append(kscale) + self.new_v_scaler.append(vscale) + else: + self.new_ks = k_quant + self.new_vs = v_quant + self.new_k_scaler = kscale + self.new_v_scaler = vscale + elif self.env.ring_buffer: self.cache_k[:, :, self.input_pos, :] = k_quant self.cache_v[:, :, self.input_pos, :] = v_quant self.k_scaler[:, :, self.input_pos, :] = kscale self.v_scaler[:, :, self.input_pos, :] = vscale else: - batch = jnp.arange(self.env.batch_size) - self.cache_k[batch, :, self.input_pos, :] = k_quant.squeeze(2) - self.cache_v[batch, :, self.input_pos, :] = v_quant.squeeze(2) - self.k_scaler[batch, :, self.input_pos, :] = kscale.squeeze(2) - self.v_scaler[batch, :, self.input_pos, :] = vscale.squeeze(2) - return self.cache_k, self.cache_v, self.k_scaler, self.v_scaler + # We don't handle left aligned but lazy_cache_update=False + self.cache_k[self.batch, :, self.input_pos, :] = k_quant.squeeze(2) + self.cache_v[self.batch, :, self.input_pos, :] = v_quant.squeeze(2) + self.k_scaler[self.batch, :, self.input_pos, :] = kscale.squeeze(2) + self.v_scaler[self.batch, :, self.input_pos, :] = vscale.squeeze(2) + + return ( + self.cache_k, + self.cache_v, + k_quant, + v_quant, + self.k_scaler, + self.v_scaler, + kscale, + vscale, + ) + + def finalize(self): + """Finalize the cache operation and updates the cache.""" + if not self.env.lazy_cache_update: + return + if self.env.ring_buffer: + # Assume no cache stack for ring buffer + # pylint: disable-next=all + self.cache_k._elem = ( + self.cache_k.jax().at[..., self.input_pos, :].set(self.new_ks.jax()) + ) + # pylint: disable-next=all + self.cache_v._elem = ( + self.cache_v.jax().at[..., self.input_pos, :].set(self.new_vs.jax()) + ) + else: + if self.env.generate_cache_stacked: + if self.env.new_cache_stacked: + # new kv scaler also has to go through shard_map instead of indexing + # because it needs to reshape to (batch, layer) which mess up with the data + caches = [ + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, + self.new_k_scaler, + self.new_v_scaler, + ] + ( + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, *caches, self.input_pos + ) + else: + caches = [ + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, + self.new_k_scaler, + self.new_v_scaler, + ] + ( + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, *caches, self.input_pos + ) + else: + ( + self.cache_k, + self.cache_v, + self.k_scaler, + self.v_scaler, + ) = torch_xla2.interop.call_jax( + self.update_single_cache_line, + self.cache_k, + self.cache_v, + self.new_ks, + self.new_vs, + self.k_scaler, + self.v_scaler, + self.new_k_scaler, + self.new_v_scaler, + self.input_pos, + ) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 78f8da9f..b22d0287 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -90,6 +90,31 @@ "Whether to enable ring buffer", required=False, ) +flags.DEFINE_bool( + "flash_attention", + False, + "Whether to enable flas attention. Only takes effect at test mode", + required=False, +) +flags.DEFINE_bool( + "generate_cache_stacked", + False, + "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", + required=False, +) +flags.DEFINE_bool( + "new_cache_stacked", + False, + "Whether to stack the generate cache to the layer dimension. Only takes effect at test mode", + required=False, +) +flags.DEFINE_bool( + "lazy_cache_update", + False, + "Whether to update the cache during attention or delayed until all the layers are done. " + "Only takes effect at test mode", + required=False, +) flags.DEFINE_float( "temperature", 1.0, @@ -132,11 +157,14 @@ def create_quantization_config_from_flags(): return config -def create_engine_from_config_flags(): +def create_engine_from_config_flags(batch=None, cache_len=None): """create a pytorch engine from cmd flag""" jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + batch = batch or FLAGS.batch_size + cache_len = cache_len or FLAGS.max_cache_length + devices = jax.devices() start = time.perf_counter() @@ -171,9 +199,9 @@ def create_engine_from_config_flags(): bf16_enable=FLAGS.bf16_enable, param_size=FLAGS.size, context_length=FLAGS.context_length, - batch_size=FLAGS.batch_size, + batch_size=batch, quant_config=quant_config, - max_cache_length=FLAGS.max_cache_length, + max_cache_length=cache_len, max_decode_length=FLAGS.max_decode_length, sharding_config=sharding_file_name, shard_on_batch=FLAGS.shard_on_batch, @@ -184,6 +212,10 @@ def create_engine_from_config_flags(): nucleus_topp=FLAGS.nucleus_topp, topk=FLAGS.topk, ring_buffer=FLAGS.ring_buffer, + flash_attention=FLAGS.flash_attention, + generate_cache_stacked=FLAGS.generate_cache_stacked, + new_cache_stacked=FLAGS.new_cache_stacked, + lazy_cache_update=FLAGS.lazy_cache_update, ) print("Initialize engine", time.perf_counter() - start) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index c168614d..9ee6a3fd 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -139,7 +139,7 @@ def init_decode_state( (self.env.batch_size, self.env.cache_sequence_length), float("-inf"), dtype=self.default_dtype, - ), + ), # mask ) # pylint: disable-next=all @@ -189,7 +189,10 @@ def _call_model_generate( # The mode is needed so that tensors created inside of # the model (such as via torch.ones etc) also have the right type res = torch.func.functional_call(self.pt_model, paramst, argst) - updated_caches = [c.state() for c in caches_obj] + updated_caches = [] + for c in caches_obj: + c.finalize() + updated_caches.append(c.state()) scales = [] if self.env.quant_config.enable_kv_quantization: scales = [c.scalers() for c in caches_obj] @@ -212,7 +215,8 @@ def _call_model_prefill(self, weights, tokens, input_indexes): dtype=self.default_dtype, ) mask = jnp.triu(mask, k=1) - args = (tokens, input_indexes, caches, mask) + start = jnp.zeros((tokens.shape[0],), dtype=jnp.int32) + args = (tokens, input_indexes, caches, mask, start) paramst, argst = torchjax.to_torch((weights, args)) with self._lock: @@ -322,7 +326,7 @@ def _insert_no_wrap( tokens = decode_state.tokens.at[slot].set(prefix.token) x = jnp.arange(0, self.env.cache_sequence_length) - cond = jnp.logical_and(x <= current_pos, x >= pos) + cond = jnp.logical_and(x < current_pos, x >= pos) mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) start = decode_state.start.at[slot].set( @@ -332,31 +336,48 @@ def _insert_no_wrap( if not self.env.quant_config.enable_kv_quantization: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) - def insert(cache, new_entry): + def insert(cache, new_entry, update_index): res = jax.lax.dynamic_update_slice( cache, new_entry, - [slot, 0, pos, 0], + update_index, ) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) return res - caches = [ - (insert(k, newk), insert(v, newv)) - for (k, v), (newk, newv) in zip(decode_state.caches, prefix.caches) - ] + if self.env.generate_cache_stacked: + caches = decode_state.caches + for idx, (newk, newv) in enumerate(prefix.caches): + update_index = [idx, slot, 0, pos, 0] + newk = jnp.expand_dims(newk, 0) + newv = jnp.expand_dims(newv, 0) + caches = [ + ( + insert(caches[0][0], newk, update_index), + insert(caches[0][1], newv, update_index), + ) + ] + else: + update_index = [slot, 0, pos, 0] + caches = [ + (insert(k, newk, update_index), insert(v, newv, update_index)) + for (k, v), (newk, newv) in zip(decode_state.caches, prefix.caches) + ] else: @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) - def insert(cache, scaler, new_entry): - reduce_axis = (1, 3) + def insert(cache, scaler, new_entry, update_index): + reduce_axis = (-3, -1) vals, scales, _ = torchjax.call_torch( quantize.quantize_tensor, new_entry, reduce_axis ) + if self.env.generate_cache_stacked: + vals = jnp.expand_dims(vals, 0) + scales = jnp.expand_dims(scales, 0) new_scaler = jax.lax.dynamic_update_slice( scaler, scales, - [slot, 0, pos, 0], + update_index, ) new_scaler = jax.lax.with_sharding_constraint( new_scaler, self.replicated @@ -364,19 +385,37 @@ def insert(cache, scaler, new_entry): res = jax.lax.dynamic_update_slice( cache, vals, - [slot, 0, pos, 0], + update_index, ) res = jax.lax.with_sharding_constraint(res, self.cache_sharding) return res, new_scaler - for (k, v), (kscaler, vscaler), (newk, newv) in zip( - decode_state.caches, decode_state.cache_scales, prefix.caches - ): - kcache, kscale = insert(k, kscaler, newk) - vcache, vscale = insert(v, vscaler, newv) - caches.append((kcache, vcache)) - scales.append((kscale, vscale)) - + if self.env.generate_cache_stacked: + cache_k, k_scale = ( + decode_state.caches[0][0], + decode_state.cache_scales[0][0], + ) + cache_v, v_scale = ( + decode_state.caches[0][1], + decode_state.cache_scales[0][1], + ) + for idx, (newk, newv) in enumerate(prefix.caches): + update_index = [idx, slot, 0, pos, 0] + # newk = jnp.expand_dims(newk, 0) + # newv = jnp.expand_dims(newv, 0) + cache_k, k_scale = insert(cache_k, k_scale, newk, update_index) + cache_v, v_scale = insert(cache_v, v_scale, newv, update_index) + caches = [(cache_k, cache_v)] + scales = [(k_scale, v_scale)] + else: + update_index = [slot, 0, pos, 0] + for (k, v), (kscaler, vscaler), (newk, newv) in zip( + decode_state.caches, decode_state.cache_scales, prefix.caches + ): + kcache, kscale = insert(k, kscaler, newk, update_index) + vcache, vscale = insert(v, vscaler, newv, update_index) + caches.append((kcache, vcache)) + scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) return DecodeState( tokens, @@ -410,10 +449,10 @@ def _insert_wrap( cond = jax.lax.cond( decode_state.current_position > start_insert, lambda x, start_insert, current_position: jnp.logical_and( - x >= start_insert, x <= current_position + x >= start_insert, x < current_position ), lambda x, start_insert, current_position: jnp.logical_or( - x >= start_insert, x <= current_position + x >= start_insert, x < current_position ), x, start_insert, @@ -488,11 +527,6 @@ def insert( decode_state: DecodeState, slot: int, ) -> DecodeState: - # logging.info( - # 'Jet input prefix: %s, decode state before insert: %s', - # prefix, - # decode_state, - # ) if self.env.ring_buffer: start_insert = decode_state.current_position - prefix.seq_len end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen @@ -574,11 +608,9 @@ def generate( pos = decode_state.current_position if self.env.ring_buffer: input_indexes = jnp.full((1,), pos) - mask = decode_state.mask.at[:, decode_state.current_position].set(0) else: input_indexes = decode_state.input_pos - batch = jnp.arange(self.env.batch_size) - mask = decode_state.mask.at[batch, decode_state.input_pos].set(0) + ragged_batch_index, ragged_block_index = ( self.precompute_ragged_block_indices(decode_state) ) @@ -586,6 +618,16 @@ def generate( (-1) ), ragged_block_index.reshape((-1)) + def update_mask(): + if self.env.ring_buffer: + return decode_state.mask.at[:, decode_state.current_position].set(0) + + batch = jnp.arange(self.env.batch_size) + return decode_state.mask.at[batch, decode_state.input_pos].set(0) + + mask = decode_state.mask + if not self.env.lazy_cache_update: + mask = update_mask() logits, new_caches, new_scales = self._call_model_generate( params, decode_state.tokens, @@ -599,6 +641,10 @@ def generate( ragged_block_index, ) + if self.env.lazy_cache_update: + # fill mask later, now use flash attention + mask = update_mask() + next_token = self._sampling(logits, self.env.batch_size) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 @@ -642,11 +688,6 @@ def generate( input_pos, mask, ) - print( - "new_pos", - (decode_state.current_position + 1) % self.env.cache_sequence_length, - ) - print(f"new_token: {jnp.squeeze(next_token)}") return new_decode_state, result_tokens # pylint: disable-next=all @@ -824,6 +865,10 @@ def create_pytorch_engine( nucleus_topp=None, topk=None, ring_buffer=True, + flash_attention=False, + generate_cache_stacked=False, + new_cache_stacked=False, + lazy_cache_update=False, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -894,6 +939,10 @@ def create_pytorch_engine( nucleus_topp=nucleus_topp, topk=topk, ring_buffer=ring_buffer, + flash_attention=flash_attention, + generate_cache_stacked=generate_cache_stacked, + new_cache_stacked=new_cache_stacked, + lazy_cache_update=lazy_cache_update, ) if shard_on_batch and sharding_config: diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index fce606d9..227d57fa 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -18,6 +18,7 @@ import yaml import jax +import jax.numpy as jnp import jax.sharding as jsharding from jax.experimental import mesh_utils import torch_xla2 @@ -98,11 +99,18 @@ class JetEngineEnvironmentData: block_size: int = 512 # Starting position - starting_position: int = 512 + starting_position: int = 0 # Ring buffer ring_buffer: bool = True + flash_attention: bool = False + + generate_cache_stacked: bool = False + + new_cache_stacked: bool = False + + lazy_cache_update: bool = False # Variables used in token sampling # sampling algorithm to use ("greedy", "weighted", "neucleus", "topk") sampling_algorithm: str = "greedy" @@ -116,6 +124,10 @@ class JetEngineEnvironmentData: # temperature parameter for scaling probability temperature: float = 1.0 + testing: bool = False + + testing_seed: int = 0 + # pylint: disable-next=all class JetEngineEnvironment: @@ -126,10 +138,34 @@ def __init__(self, data: JetEngineEnvironmentData): self.batch_size = self._data.batch_size self.seq_len = self._data.max_input_sequence_length self.cache_len = self._data.cache_sequence_length - self.ragged_mha = self._data.ragged_mha self.block_size = self._data.block_size self.starting_position = self._data.starting_position + self.num_layers = self._data.num_layers + self.testing = self._data.testing + self.testing_seed = self._data.testing_seed self.ring_buffer = self._data.ring_buffer + + if not self.ring_buffer: + self.lazy_cache_update = True + self.ragged_mha = True + self.flash_attention = True + self.generate_cache_stacked = True + self.new_cache_stacked = True + + if self.testing: + self.lazy_cache_update = self._data.lazy_cache_update + self.ragged_mha = self._data.ragged_mha + self.flash_attention = self._data.flash_attention + self.generate_cache_stacked = self._data.generate_cache_stacked + self.new_cache_stacked = self._data.new_cache_stacked + + self.default_type = jnp.bfloat16 if self._data.bf16_enable else jnp.float32 + + if self.generate_cache_stacked: + self.cache_shape = (self.num_layers, *self._data.cache_shape) + else: + self.cache_shape = self._data.cache_shape + P = jax.sharding.PartitionSpec num_of_partitions = jax.device_count() @@ -143,19 +179,29 @@ def __init__(self, data: JetEngineEnvironmentData): self.x_sharding = jsharding.NamedSharding(self.mesh, P("x")) self.replicated = jsharding.NamedSharding(self.mesh, P()) + if self.generate_cache_stacked: + self.attention_kv_axis_names = ( + "layer", + "batch", + "num_attn_heads", + "sequence_length", + "head_dim", + ) if data.shard_on_batch: - cache_sharding_axis = 0 + self.kv_cache_shard_axis = "batch" else: - cache_sharding_axis = self.attention_kv_axis_names.index( - self.kv_cache_shard_axis - ) + self.kv_cache_shard_axis = "num_attn_heads" - if self.cache_shape[cache_sharding_axis] == 1: + self.cache_sharding_axis = self.attention_kv_axis_names.index( + self.kv_cache_shard_axis + ) + + if self.cache_shape[self.cache_sharding_axis] == 1: # cannot shard on an axis that is 1 # default to last - cache_sharding_axis = len(self.cache_shape) - 1 + self.cache_sharding_axis = len(self.cache_shape) - 1 - self.cache_sharding = self.sharding_by_axis(cache_sharding_axis) + self.cache_sharding = self.sharding_by_axis(self.cache_sharding_axis) self._load_sharding_config() def _load_sharding_config(self): @@ -201,19 +247,20 @@ def make_caches_prefill(self): def make_caches_generate(self): """Create kv caches for inference generation""" caches = [] - shape = self._data.cache_shape - for _ in range(self.num_layers): + layered_cache_count = 1 if self.generate_cache_stacked else self.num_layers + + for _ in range(layered_cache_count): if self._data.quant_config.enable_kv_quantization: caches.append( cache_manager.Int8KVCacheGenerate.empty( - shape, self.cache_sharding, self.bf16_enable, env=self + self.cache_shape, self.cache_sharding, env=self ) ) else: caches.append( cache_manager.KVCacheGenerate.empty( - shape, self.cache_sharding, self.bf16_enable, env=self + self.cache_shape, self.cache_sharding, env=self ) ) return caches diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 8ef7f131..dc98f605 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -349,9 +349,19 @@ def apply_rotary_emb( def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep).""" - bs, n_kv_heads, slen, head_dim = x.shape + *_, bs, n_kv_heads, slen, head_dim = x.shape + stacked = x.ndim == 5 + if n_rep == 1: return x + + if stacked: + layer = x.shape[0] + return ( + x[:, :, :, None, :, :] + .expand(layer, bs, n_kv_heads, n_rep, slen, head_dim) + .reshape(layer, bs, n_kv_heads * n_rep, slen, head_dim) + ) return ( x[:, :, None, :, :] .expand(bs, n_kv_heads, n_rep, slen, head_dim) @@ -361,18 +371,36 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class AttentionKernel: - def __init__(self, env): + def __init__(self, env, layer_id): self.env = env - self.shard_axis = 0 if self.env.shard_on_batch else 1 - qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + self.q_shard_axis = 0 if self.env.shard_on_batch else 1 + self.kv_shard_axis = ( + 0 + if self.env.shard_on_batch + else 2 + if self.env.generate_cache_stacked + else 1 + ) + q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads + kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention - self.ragged_attention = ak.RaggedAttentionKernel( + self.flash_attention = ak.flash_attention + self.ragged_attention_orig = ak.RaggedAttentionKernel( + env, + input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.kv_shard_axis, + ) + self.ragged_attention_new = ak.RaggedAttentionKernel( env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 4)), - output_specs=(qkv_pspec, (others_pspec, others_pspec)), - sharding_axis=self.shard_axis, + input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.q_shard_axis, ) + self.layer_id = layer_id def __call__( self, @@ -395,53 +423,141 @@ def __call__( cache: CacheManagerInterface object """ bsz, num_heads, seqlen, head_dim = xq.shape - _, num_kv_heads, _, kv_head_dim = xk.shape + num_kv_heads = xk.shape[-3] + kv_head_dim = xk.shape[-1] n_rep = num_heads // num_kv_heads - if not self.env.ragged_mha and seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - with jax.named_scope("attn_insert_cache"): - keys, values = cache.update(xk, xv) - keys = repeat_kv(keys, n_rep) - values = repeat_kv(values, n_rep) + def attend(xq, keys, values, local_mask=None): + if keys.ndim == 4: + impl = self.ragged_attention_new + else: + impl = self.ragged_attention_orig + + true_len = seqlen + # When GQA is enabled, it not necessary to expand + if n_rep == 1 and seqlen == 1: + true_len = 2 + xq = torch.nn.functional.pad( + xq, (0, 0, 0, true_len - seqlen), "constant", 0 + ) - with jax.named_scope("attn_qkv"): if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax( - self.ragged_attention, + local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( + impl, xq, keys, values, + self.layer_id, start, end, ragged_batch_index, ragged_block_index, ) + elif self.env.flash_attention: + with torch_xla2.default_env(): + local_output, (local_max, local_denom) = self.flash_attention( + xq, keys, values, self.layer_id, mask=local_mask + ) else: - output = self.dense_attention(xq, keys, values, None, None, mask) + local_output = self.dense_attention( + xq, keys, values, None, None, local_mask + ) + local_max = None + local_denom = None + + local_output = local_output.reshape(bsz, num_heads, true_len, head_dim) + if local_max is not None: + local_max = local_max.reshape(bsz, num_heads, true_len, 1) + local_denom = local_denom.reshape(bsz, num_heads, true_len, 1) + + if true_len != seqlen: + local_output = local_output[:, :, 0:seqlen, :] + if local_max is not None: + local_max = local_max[:, :, 0:seqlen, :] + if local_denom is not None: + local_denom = local_denom[:, :, 0:seqlen, :] + + # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") + # if local_max is not None and local_denom is not None: + # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") + self.env.apply_sharding(local_output, axis=self.q_shard_axis) + return local_output, (local_max, local_denom) + + with jax.named_scope("attn_insert_cache"): + orig_keys, orig_values = cache.update(xk, xv, self.layer_id) + # We are not using ragged attention for prefill yet. + if not self.env.ragged_mha or seqlen > 1: + orig_keys = repeat_kv(orig_keys, n_rep) + orig_values = repeat_kv(orig_values, n_rep) + + # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") + with jax.named_scope("attn_qkv"): + existing_output, (existing_max, existing_denom) = attend( + xq, orig_keys, orig_values, mask + ) + # Updating cache during each step still has very large impact on latency. + # For non flash attention or prefill, existing output contains everything + if not self.env.lazy_cache_update or seqlen > 1: + return existing_output + + # For flash attention, existing output contains the existing kv cache generated logits + with jax.named_scope("attn_new_qkv"): + if not self.env.ragged_mha or seqlen > 1: + xk = repeat_kv(xk, n_rep) + xv = repeat_kv(xv, n_rep) + new_output, (new_max, new_denom) = attend(xq, xk, xv, None) + + with jax.named_scope("attn_global"): + # print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") + # print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") + + global_sum = existing_denom * torch.exp( + existing_max + ) + new_denom * torch.exp(new_max) + existing_output = ( + existing_output + * existing_denom + * torch.exp(existing_max) + / global_sum + ) + new_output = new_output * new_denom * torch.exp(new_max) / global_sum + attn_out = existing_output + new_output - if not self.env.ragged_mha and seqlen == 1: - output = output[:, :, 0:1, :] - # For XLA matmul performance boost - # output = torch.matmul(scores, values) - self.env.apply_sharding(output, axis=self.shard_axis) - return output + return attn_out class Int8KVAttentionKernel: - def __init__(self, env): + def __init__(self, env, layer_id): self.env = env - self.shard_axis = 0 if self.env.shard_on_batch else 1 - qkv_pspec = self.env.partition_by_axis(self.shard_axis) # Number of heads + self.q_shard_axis = 0 if self.env.shard_on_batch else 1 + self.kv_shard_axis = ( + 0 + if self.env.shard_on_batch + else 2 + if self.env.generate_cache_stacked + else 1 + ) + q_pspec = self.env.partition_by_axis(self.q_shard_axis) # Number of heads + kv_pspec = self.env.partition_by_axis(self.kv_shard_axis) # Number of heads others_pspec = self.env.partition_by_axis() self.dense_attention = ak.dense_attention - self.ragged_attention = ak.RaggedAttentionKernel( + self.flash_attention = ak.flash_attention + self.ragged_attention_orig = ak.RaggedAttentionKernel( + env, + input_specs=(q_pspec, kv_pspec, kv_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.kv_shard_axis, + ) + self.ragged_attention_new = ak.RaggedAttentionKernel( env, - input_specs=(*([qkv_pspec] * 3), *([others_pspec] * 6)), - output_specs=(qkv_pspec, (others_pspec, others_pspec)), - sharding_axis=self.shard_axis, + input_specs=(q_pspec, q_pspec, q_pspec, *([others_pspec] * 7)), + output_specs=(q_pspec, (q_pspec, q_pspec)), + q_shard_axis=self.q_shard_axis, + kv_shard_axis=self.q_shard_axis, ) + self.layer_id = layer_id def __call__( self, @@ -464,24 +580,33 @@ def __call__( cache: CacheManagerInterface object """ bsz, num_heads, seqlen, head_dim = xq.shape - _, num_kv_heads, _, kv_head_dim = xk.shape + num_kv_heads = xk.shape[-3] + kv_head_dim = xk.shape[-1] n_rep = num_heads // num_kv_heads - if not self.env.ragged_mha and seqlen == 1: - xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) - - with jax.named_scope("attn_insert_cache"): - keys, values, k_scaler, v_scaler = cache.update(xk, xv) - keys = repeat_kv(keys, n_rep) - values = repeat_kv(values, n_rep) + def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): + if keys.ndim == 4: + impl = self.ragged_attention_new + else: + impl = self.ragged_attention_orig + + true_len = seqlen + # When GQA is enabled, it not necessary to expand + if n_rep == 1 and seqlen == 1: + true_len = 2 + xq = torch.nn.functional.pad( + xq, (0, 0, 0, true_len - seqlen), "constant", 0 + ) + # xq = torch.broadcast_to(xq, (bsz, num_heads, true_len, head_dim)) - with jax.named_scope("attn_qkv"): + # We are not using ragged attention for prefill yet. if self.env.ragged_mha and seqlen == 1: - output, _ = torch_xla2.interop.call_jax( - self.ragged_attention, + local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( + impl, xq, keys, values, + self.layer_id, start, end, ragged_batch_index, @@ -489,22 +614,94 @@ def __call__( k_scaler, v_scaler, ) + elif self.env.flash_attention: + with torch_xla2.default_env(): + local_output, (local_max, local_denom) = self.flash_attention( + xq, + keys, + values, + self.layer_id, + k_scaler, + v_scaler, + mask=local_mask, + ) else: - output = self.dense_attention( - xq, keys, values, k_scaler, v_scaler, mask + local_output = self.dense_attention( + xq, keys, values, k_scaler, v_scaler, local_mask ) + local_max = None + local_denom = None - if not self.env.ragged_mha and seqlen == 1: - output = output[:, :, 0:1, :] + local_output = local_output.reshape(bsz, num_heads, true_len, head_dim) + if local_max is not None: + local_max = local_max.reshape(bsz, num_heads, true_len, 1) + local_denom = local_denom.reshape(bsz, num_heads, true_len, 1) - self.env.apply_sharding(output, axis=self.shard_axis) - return output + if true_len != seqlen: + local_output = local_output[:, :, 0:seqlen, :] + if local_max is not None: + local_max = local_max[:, :, 0:seqlen, :] + local_denom = local_denom[:, :, 0:seqlen, :] + + self.env.apply_sharding(local_output, axis=self.q_shard_axis) + return local_output, (local_max, local_denom) + + with jax.named_scope("attn_insert_cache"): + ( + orig_keys, + orig_values, + new_key, + new_value, + k_scaler, + v_scaler, + new_k_scaler, + new_v_scaler, + ) = cache.update(xk, xv, self.layer_id) + # We are not using ragged attention for prefill yet. + if not self.env.ragged_mha or seqlen > 1: + orig_keys = repeat_kv(orig_keys, n_rep) + orig_values = repeat_kv(orig_values, n_rep) + with jax.named_scope("attn_qkv"): + existing_output, (existing_max, existing_denom) = attend( + xq, orig_keys, orig_values, k_scaler, v_scaler, mask + ) + + # For non flash attention or prefill, existing output contains everything + if not self.env.lazy_cache_update or seqlen > 1: + return existing_output + + # For flash attention, existing output contains the existing kv cache generated logits + with jax.named_scope("attn_new_qkv"): + # At this point, flash attention or ragged attention must have been enabled + if not self.env.ragged_mha or seqlen > 1: + new_key = repeat_kv(new_key, n_rep) + new_value = repeat_kv(new_value, n_rep) + new_output, (new_max, new_denom) = attend( + xq, new_key, new_value, new_k_scaler, new_v_scaler, None + ) + + with jax.named_scope("attn_global"): + global_sum = existing_denom * torch.exp( + existing_max + ) + new_denom * torch.exp(new_max) + existing_output = ( + existing_output + * existing_denom + * torch.exp(existing_max) + / global_sum + ) + new_output = new_output * new_denom * torch.exp(new_max) / global_sum + attn_out = existing_output + new_output + + return attn_out class Attention(nn.Module): """Attention module.""" - def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): + def __init__( + self, n_heads, n_kv_heads, head_dim, hidden_size, device, env, layer_id + ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads @@ -512,6 +709,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): self.n_rep = self.n_heads // self.n_kv_heads self.env = env self.hidden_size = hidden_size + self.layer_id = layer_id LinearLayer = get_quantized_linear_layer(env.quant_config) linear_kwargs = {} @@ -531,7 +729,7 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env): if env.quant_config.enable_kv_quantization else AttentionKernel ) - self.attention_kernel = Kernel(env) + self.attention_kernel = Kernel(env, self.layer_id) self.q_size = n_heads * self.head_dim self.kv_size = self.n_kv_heads * self.head_dim @@ -611,16 +809,26 @@ def forward( xv = xv.transpose(1, 2) xq = xq.transpose(1, 2) + if mask.ndim == 2: + if seqlen == 1: + mask = mask[:, None, None, :] + else: + mask = mask[None, None, :, :] + + # if cache is not None and cache.cache_k is not None: + # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( xq, xk, xv, mask, + # cache[self.layer_id], cache, start, end, ragged_batch_index, ragged_block_index, ).type_as(xq) + # print(f"output {output.shape}") output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) return self.wo(output) diff --git a/jetstream_pt/offline_inference.py b/jetstream_pt/offline_inference.py new file mode 100644 index 00000000..1834ed3e --- /dev/null +++ b/jetstream_pt/offline_inference.py @@ -0,0 +1,192 @@ +from typing import Callable +import dataclasses +from collections import defaultdict +import jax +from jax import numpy as jnp +import numpy as np + +from jetstream.engine import engine_api + +import logging + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class InputData: + id: str + tokens: jax.Array + true_length: int + + +class OfflineInference: + + def __init__(self, engine: engine_api.Engine, params=None): + self.engine = engine + self.decode_state = None + if params is None: + params = engine.load_params() + self.params = params + + self.batch_size = engine.env.batch_size + self.max_decode_length = engine.max_decode_length + metadata = engine.get_tokenizer() + self.tokenizer = engine.build_tokenizer(metadata) + self.dummy = False + + self._cached_pref = {} + self._cached_generate = None + + def init_decode_state(self): + if self.decode_state is None: + self.decode_state = self.engine.init_decode_state() + + def warmup(self, max_length=2048): + self.init_decode_state() + interesting_buckets = [ + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + ] + for length in interesting_buckets: + if length > max_length: + break + log.info(f"Compiling prefill: {length}") + input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32")) + self._cached_pref[length] = ( + jax.jit(self._prefill_insert, donate_argnums=(4,)) + .lower( + self.params, + tokens=input_data, + slot=0, + true_length=length - 1, + decode_state=self.decode_state) + .compile() + ) + + log.info(f"Compiling decode") + self._cached_generate = ( + jax.jit(self.engine.generate, donate_argnums=(1,)) + .lower(self.params, self.decode_state) + .compile() + ) + + def _prefill_insert(self, params, tokens, slot, true_length, decode_state): + """return decodestate.""" + prefill_result, first_token = self.engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) + decode_state = self.engine.insert(prefill_result, decode_state, slot=slot) + return first_token, decode_state + + def batch_inference_with_callback( + self, + data: InputData, + emit_first_token: Callable[[str, int], bool], + emit_token: Callable[[str, int], bool], + ): + """callback is a function that takes id and token. It will be called once per output + + token. + """ + + def prefill(slot, tokens, true_length): + nonlocal self + if self.dummy: + log.debug("dummy prefill") + return 123 + + prefill_fn = self._prefill_insert + if (cached := self._cached_pref.get(len(tokens))) is not None: + prefill_fn = cached + + first_token, self.decode_state = prefill_fn( + self.params, tokens=tokens, slot=slot, + true_length=true_length, decode_state=self.decode_state + ) + return first_token.data[0][0].item() + + empty_slots = list(range(self.batch_size)) + slot_to_id = {} + + dummy_length = 1 + + def decode(): + log.debug("decode") + nonlocal self + nonlocal slot_to_id + nonlocal dummy_length + if self.dummy: + log.debug("Dummy generate") + res = engine_api.ResultTokens( + data=np.array([[123, 1, dummy_length]] * self.batch_size), + tokens_idx=(0, 0), + valid_idx=(0, 0), + length_idx=(0, 0), + samples_per_slot=(0, 0), + ) + dummy_length += 1 + self.decode_state, result_tokens = self.decode_state, res + else: + gen_fn = self.engine.generate + if self._cached_generate is not None: + gen_fn = self._cached_generate + self.decode_state, result_tokens = gen_fn( + self.params, self.decode_state + ) + + result_tokens = result_tokens.convert_to_numpy() + + newly_empty = [] + for slot, id_ in slot_to_id.items(): + token, is_valid, length = result_tokens.data[slot] + log.debug(f"slot is {slot}, length is {length}") + should_finish = False + if is_valid: + should_finish = emit_token(id_, token.item()) + if should_finish or length >= self.max_decode_length: + newly_empty.append(slot) + + # Add slots of those that are empty to emtpy + for slot in newly_empty: + del slot_to_id[slot] + empty_slots.append(slot) + + for row in data: + log.debug(f"empty_slots {len(empty_slots)}") + while not empty_slots: + # If slots are all full, decode until there are free slots + # to insert + decode() + # do one insert + log.debug(f"prefill {row.id}") + slot = empty_slots.pop() + first_token = prefill(slot, row.tokens, row.true_length) + should_terminate = emit_first_token(row.id, first_token) + if not should_terminate: + slot_to_id[slot] = row.id + else: + empty_slots.append(slot) # dont use the slot + + while slot_to_id: + log.debug(f"slot to id {len(slot_to_id)}") + decode() + + def batch_inference(self, data: InputData): + """data is list of obj with id, tokens, and true length""" + ans = defaultdict(list) + + def callback(id_, token): + nonlocal ans + ans[id_].append(token) + return token == self.tokenizer.eos_id + + self.batch_inference_with_callback( + data, emit_first_token=callback, emit_token=callback + ) + return ans diff --git a/jetstream_pt/ray_worker.py b/jetstream_pt/ray_worker.py index 25857f75..2a35bb7e 100644 --- a/jetstream_pt/ray_worker.py +++ b/jetstream_pt/ray_worker.py @@ -567,7 +567,7 @@ def insert(cache, new_entry): @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) def insert(cache, scaler, new_entry): - reduce_axis = (1, 3) + reduce_axis = (-3, -1) vals, scales, _ = torchjax.call_torch( quantize.quantize_tensor, new_entry, reduce_axis ) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 1072dad9..5773b8bd 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -73,6 +73,7 @@ def __init__( head_dim: int, device, env, + layer_id, ): super().__init__() @@ -135,7 +136,7 @@ def __init__( if env.quant_config.enable_kv_quantization else layers.AttentionKernel ) - self.attention_kernel = Kernel(env) + self.attention_kernel = Kernel(env, layer_id) def forward( self, @@ -272,7 +273,7 @@ def forward(self, x): class GemmaDecoderLayer(nn.Module): - def __init__(self, config: gemma_config.GemmaConfig, env): + def __init__(self, config: gemma_config.GemmaConfig, env, layer_id): super().__init__() self.self_attn = GemmaAttention( config.hidden_size, @@ -281,6 +282,7 @@ def __init__(self, config: gemma_config.GemmaConfig, env): config.head_dim, config.device, env, + layer_id, ) self.mlp = GemmaMLP( @@ -340,8 +342,8 @@ def __init__(self, config: gemma_config.GemmaConfig, env): self.env = env self.layers = nn.ModuleList() - for _ in range(config.num_hidden_layers): - self.layers.append(GemmaDecoderLayer(config, env)) + for layer_id in range(config.num_hidden_layers): + self.layers.append(GemmaDecoderLayer(config, env, layer_id)) self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, device=config.device ) diff --git a/jetstream_pt/third_party/llama/model_args.py b/jetstream_pt/third_party/llama/model_args.py index 7956667d..1b72c0a7 100755 --- a/jetstream_pt/third_party/llama/model_args.py +++ b/jetstream_pt/third_party/llama/model_args.py @@ -45,7 +45,8 @@ def get_arg( "dim": 128, "vocab_size": 32000, "multiple_of": 32, - "n_heads": 8, + "n_heads": 64, + "n_kv_heads": 8, "n_layers": 3, "norm_eps": 1e-05, } diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 15f4fd04..a045cd20 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -94,6 +94,7 @@ def __init__( args.dim, env=env, device=args.device, + layer_id=layer_id, ) self.feed_forward = FeedForward( dim=args.dim, @@ -217,7 +218,6 @@ def forward( ragged_batch_index: precomputed batch index for ragged attention ragged_block_index: precomputed block index for ragged attention """ - with jax.named_scope("transformer_tok"): seqlen = tokens.shape[-1] h = self.tok_embeddings(tokens) @@ -227,12 +227,16 @@ def forward( freqs_cis = self.freqs_cis[input_pos] freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) - assert len(caches) == len( - self.layers - ), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match" end = None if start is None else (start + input_pos) % self.env.cache_len - for layer, cache in zip(self.layers, caches): - with jax.named_scope("TransformerBlock_Layer_" + str(layer.layer_id)): + # For stacked case, cannot get cache inside the loop which will cause cache copy + for layer_id, layer in enumerate(self.layers): + if caches[0].stacked: + cache = caches[0] + else: + cache = caches[layer_id] + # else: # For stacked case, there is only 1 yer of kv cache + + with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): h = layer( h, freqs_cis, diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py index b0d8d573..422f4990 100644 --- a/jetstream_pt/third_party/mixtral/model.py +++ b/jetstream_pt/third_party/mixtral/model.py @@ -22,8 +22,10 @@ from torch.nn import functional as F from .config import ModelArgs, find_multiple from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer +from jetstream_pt import quantize, torchjax import jax +import jax.numpy as jnp class Transformer(nn.Module): @@ -38,7 +40,8 @@ def __init__(self, config: ModelArgs, env) -> None: config.vocab_size, config.dim, device=config.device ) self.layers = nn.ModuleList( - TransformerBlock(config, env) for _ in range(config.n_layer) + TransformerBlock(config, env, layer_id) + for layer_id in range(config.n_layer) ) self.norm = RMSNorm(config.dim, eps=config.norm_eps) LinearLayer = get_quantized_linear_layer(env.quant_config) @@ -76,11 +79,15 @@ def forward( bsz, seqlen = idx.shape freqs_cis = self.freqs_cis[input_pos] freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) - assert len(caches) == len( - self.layers - ), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match" - for layer, cache in zip(self.layers, caches): - with jax.named_scope("TransformerBlock"): + + for layer_id, layer in enumerate(self.layers): + if caches[0].stacked: + cache = caches[0] + else: + cache = caches[layer_id] + # else: # For stacked case, there is only 1 yer of kv cache + + with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): x = layer( x, freqs_cis, @@ -142,7 +149,7 @@ def get_weight_sharding_type(): class TransformerBlock(nn.Module): - def __init__(self, config: ModelArgs, env) -> None: + def __init__(self, config: ModelArgs, env, layer_id) -> None: super().__init__() self.attention = Attention( config.n_head, @@ -151,6 +158,7 @@ def __init__(self, config: ModelArgs, env) -> None: config.dim, env=env, device=config.device, + layer_id=layer_id, ) self.block_sparse_moe = MOEFeedForward(config, config.device, env) self.ffn_norm = RMSNorm(config.dim, config.norm_eps) @@ -227,6 +235,31 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: else: return self.forward_for_short_seq_len(x, expert_indices) + def _int_ti_eoi_teo(self, lhs, rhs): + # x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler) + result = torchjax.call_jax( + jax.lax.dot_general, + lhs, + rhs, + (((1,), (2)), ((), ())), + None, + jnp.bfloat16.dtype, + ) + return result + + def _int_teo_eio_tei(self, lhs, rhs): + #torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler + result = torchjax.call_jax( + jax.lax.dot_general, + lhs, + rhs, + (((2,), (2,)), ((1, ), (0, ))), + None, + jnp.bfloat16.dtype, + ) # output is (eti) for some reason + return result.transpose(0, 1) + + def forward_for_short_seq_len( self, x: Tensor, expert_indices: Tensor ) -> Tensor: @@ -254,14 +287,20 @@ def forward_for_long_seq_len(self, x, expert_indices): # o = config.imtermediate size # i = config.dim with jax.named_scope("conditional_ff"): - x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler) - x3 = torch.einsum("ti, eoi-> teo", x, self.w3) * self.w3_scaler + x_int, x_scaler, _ = quantize.quantize_tensor(x, (1,)) + x_scaler = x_scaler.reshape(seqlen, 1, 1) + + x1 = F.silu(self._int_ti_eoi_teo(x_int, self.w1) * self.w1_scaler * x_scaler) + x3 = self._int_ti_eoi_teo(x_int, self.w3) * self.w3_scaler * x_scaler + + x1x3_int, x1x3_scaler, _ = quantize.quantize_tensor(x1 * x3, (1, 2)) + x1x3_scaler = x1x3_scaler.reshape(seqlen, 1, 1) expert_outs = ( - torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler + self._int_teo_eio_tei(x1x3_int, self.w2) * self.w2_scaler ) # e = 8; need to reduce to 2 seq_indexes = torch.arange(seqlen).unsqueeze(1) - return expert_outs[seq_indexes, expert_indices] + return expert_outs[seq_indexes, expert_indices] * x1x3_scaler class ConditionalFeedForward(nn.Module): diff --git a/mlperf/accuracy_run.sh b/mlperf/accuracy_run.sh new file mode 100644 index 00000000..b940ae74 --- /dev/null +++ b/mlperf/accuracy_run.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +me=$(basename "$0") + +BASEDIR=mlperf +API_URL=0.0.0.0:9000 +USER_CONFIG=$BASEDIR/user.conf +DATA_DISK_DIR=$BASEDIR/data +TOTAL_SAMPLE_COUNT=1000 +DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl + +# HF model id +TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1" +LOADGEN_RUN_TYPE=offline-performance +OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} +OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} + +mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} + +OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json + +CACHE_LENGTH=1024 +INPUT_SIZE=512 +OUTPUT_SIZE=512 +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ + +LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" +# makes subsequent runs faster +export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2" +export LIBTPU_INIT_ARGS + +pushd .. +# python -m mlperf.offline_mode \ +# --model_name=mixtral \ +# --max_cache_length=$CACHE_LENGTH \ +# --max_decode_length=$OUTPUT_SIZE \ +# --context_length=$INPUT_SIZE \ +# --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ +# --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ +# --quantize_weights=1 \ +# --quantize_type=int8_per_channel \ +# --quantize_kv_cache=1 \ +# --scenario Offline \ +# --input_mode tokenized \ +# --output_mode tokenized \ +# --mlperf_conf $BASEDIR/mlperf.conf \ +# --user_conf ${USER_CONFIG} \ +# --audit_conf no_audit \ +# --total_sample_count ${TOTAL_SAMPLE_COUNT} \ +# --dataset_path ${DATASET_PATH} \ +# --output_log_dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log + +python -m mlperf.evaluate_accuracy \ + --checkpoint-path ${TOKENIZER_PATH} \ + --mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \ + --dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log +popd \ No newline at end of file diff --git a/mlperf/evaluate_accuracy.py b/mlperf/evaluate_accuracy.py new file mode 100644 index 00000000..2fd74e1e --- /dev/null +++ b/mlperf/evaluate_accuracy.py @@ -0,0 +1,251 @@ +import argparse +from transformers import AutoTokenizer +import nltk +import evaluate +import numpy as np +import pandas as pd +import json +import re + +import logging +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger("evaluate_accuracy.py") + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint-path", + required=True, + help="Path to Mixtral-8x7b-Instruct checkpoint", + ) + parser.add_argument( + "--mlperf-accuracy-file", + required=True, + help="path to mlperf_log_accuracy.json", + ) + parser.add_argument( + "--dataset-file", + required=True, + help="path to processed validation dataset", + ) + parser.add_argument( + "--n_workers", + default=2, + type=int, + help="Number of workers used for the MBXP evaluation", + ) + parser.add_argument("--verbose", action="store_true", help="verbose messages") + parser.add_argument( + "--dtype", + default="int64", + help="dtype of the accuracy log", + choices=["int32", "int64", "float"], + ) + args = parser.parse_args() + return args + + +def get_groundtruth(processed_dataset_file): + data = pd.read_pickle(processed_dataset_file) + return data + + +# Functions for evaluating GSM8K +def find_numbers(x: str) -> list[str]: + """Finds all numbers in a string.""" + # Search for number, possibly negative (hyphen), with thousand separators + # (comma), and with a decimal point (period inbetween digits). + numbers = re.compile( + r"-?[\d,]*\.?\d+", + re.MULTILINE | re.DOTALL | re.IGNORECASE, + ).findall(x) + return numbers + + +def find_number(x: str, answer_delimiter: str = "The answer is") -> str: + """Finds the most relevant number in a string.""" + # If model uses the answer delimiter, then select the first number following + # that format. + if answer_delimiter in x: + answer = x.split(answer_delimiter)[-1] + numbers = find_numbers(answer) + if numbers: + return numbers[0] + + # In general, select the last number in the string. + numbers = find_numbers(x) + if numbers: + return numbers[-1] + return "" + + +def maybe_remove_comma(x: str) -> str: + # Example: 5,600 -> 5600 + return x.replace(",", "") + + +def try_float(x: str): + try: + ret = float(x) + except BaseException: + ret = None + return ret + + +# Functions for evaluating OpenOrca + + +def postprocess_text(preds, targets): + preds = [pred.strip() for pred in preds] + targets = [target.strip() for target in targets] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] + + return preds, targets + + +# Functions for MBXP + + +def create_mbxp_dict(row, response): + lang, entry_point = row["id"].split("_", 1) + return { + "lang": lang, + "prompt": row["input"], + "test_code": row["gt_output"], + "entry_point": entry_point, + "response": response, + } + + +def main(): + + args = get_args() + dataset_path = args.dataset_file + checkpoint_path = args.checkpoint_path + metric = evaluate.load("rouge") + nltk.download("punkt") + + tokenizer = AutoTokenizer.from_pretrained( + checkpoint_path, + model_max_length=2048, + padding_side="left", + use_fast=False, + ) + + data = get_groundtruth(args.dataset_file) + query_types, gt_outputs = data["dataset"], data["gt_output"] + + target_required_GSM8K = [] + target_required_OpenOrca = [] + results_MBXP = [] + preds_token_GSM8K = [] + preds_token_OpenOrca = [] + preds_token_MBXP = [] + + eval_dtype = np.int64 + if args.dtype == "int32": + eval_dtype = np.int32 + elif args.dtype == "float": + eval_dtype = np.float32 + + with open(args.mlperf_accuracy_file, "r") as f: + results = json.load(f) + + seen = set() + gen_tok_len = 0 + gen_num = 0 + for pred in results: + gen_num += 1 + qsl_idx = pred["qsl_idx"] + if qsl_idx in seen: + continue + + seen.add(qsl_idx) + + query_type = query_types.iloc[qsl_idx] + if query_type == "GSM8K": + target = gt_outputs.iloc[qsl_idx] + target_required_GSM8K.append(target) + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + gen_tok_len += len(pred) + preds_token_GSM8K.append(pred) + elif query_type == "OpenOrca": + target = gt_outputs.iloc[qsl_idx] + target_required_OpenOrca.append(target) + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + preds_token_OpenOrca.append(pred) + gen_tok_len += len(pred) + else: + target = data.iloc[qsl_idx] + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + pred_str = tokenizer.decode(pred, skip_special_tokens=True) + results_MBXP.append(create_mbxp_dict(target, pred_str)) + gen_tok_len += len(pred) + + # OpenOrca metric + preds_decoded_text = tokenizer.batch_decode( + preds_token_OpenOrca, skip_special_tokens=True + ) + + preds, targets = postprocess_text( + preds_decoded_text, target_required_OpenOrca + ) + + if preds: + result = metric.compute( + predictions=preds, + references=targets, + use_stemmer=True, + use_aggregator=False, + ) + result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} + prediction_lens = [len(pred) for pred in preds] + + else: + result = {} + prediction_lens = [] + + # GSM8K metric + preds_decoded_text = tokenizer.batch_decode( + preds_token_GSM8K, skip_special_tokens=True + ) + pred_nums = [ + maybe_remove_comma(find_number(pred_text.split("\nQ:")[0])) + for pred_text in preds_decoded_text + ] + gsm8k_total = len(target_required_GSM8K) + correct = 0 + for idx in range(len(target_required_GSM8K)): + ref = try_float(target_required_GSM8K[idx]) + tgt = try_float(pred_nums[idx]) + if tgt is None: + continue + correct += ref == tgt + + result["gsm8k"] = 100.0 * correct / gsm8k_total + + # MBXP metric + # from evaluate_mbxp import evaluate_mbxp + + # if results_MBXP: + # result['mbxp'] = evaluate_mbxp(results_MBXP, args.n_workers) + # else: + # result['mbxp'] = 0 + + result = { + **result, + "gen_len": np.sum(prediction_lens), + "gen_num": gen_num, + "gen_tok_len": gen_tok_len, + "tokens_per_sample": round(gen_tok_len / gen_num, 1), + } + + print("\nResults\n") + print(result) + + +if __name__ == "__main__": + main() diff --git a/mlperf/gmm.py b/mlperf/gmm.py new file mode 100644 index 00000000..73c0dd3b --- /dev/null +++ b/mlperf/gmm.py @@ -0,0 +1,202 @@ +import jax +import jax.numpy as jnp + +from jax.sharding import Mesh, PartitionSpec as P, NamedSharding +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +from functools import partial +from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm + +devices = mesh_utils.create_device_mesh((8, 1)) +mesh = Mesh(devices, axis_names=('x', 'y')) +jax.config.update('jax_default_matmul_precision', "float32") +import torch +interp = False +pdtype = jnp.dtype('float32') + +def _reference_gmm(lhs: torch.Tensor, rhs: torch.Tensor, + group_sizes: torch.Tensor, tiling=None, preferred_element_type=None, interpret=False) -> torch.Tensor: + start = 0 + out = [] + for i, size in enumerate(group_sizes): + result = lhs[start:start + size, :] @ rhs[i, :, :] + out.append(result) + start += group_sizes[i] + return jnp.concatenate(out) + +#gmm = _reference_gmm + + + + +@partial( + shard_map, + mesh=mesh, + in_specs=( + P(), + P(None, 'x', None), + P(None, None, 'x'), + P(None, 'x', None), + P(None, None)), + out_specs=(P()), check_rep=False) +def temp(x, w1, w2, w3, expert_indices): + print('x inside', x.shape) + print('w1', w1.shape) + print('w2', w2.shape) + print('w3', w3.shape) + print('index', expert_indices.shape) + def _histogram(input, min: int, max: int): + assert min <= max, "min must be less than or equal to max." + + def searchsorted(sorted_sequence, values_to_search): + return (jax.numpy.expand_dims(sorted_sequence, 1) == values_to_search).sum(axis=1) + + bin_edges = jax.numpy.linspace( + min, max, max - min + 1, dtype=input.dtype) + return searchsorted(bin_edges, input) + + num_tokens, k = expert_indices.shape + _, n = x.shape + top_flat = expert_indices.flatten() + hidden_states_order = top_flat.argsort() + hidden_states_reverse_order = hidden_states_order.argsort() + # Always replicated, so okay to skip manual sharding. + hidden_states_indices = jnp.arange(num_tokens).repeat(k)[hidden_states_order] + hidden_states_sorted = x[hidden_states_indices] # (num_tokens, hidden_dim/8) + group_sizes = _histogram(top_flat, 0, 7) + # w1 (num_experts, hiddent_dim/8, intermeditate) + gmm1 = gmm(hidden_states_sorted.astype('bfloat16'), + jnp.transpose(w1, (0, 2, 1)).astype('bfloat16'), + group_sizes, tiling=(16,128,128), + preferred_element_type=pdtype, interpret=interp) + gmm3 = gmm( + hidden_states_sorted.astype('float32'), + jnp.transpose(w3, (0, 2, 1)).astype('float32'), + group_sizes, tiling=(16,128,128), preferred_element_type=pdtype, interpret=interp) + + gmm1_s = gmm1 + gmm3_s = gmm3 + #gmm1_s = jax.lax.psum(gmm1, 'x') + #gmm3_s = jax.lax.psum(gmm3, 'x') + silu = jax.nn.silu(gmm1_s) + sgmm = silu * gmm3_s # (num_tokens, intermediate_size) + gmm2 = gmm( + sgmm, + jnp.transpose(w2, (0, 2, 1)).astype('float32'), + group_sizes, + tiling=(8,512,512), + preferred_element_type=pdtype, interpret=interp) #(num_tokens, hidden_dim/8) + print(gmm2.shape) + gmm2 = jax.lax.psum(gmm2, 'x') + current_hidden_states = gmm2[hidden_states_reverse_order].reshape(-1, k, n) + return current_hidden_states + + +# Create a PRNG key +key = jax.random.PRNGKey(123) # Using a different seed for variety + +seqlen = 16 + +expert_indices = jax.random.randint(key, shape=(seqlen, 2), minval=0, maxval=8) +hidden_states = jax.random.normal(key, (seqlen, 4096), dtype=jnp.bfloat16) + +w1 = jnp.broadcast_to(jnp.arange(8).reshape((8, 1, 1)).astype('float32'), (8, 14336, 4096)) +w2 = jnp.broadcast_to(jnp.arange(8).reshape((8, 1, 1)).astype('float32'), (8, 4096, 14336)) +w3 = jnp.broadcast_to(jnp.arange(8).reshape((8, 1, 1)).astype('float32'), (8, 14336, 4096)) + + +hidden_states = jax.device_put(hidden_states, NamedSharding(mesh, P(None, "x"))) +w1 = jax.device_put(w1, NamedSharding(mesh, P(None, "x"))) +w2 = jax.device_put(w2, NamedSharding(mesh, P(None, None, "x"))) +w3 = jax.device_put(w3, NamedSharding(mesh, P(None, "x"))) + +def exp_einsum(x, w1, expert_indices): + w1_weights = w1[expert_indices] + x1 = jnp.einsum("ti,taoi -> tao", x, w1_weights) + return x1 + +def _repeat_index(num_tokens, k): + start = jnp.arange(num_tokens).repeat(k) + start = start.reshape((num_tokens, k)) + return start.T.flatten() + +def exp_gmm(x, w1, expert_indices): + num_tokens, k = expert_indices.shape + _, n = x.shape + e, o, i = w1.shape + top_flat = expert_indices.flatten() + hidden_states_order = top_flat.argsort() + hidden_states_reverse_order = hidden_states_order.argsort() + # Always replicated, so okay to skip manual sharding. + hidden_states_indices = jnp.arange(num_tokens).repeat(k)[hidden_states_order] + hidden_states_sorted = x[hidden_states_indices] # (num_tokens, hidden_dim/8) + def _histogram(input, min: int, max: int): + assert min <= max, "min must be less than or equal to max." + + def searchsorted(sorted_sequence, values_to_search): + return (jax.numpy.expand_dims(sorted_sequence, 1) == values_to_search).sum(axis=1) + + bin_edges = jax.numpy.linspace( + min, max, max - min + 1, dtype=input.dtype) + return searchsorted(bin_edges, input) + + group_sizes = _histogram(top_flat, 0, 7) + gmm1 = gmm(hidden_states_sorted.astype('float32'), + jnp.transpose(w1, (0, 2, 1)).astype('float32'), + group_sizes, + tiling=(16,128,128), + preferred_element_type=pdtype, interpret=interp) + return gmm1[hidden_states_reverse_order].reshape(-1, k, o) + + +def forward_for_long_seq_len(x, w1, w2, w3, expert_indices): + seqlen = x.shape[0] + num_experts = w1.shape[0] + + # e = total num of exp = 8 + # t = seqlen + # o = config.imtermediate size + # i = config.dim + with jax.named_scope("conditional_ff"): + x1 = jax.nn.silu(jnp.einsum("ti,eoi -> teo", x, w1)) + x3 = jnp.einsum("ti, eoi-> teo", x, w3) + expert_outs = ( + jnp.einsum("teo, eio -> tei", (x1 * x3), w2) + ) + # e = 8; need to reduce to 2 + seq_indexes = jnp.expand_dims(jnp.arange(seqlen), 1) + return expert_outs[seq_indexes, expert_indices] + +def main(): + + # x = jnp.arange(8).reshape((2, 4)).astype('float64') / 100 + # w1 = jnp.arange(3 * 4 * 5).reshape((3,5,4)).astype('float64') + # w2 = jnp.arange(3 * 4 * 5).reshape((3,4,5)).astype('float64') + # w3 = jnp.arange(3 * 4 * 5).reshape((3,5,4)).astype('float64') + # expert_indices = jnp.array([[0, 2], [1, 0]]) + x = hidden_states + out1 = forward_for_long_seq_len(x, w1, w2, w3, expert_indices) + out = temp(x, w1, w2, w3, expert_indices) + print(jnp.max(jnp.abs(out1 - out))) + + # out1 = exp_einsum(x, w1, expert_indices) + # out_gmm = exp_gmm(x, w1, expert_indices) + # print(out1 - out_gmm) + + + #group_sizes = jnp.array([4] * 8) + #gmm1 = gmm(hidden_states.astype('float32'), + # w1.astype('float32'), + # group_sizes, + # tiling=(16,128,128), interpret=interp) + #gmm1_ref = _reference_gmm(hidden_states.astype('float32'), + # w1.astype('float32'), + # group_sizes, + # tiling=(16,128,128), interpret=interp) + #print(gmm1 - gmm1_ref) + + +if __name__ == '__main__': + main() + + diff --git a/mlperf/install.sh b/mlperf/install.sh new file mode 100644 index 00000000..3a8f037b --- /dev/null +++ b/mlperf/install.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +DATA_DISK_DIR=data + +mkdir -p $DATA_DISK_DIR + +pip install -U "huggingface_hub[cli]" +pip install \ + transformers \ + nltk==3.8.1 \ + evaluate==0.4.0 \ + absl-py==1.4.0 \ + rouge-score==0.1.2 \ + sentencepiece==0.1.99 \ + accelerate==0.21.0 + +# install loadgen +pip install mlperf-loadgen + + +pushd $DATA_DISK_DIR + +# model weights +gcloud storage cp gs://sixiang_gcp/mixtral-instruct-quantized ./ --recursive +# NOTE: uncomment one so you dont download too much weights to your box +# gcloud storage cp gs://sixiang_gcp/llama2-70b/llama2-70b/ ./ --recursive + +# Get mixtral data +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl +mv mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl mixtral_15k_data.pkl +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl +mv mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl mixtral_15k_calibration_data.pkl + +# Get llama70b data +gcloud storage cp \ + gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \ + processed-calibration-data.pkl +gcloud storage cp \ + gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl \ + processed-data.pkl +popd diff --git a/mlperf/mixtral_run.sh b/mlperf/mixtral_run.sh new file mode 100755 index 00000000..77504920 --- /dev/null +++ b/mlperf/mixtral_run.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +me=$(basename "$0") + +BASEDIR=mlperf +USER_CONFIG=$BASEDIR/user.conf +DATA_DISK_DIR=$BASEDIR/data +TOTAL_SAMPLE_COUNT=900 + +# HF model id +TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1" +LOADGEN_RUN_TYPE=offline-performance +OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} +OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} + +mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} + +OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json + +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ + +LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" +# makes subsequent runs faster +export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2" +export LIBTPU_INIT_ARGS + + +DATASET_PATH=$BASEDIR/data/mixtral_1k_data.pkl + +pushd .. +python -m mlperf.offline_mode \ + --lazy_cache_update=1 \ + --ring_buffer=0 \ + --mlperf_test_mode=$1 \ + --model_name=mixtral \ + --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ + --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ + --quantize_weights=1 \ + --quantize_type=int8_per_channel \ + --quantize_kv_cache=1 \ + --input_mode tokenized \ + --output_mode tokenized \ + --mlperf_conf $BASEDIR/mlperf.conf \ + --user_conf ${USER_CONFIG} \ + --audit_conf no_audit \ + --total_sample_count ${TOTAL_SAMPLE_COUNT} \ + --dataset_path ${DATASET_PATH} \ + --output_log_dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log + +if [ "$1" = "accuracy" ]; then +python -m mlperf.evaluate_accuracy \ + --checkpoint-path ${TOKENIZER_PATH} \ + --mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \ + --dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log +fi +popd diff --git a/mlperf/mlperf.conf b/mlperf/mlperf.conf new file mode 100644 index 00000000..e9ae205e --- /dev/null +++ b/mlperf/mlperf.conf @@ -0,0 +1,98 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds + +# Set performance_sample_count for each model. +# User can optionally set this to higher values in user.conf. +resnet50.*.performance_sample_count_override = 1024 +ssd-mobilenet.*.performance_sample_count_override = 256 +retinanet.*.performance_sample_count_override = 64 +bert.*.performance_sample_count_override = 10833 +dlrm.*.performance_sample_count_override = 204800 +dlrm-v2.*.performance_sample_count_override = 204800 +rnnt.*.performance_sample_count_override = 2513 +gptj.*.performance_sample_count_override = 13368 +llama2-70b.*.performance_sample_count_override = 24576 +stable-diffusion-xl.*.performance_sample_count_override = 5000 +# set to 0 to let entire sample set to be performance sample +3d-unet.*.performance_sample_count_override = 0 + +# Set seeds. The seeds will be distributed two weeks before the submission. +*.*.qsl_rng_seed = 3066443479025735752 +*.*.sample_index_rng_seed = 10688027786191513374 +*.*.schedule_rng_seed = 14962580496156340209 +# Set seeds for TEST_05. The seeds will be distributed two weeks before the submission. +*.*.test05_qsl_rng_seed = 16799458546791641818 +*.*.test05_sample_index_rng_seed = 5453809927556429288 +*.*.test05_schedule_rng_seed = 5435552105434836064 + + +*.SingleStream.target_latency_percentile = 90 +*.SingleStream.min_duration = 600000 + +*.MultiStream.target_latency_percentile = 99 +*.MultiStream.samples_per_query = 8 +*.MultiStream.min_duration = 600000 +*.MultiStream.min_query_count = 662 +retinanet.MultiStream.target_latency = 528 + +# 3D-UNet uses equal issue mode because it has non-uniform inputs +3d-unet.*.sample_concatenate_permutation = 1 + +# LLM benchmarks have non-uniform inputs and outputs, and use equal issue mode for all latency scenario +gptj.*.sample_concatenate_permutation = 1 +llama2-70b.*.sample_concatenate_permutation = 1 +mixtral-8x7B.*.sample_concatenate_permutation = 1 + +*.Server.target_latency = 10 +*.Server.target_latency_percentile = 99 +*.Server.target_duration = 0 +*.Server.min_duration = 600000 +resnet50.Server.target_latency = 15 +retinanet.Server.target_latency = 100 +bert.Server.target_latency = 130 +dlrm.Server.target_latency = 60 +dlrm-v2.Server.target_latency = 60 +rnnt.Server.target_latency = 1000 +gptj.Server.target_latency = 20000 +stable-diffusion-xl.Server.target_latency = 20000 +# Llama2-70b benchmarks measures token latencies +llama2-70b.*.use_token_latencies = 1 +mixtral-8x7b.*.use_token_latencies = 1 +# gptj benchmark infers token latencies +gptj.*.infer_token_latencies = 1 +gptj.*.token_latency_scaling_factor = 69 +# Only ttft and tpot are tracked for the llama2-70b & mixtral-8x7B benchmark therefore target_latency = 0 +llama2-70b.Server.target_latency = 0 +llama2-70b.Server.ttft_latency = 2000 +llama2-70b.Server.tpot_latency = 200 + +mixtral-8x7b.Server.target_latency = 0 +mixtral-8x7b.Server.ttft_latency = 2000 +mixtral-8x7b.Server.tpot_latency = 200 + +*.Offline.target_latency_percentile = 90 +*.Offline.min_duration = 600000 + +# In Offline scenario, we always have one query. But LoadGen maps this to +# min_sample_count internally in Offline scenario. If the dataset size is larger +# than 24576 we limit the min_query_count to 24576 and otherwise we use +# the dataset size as the limit + +resnet50.Offline.min_query_count = 24576 +retinanet.Offline.min_query_count = 24576 +dlrm-v2.Offline.min_query_count = 24576 +bert.Offline.min_query_count = 10833 +gptj.Offline.min_query_count = 13368 +rnnt.Offline.min_query_count = 2513 +3d-unet.Offline.min_query_count = 43 +stable-diffusion-xl.Offline.min_query_count = 5000 +llama2-70b.Offline.min_query_count = 1000 +mixtral-8x7b.Offline.min_query_count = 15000 + +# These fields should be defined and overridden by user.conf. +*.SingleStream.target_latency = 10 +*.MultiStream.target_latency = 80 +*.Server.target_qps = 1.0 +*.Offline.target_qps = 4.0 diff --git a/mlperf/offline_mode.py b/mlperf/offline_mode.py new file mode 100644 index 00000000..09b00c25 --- /dev/null +++ b/mlperf/offline_mode.py @@ -0,0 +1,476 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import contextlib +import copy +import gc +import time +import math +import logging +import os +import sys +import array +import collections +import threading + +import torch_xla2 +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd + +import mlperf_loadgen as lg +from jetstream_pt.config import create_engine_from_config_flags +from jetstream_pt import offline_inference + +_MLPERF_ID = "mixtral-8x7b" + +logging.basicConfig(level=logging.DEBUG) + +sys.path.insert(0, os.getcwd()) +log = logging.getLogger("main2.py") + + +from absl import app, flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "mlperf_test_mode", + "performance", + "performance, accuracy, submission", +) +flags.DEFINE_string( + "api_url", None, "SAX published model path.", required=False +) +flags.DEFINE_string("dataset_path", None, "", required=False) +flags.DEFINE_bool("is_stream", False, "", required=False) +flags.DEFINE_string( + "input_mode", + "tokenized", + "Input mode", +) +flags.DEFINE_string( + "output_mode", + "tokenized", + "Output mode", +) + +flags.DEFINE_string( + "audit_conf", + "audit.conf", + "audit config for LoadGen settings during compliance runs", + required=False, +) +flags.DEFINE_string( + "mlperf_conf", + "mlperf.conf", + "mlperf rules config", + required=False, +) +flags.DEFINE_string( + "user_conf", + "user.conf", + "user config for user LoadGen settings such as target QPS", + required=False, +) +flags.DEFINE_integer( + "total_sample_count", + 15000, + "Number of samples to use in benchmark.", + required=False, +) +flags.DEFINE_integer( + "perf_count_override", + None, + "Overwrite number of samples to use in benchmark.", + required=False, +) +flags.DEFINE_string( + "output_log_dir", + "output-logs", + "Where logs are saved.", + required=False, +) +flags.DEFINE_bool( + "enable_log_trace", + False, + "Enable log tracing. This file can become quite large", + required=False, +) +flags.DEFINE_bool( + "skip_warmup", False, "Skips warmup" +) +flags.DEFINE_bool( + "internal_dummy_model", False, "Skips actual model compute, used for testing" +) + + +scenario_map = { + "offline": lg.TestScenario.Offline, + "server": lg.TestScenario.Server, +} + + +def pad_tokens(tokens): + true_length = len(tokens) + target_length = max(int(2 ** math.ceil(math.log2(true_length))), 32) + padded = tokens + [0] * (target_length - true_length) + return padded, true_length + + +@contextlib.contextmanager +def timed(msg): + log.info(msg + " start") + start = time.perf_counter() + yield + end = time.perf_counter() + log.info(msg + " done: " + str(end - start)) + + +def _classify_query(dataset_rows, index): + # return groupped indexes + if FLAGS.model_name == 'mixtral': + sample = dataset_rows[index][1] + input_len = sample.tok_input_len + total_len = sample.tok_input_len + sample.tok_ref_output_len + if total_len <= 512: + return 0 + elif total_len <= 1280 and input_len <= 1024: + return 1 + else: + return 2 + else: + sample = dataset_rows[index][1] + total_len = sample.tok_input_length + sample.tok_output_length + if total_len <= 512: + return 0 + elif total_len <= 1024: + return 1 + else: + return 2 + + + + +def _pick_batch_size(num_samples, max_batch, dataset_size, sample_size): + """max_batch to not run OOM.""" + if num_samples <= max_batch: + return num_samples + mult = math.ceil(num_samples / max_batch) + return math.ceil(num_samples / mult * (sample_size / dataset_size)) + + +def _log_complete(sample_id, response_token_ids): + assert (response_token_ids[0] <= 32000) + n_tokens = len(response_token_ids) + response_token_ids = np.array(response_token_ids, dtype=np.int64) + response_array = array.array("B", response_token_ids.tobytes()) + response_info = response_array.buffer_info() + response_data = response_info[0] + response_size = response_info[1] * response_array.itemsize + query_sample_response = lg.QuerySampleResponse( + sample_id, response_data, response_size, n_tokens + ) + lg.QuerySamplesComplete([query_sample_response]) + # import ipdb; ipdb.set_trace() + +def _log_first(sample_id, response_token_ids): + assert (response_token_ids[0] <= 32000) + assert len(response_token_ids) == 1 + response_token_ids = np.array(response_token_ids, dtype=np.int64) + response_array = array.array("B", response_token_ids.tobytes()) + response_info = response_array.buffer_info() + first_token_response = lg.QuerySampleResponse( + sample_id, response_info[0], response_info[1] + ) + lg.FirstTokenComplete([first_token_response]) + + +class SUT: + + def __init__(self, data, offline_inf): + # dict of int (cache length) -> offline_inf + self.offline_inf = offline_inf + + # pandas dataframe, it has tok + self._dataset = data + self.pandas_rows = list(self._dataset.iterrows()) + + # List of things with .id and .index + self._queries = None + + # index to loaded data + self._processed_data = None + + # self.replicated = self.offline_inf.engine.env.sharding_by_axis(-1) + self._sample_id_to_input = None + self._sample_id_to_bucket = None + self._groupped_queries = [[], [], []] + self._id_to_index = {} + self._index_to_group = {} + self._eos = offline_inf[0].tokenizer.eos_id + + def _get_eos_seq(self, id_): + idx = self._id_to_index[id_] + pandas_row = self.pandas_rows[idx][1] + if hasattr(pandas_row, 'tok_stop_sequence'): + return pandas_row.tok_stop_sequence + else: + return [self._eos] + + def issue_queries(self, queries): + log.info('issue queries called') + self._groupped_queries = [[], [], []] + assert self._sample_id_to_input is not None + self._processed_data = [] + self._queries = queries + for q in queries: + self._id_to_index[q.id] = q.index + group = self._index_to_group[q.index] + input_data = copy.copy(self._sample_id_to_input[q.index]) + input_data.id = q.id + self._groupped_queries[group].append(input_data) + + if len(self._queries) != sum(len(q) for q in self._groupped_queries): + import ipdb; ipdb.set_trace() + + # At this point _processed_data is ready + + @timed("flush_queries") + def flush_queries(self): + start = time.perf_counter() + completed = set() + resp = collections.defaultdict(list) + lock = threading.RLock() + def emit_token(id_, token): + nonlocal resp + nonlocal completed + with lock: + resp[id_].append(token) + end_seq = self._get_eos_seq(id_) + is_end = (token == self._eos) or (end_seq == resp[id_][-len(end_seq):]) + if is_end: + _log_complete(id_, resp[id_]) + completed.add(id_) + if id_ in resp: + del resp[id_] + return is_end + + def emit_first_token(id_, token): + nonlocal resp + nonlocal completed + # emit first token + with lock: + _log_first(id_, [token]) + end_seq = self._get_eos_seq(id_) + is_end = (token == self._eos) or (len(end_seq) == 1 and end_seq[0] == token) + if is_end: + # We have four OpenOrca samples that return empty (eos_token) output. + # It was decided to allow two eos_tokens to not break throughput computation + # PR - https://github.com/mlcommons/inference/pull/1778 + _log_complete(id_, [token, self._eos]) + completed.add(id_) + if id_ in resp: + import pdb; pdb.set_trace() + del resp[id_] + return is_end + + for group_idx in [2,1,0]: + group = self._groupped_queries[group_idx] + self.offline_inf[group_idx].init_decode_state() + result = self.offline_inf[group_idx].batch_inference_with_callback( + group, emit_first_token, emit_token) + + # some never reached eos but reached max sequence + with lock: + for key, value in resp.items(): + if key in completed: + continue + _log_complete(key, value) + completed.add(key) + + if group_idx != 0: + # no need to drop state for the last one + self.offline_inf[group_idx].decode_state = None + gc.collect() + + end = time.perf_counter() + + def LoadSamplesToRam(self, sample_list): + """Pads the data, move them to jax array on device""" + log.info("LoadSamplesToRam start") + start = time.perf_counter() + input_data = {} + + for sample_id in sample_list: + p = self.pandas_rows[sample_id][1] + padded, length = pad_tokens(p.tok_input) + input_data[sample_id] = offline_inference.InputData( + "", jnp.array(padded), length # to be filled later + ) + self._index_to_group[sample_id] = _classify_query(self.pandas_rows, sample_id) + + for data in input_data.values(): + # make sure tokens are transfered to device + jax.block_until_ready(data.tokens) + + self._sample_id_to_input = input_data + + end = time.perf_counter() + log.info(f"LoadSamplesToRam finished: {end - start}s") + + def UnloadSamplesFromRam(self, sample_list): + print("UnloadSamplesFromRam called") + pass + + +def _count_by_bucket(dataset): + if FLAGS.model_name == 'mixtral': + total_len = dataset.tok_input_len + dataset.tok_ref_output_len + + group1 = total_len <= 512 + group2 = (total_len <= 1280) & (dataset.tok_input_len <= 1024) + + # with 5 percent extra + mult = FLAGS.total_sample_count / len(dataset) * 1.05 + + counts = [ + # power of 2 + math.ceil(len(dataset[group1]) * mult), + math.ceil(len(dataset[~group1 & group2]) * mult), + math.ceil(len(dataset[~group1 & ~group2]) * mult), + ] + return counts + else: + total_len = dataset.tok_input_length + dataset.tok_output_length + group1 = total_len <= 512 + group2 = total_len <= 1024 + # with 5 percent extra + mult = FLAGS.total_sample_count / len(dataset) * 1.05 + counts = [ + math.ceil(len(dataset[group1]) * mult), + math.ceil(len(dataset[~group1 & group2]) * mult), + math.ceil(len(dataset[~group1 & ~group2]) * mult), + ] + return counts + + +def main(argv): + del argv + args = FLAGS + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + + if len(jax.devices()) < 4: + print("Looks like TPU not available?", jax.devices()) + return -1 + # jax.config.update("jax_explain_cache_misses", True) + + settings = lg.TestSettings() + settings.scenario = lg.TestScenario.Offline + user_conf = FLAGS.user_conf + + settings.FromConfig(FLAGS.mlperf_conf, _MLPERF_ID, "Offline") + settings.FromConfig(user_conf, _MLPERF_ID, "Offline") + log.info("Mlperf config: %s", FLAGS.mlperf_conf) + log.info("User config: %s", user_conf) + + dataset = pd.read_pickle(FLAGS.dataset_path) + rows = list(dataset.iterrows()) + counts_by_bucket = _count_by_bucket(dataset) + log.info(f"Counts by bucket {counts_by_bucket}") + + if FLAGS.model_name == "mixtral": + length_and_batch = ( + (512, 2048), + (1280, 512), + (3072, 256), + ) + else: + length_and_batch = ( + (512, 512), + (1024, 256), + (2048, 96), + ) + engines = [] + params = None + for i, (length, max_batch) in enumerate(length_and_batch): + batch = min(counts_by_bucket[i], max_batch) + log.info(f"Using batch size of {batch} for {length}") + engine = create_engine_from_config_flags(batch=batch, cache_len=length) + offline_inf = offline_inference.OfflineInference(engine, params) + offline_inf.dummy = FLAGS.internal_dummy_model + params = offline_inf.params + engines.append(offline_inf) + + if not FLAGS.skip_warmup: + with timed("warmup"): + for (length, _), engine in zip(length_and_batch, engines): + log.info(f"warm up for {length}") + engine.init_decode_state() + engine.warmup(length) + if length != 3072: + # dont need to drop state for the last one + engine.decode_state = None # drop state + gc.collect() + + sut = SUT(dataset, engines) + + if FLAGS.mlperf_test_mode == "accuracy": + settings.mode = lg.TestMode.AccuracyOnly + log.warning( + "Accuracy run will generate the accuracy logs, but the evaluation of the log is not completed yet" + ) + elif FLAGS.mlperf_test_mode == "submission": + settings.mode = lg.TestMode.SubmissionRun + settings.print_timestamps = True + else: + settings.mode = lg.TestMode.PerformanceOnly + settings.print_timestamps = True + + settings.use_token_latencies = True + + os.makedirs(FLAGS.output_log_dir, exist_ok=True) + log_output_settings = lg.LogOutputSettings() + log_output_settings.outdir = FLAGS.output_log_dir + log_output_settings.copy_summary_to_stdout = True + log_settings = lg.LogSettings() + log_settings.log_output = log_output_settings + log_settings.enable_trace = FLAGS.enable_log_trace + + lgSUT = lg.ConstructSUT(sut.issue_queries, sut.flush_queries) + qsl = lg.ConstructQSL( + FLAGS.total_sample_count, + FLAGS.total_sample_count, + sut.LoadSamplesToRam, + sut.UnloadSamplesFromRam, + ) + log.info("Starting Benchmark run") + lg.StartTestWithLogSettings( + lgSUT, qsl, settings, log_settings, FLAGS.audit_conf + ) + log.info(f"query counts {[len(q) for q in sut._groupped_queries]}") + log.info("Run Completed!") + log.info("Destroying SUT...") + lg.DestroySUT(lgSUT) + + log.info("Destroying QSL...") + lg.DestroyQSL(qsl) + + +if __name__ == "__main__": + # Disable garbage collection to avoid stalls when running tests. + gc.disable() + app.run(main) diff --git a/mlperf/user.conf b/mlperf/user.conf new file mode 100644 index 00000000..95ef75ef --- /dev/null +++ b/mlperf/user.conf @@ -0,0 +1,6 @@ +mixtral-8x7b.Server.target_qps = 2.0 +mixtral-8x7b.Offline.target_qps = 100.0 + +# send unique queries +mixtral-8x7b.Offline.performance_issue_unique = 1 + diff --git a/run_interactive.py b/run_interactive.py index eef2def8..e5193db3 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -42,11 +42,17 @@ def main(argv): max_output_length = 1024 profiling_output = FLAGS.profiling_output - profiling_prefill = FLAGS.profiling_prefill - if profiling_output and profiling_prefill: - jax.profiler.start_trace(profiling_output) + profiling_prefill = ( + FLAGS.profiling_prefill + and profiling_output is not None + and profiling_output != "" + ) + if profiling_prefill: + jax.profiler.start_trace(profiling_output) decode_state = engine.init_decode_state() + if profiling_prefill: + jax.profiler.stop_trace() prompts: List[str] = [ "I believe the meaning of life is", "To add an element to an ArrayList of a specific class type in Java, you can follow the following steps:\n\n1. Create an instance of the class to be added.\n2. Get a reference to the ArrayList.\n3. Call the `add()` method on the ArrayList, passing the instance of the class as the argument.\n\nHere's an example of how to add an object of type `Person` to an ArrayList of type `ArrayList`:\n```csharp\n// Create a new instance of the Person class\nPerson person = new Person(\"John\", 25);\n\n// Get a reference to the ArrayList\nArrayList peopleList = new ArrayList<>();\n\n// Add the person object to the ArrayList\npeopleList.add(person);\n```\nIn this example, the `Person` class is assumed to have a constructor that takes two arguments: a String for the person's name, and an int for their age. You can substitute your own class and constructor as necessary.", @@ -62,21 +68,27 @@ def main(argv): print(f"---- Encoded tokens are: {tokens}") # pylint: disable-next=all + if profiling_prefill: + jax.profiler.start_trace(profiling_output) prefill_result, _ = engine.prefill( params=params, padded_tokens=tokens, true_length=true_length ) # pylint: disable-next=all decode_state = engine.insert(prefill_result, decode_state, slot=slot) + if profiling_prefill: + jax.profiler.stop_trace() + sampled_tokens_list = [] print(f"---- Streaming decode started on #slot{slot}.") complete = np.zeros((1,), dtype=np.bool_) while True: - if profiling_output and not profiling_prefill: + if profiling_output: jax.profiler.start_trace(profiling_output) decode_state, result_tokens = engine.generate(params, decode_state) - if profiling_output and not profiling_prefill: - jax.profiler.stop_trace() result_tokens = result_tokens.convert_to_numpy() + + if profiling_output: + jax.profiler.stop_trace() output, complete = token_utils.process_result_tokens( tokenizer=tokenizer, slot=slot, @@ -94,9 +106,6 @@ def main(argv): print("---- All output text.") print(tokenizer.decode(sampled_tokens_list)) - if profiling_output and profiling_prefill: - jax.profiler.stop_trace() - if __name__ == "__main__": os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" diff --git a/tests/helpers.py b/tests/helpers.py index 00442517..62c0789b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -6,7 +6,7 @@ from jetstream_pt import environment -def make_env_tiny(bf16_enable=True): +def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) @@ -26,6 +26,8 @@ def make_env_tiny(bf16_enable=True): environment_data.cache_sequence_length, config.dim // config.n_heads, ) + environment_data.testing = True + env_data_update_fn(environment_data) env = environment.JetEngineEnvironment(environment_data) env.apply_sharding = lambda *args, **kwargs: None # don't shard on cpu return env, config diff --git a/tests/test_llama_e2e.py b/tests/test_llama_e2e.py index dcbcf5f2..73d0ce6c 100644 --- a/tests/test_llama_e2e.py +++ b/tests/test_llama_e2e.py @@ -23,15 +23,16 @@ import torch_xla2 from torch.utils import _pytree as pytree - from jetstream_pt.engine import PyTorchEngine from jetstream_pt.third_party.llama import model_exportable, model_args from jetstream_pt.third_party.llama.generation_original import LlamaOriginal from jetstream_pt import environment from tests import helpers +from jetstream_pt import torchjax +from absl.testing import parameterized -class LlamaE2ETest(unittest.TestCase): +class LlamaE2ETest(parameterized.TestCase): """This test class includes all E2E test for llama2""" def _from_torch(self, tree): @@ -187,6 +188,9 @@ def _llama_e2e(self, env, model_arg): model_ours = model_exportable.Transformer(model_arg, env) + for k, v in model_ours.state_dict().items(): + if "scale" in k: + state_dict[k] = helpers.to_xla_tensor(v) engine = PyTorchEngine(pt_model=model_ours, env=env) params = self._from_torch(state_dict) @@ -233,6 +237,58 @@ def test_llama_e2e_bfloat16(self): out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) self.assertNotEqual(out_tokens, expected_output_tokens) + @parameterized.named_parameters( + ("ring_buffer_f32", True, False, False), + ("left_aligned_f32", False, False, False), + ) + def test_llama_e2e_result_verification( + self, ring_buffer, quantized, bf16_enabled + ): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.quant_config.enable_kv_quantization = quantized + + env, model_arg = helpers.make_env_tiny(bf16_enabled, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertEqual(out_tokens, expected_output_tokens) + + @parameterized.named_parameters( + ("ring_buffer_int8", True, True, True), + ("ring_buffer_bf16", True, False, True), + ("left_aligned_int8", False, True, True), + ("left_aligned_bf16", False, False, True), + ) + def test_llama_e2e_no_result_verification( + self, ring_buffer, quantized, bf16_enabled + ): + """end to end jetstream llama test with float32""" + jax.config.update("jax_platform_name", "cpu") + print(f"---------> {jax.devices()}") + + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.quant_config.enable_kv_quantization = quantized + + env, model_arg = helpers.make_env_tiny(bf16_enabled, update_env_data) + out_tokens, expected_output_tokens = self._llama_e2e(env, model_arg) + self.assertNotEqual(out_tokens, expected_output_tokens) + # pylint: disable-next=all def test_llama_e2e_two_addtional_tokens(self): """end to end jetstream llama with addtional tokens""" diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 4d4ddfd6..703ce444 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -65,9 +65,13 @@ def _make_freqs_cis(self, model_arg, seqlen, start_pos): freqs_cis = freqs_cis[start_pos : start_pos + seqlen] return freqs_cis - def _generate_mask(self, cache_length, pos, seqlen): + def _generate_mask(self, cache_length, pos, seqlen, ring_buffer=True): x = jnp.arange(0, cache_length) - cond = jnp.logical_and(x <= pos, x >= pos - seqlen) + if ring_buffer: + cond = jnp.logical_and(x <= pos, x >= pos - seqlen) + else: + # Left aligned buffer we postpone the cache update + cond = jnp.logical_and(x < pos, x >= pos - seqlen) res = jnp.where(cond, 0, float("-inf")) return torchjax.to_torch(res) @@ -91,6 +95,7 @@ def _make_one_cache_for_generate(self, env, pos): # pylint: disable-next=all def test_attention(self): + torch.manual_seed(0) env, model_arg = helpers.make_env_tiny(False) attention_orig = model_original.Attention(model_arg) @@ -101,6 +106,7 @@ def test_attention(self): hidden_size=model_arg.dim, device="cpu", env=env, + layer_id=0, ) seqlen = 32 @@ -136,11 +142,11 @@ def test_attention(self): # insert prefilled cache entry cache_decode.cache_k._elem = cache_decode.cache_k._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_k._elem) cache_decode.cache_v._elem = cache_decode.cache_v._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_v._elem) # self._compare_cache(attention_orig.cache_k, cache_decode.cache_k) @@ -154,7 +160,7 @@ def test_attention(self): None, # mask is none for decode ) expected_out = attention_orig(*inputs_orig2) - cache_decode.pos = [pos] # next position to update + cache_decode.input_pos = [pos] # next position to update mask = self._generate_mask(env.cache_sequence_length, pos, seqlen) mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one freqs_cis = freqs_cis.reshape(batch, 1, -1) @@ -203,6 +209,7 @@ def init_weights(model): head_dim=head_dim, device="meta", env=env, + layer_id=0, ) def load_hook(state_dict, prefix, *args): @@ -228,8 +235,8 @@ def load_hook(state_dict, prefix, *args): freqs_cis = self._make_freqs_cis(model_arg, seqlen, start_pos) mask = self._prefill_mask(seqlen, start_pos) kv_write_indexes = torch.arange(0, seqlen) - cache_k = torch.zeros((batch, seqlen, num_heads, head_dim)) - cache_v = torch.zeros((batch, seqlen, num_heads, head_dim)) + cache_k = torch.zeros((batch, seqlen, num_kv_heads, head_dim)) + cache_v = torch.zeros((batch, seqlen, num_kv_heads, head_dim)) inputs_orig = (x, freqs_cis, kv_write_indexes, (cache_k, cache_v), mask) expected_out = attention_orig(*inputs_orig) @@ -300,10 +307,10 @@ def test_transformer_block(self): # insert prefilled cache entry cache_decode.cache_k._elem = cache_decode.cache_k._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_k._elem) cache_decode.cache_v._elem = cache_decode.cache_v._elem.at[ - :, :, :pos, : + ..., :pos, : ].set(cache.cache_v._elem) # Now do one with decode @@ -316,7 +323,7 @@ def test_transformer_block(self): None, # mask is none for decode ) expected_out = block_orig(*inputs_orig2) - cache_decode.pos = [pos] # next position to update + cache_decode.input_pos = [pos] # next position to update mask = self._generate_mask(env.cache_sequence_length, pos, seqlen) mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one freqs_cis = freqs_cis.reshape(batch, 1, -1) diff --git a/tests/test_quantization.py b/tests/test_quantization.py index 581553d1..6e0e4866 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -21,7 +21,7 @@ import torch import torch_xla2 from jax.experimental import mesh_utils -from jetstream_pt import cache_manager, layers, quantize, torchjax +from jetstream_pt import cache_manager, layers, quantize, torchjax, environment from jetstream_pt.environment import QuantizationConfig from jetstream_pt.layers import ( WeightOnlyBlockwiseQuantizedLinear, @@ -31,11 +31,13 @@ from tests import helpers from torch.utils import _pytree as pytree from torch_xla2 import tensor +import copy +from absl.testing import parameterized torch.manual_seed(12345) -class QuantizationTest(unittest.TestCase): +class QuantizationTest(parameterized.TestCase): """test kv cache quantization""" def _xla_tensor(self, shape): @@ -68,72 +70,216 @@ def _print_diff(self, w, w_dq): print(" norm: ", (w - w_dq).norm()) print(" cosine dist: ", self._calc_cosine_dist(w, w_dq)) - def test_kv_cache(self): + @parameterized.named_parameters( + ("ring_buffer", True), + ("left_aligned", False), + ) + def test_kv_cache(self, ring_buffer): """test kv cache quantization""" - cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim + + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.quant_config.enable_kv_quantization = True + env_data.batch_size = 4 + + env, _ = helpers.make_env_tiny(True, update_env_data) + + batch = env.batch_size + if env.generate_cache_stacked: + cache_shape = ( + env.num_layers, + batch, + 2, + 100, + 2, + ) # layer, bs, num heads, seqlen, dim + else: + cache_shape = (batch, 2, 100, 2) # bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - env, _ = helpers.make_env_tiny() - cache = cache_manager.Int8KVCacheGenerate.empty( - cache_shape, None, False, env - ) - # seqlen is 1 - k = self._xla_tensor((3, 2, 1, 2)) - v = self._xla_tensor((3, 2, 1, 2)) - cache.input_pos = [57] - new_k, new_v, scaler_k, scaler_v = cache.update(k, v) - new_k = new_k * scaler_k - new_v = new_v * scaler_v + cache = cache_manager.Int8KVCacheGenerate.empty(cache_shape, None, env) + # seqlen is 1 + k = self._xla_tensor((batch, 2, 1, 2)) + v = self._xla_tensor((batch, 2, 1, 2)) + + def update_finalize_compare(in_k, in_v, in_layer, in_pos): + cache.input_pos = ( + [in_pos] if env.ring_buffer else jnp.array([in_pos] * batch) + ) + + # layer id may or may not take effect, depends on the env config. + cache.update(in_k, in_v, layer_id=in_layer) + cache.finalize() + if env.quant_config.enable_kv_quantization: + new_k = cache.cache_k * cache.k_scaler + new_v = cache.cache_v * cache.v_scaler + else: + new_k = cache.cache_k + new_v = cache.cache_v + + if env.generate_cache_stacked: + self.assertTrue( + jnp.allclose( + k._elem, + new_k._elem[in_layer, :, :, in_pos : (in_pos + 1), :], + atol=0.1, + ) + ) + self.assertTrue( + jnp.allclose( + v._elem, + new_v._elem[in_layer, :, :, in_pos : (in_pos + 1), :], + atol=0.1, + ) + ) + else: + self.assertTrue( + jnp.allclose( + k._elem, new_k._elem[:, :, in_pos : (in_pos + 1), :], atol=0.1 + ) + ) + self.assertTrue( + jnp.allclose( + v._elem, new_v._elem[:, :, in_pos : (in_pos + 1), :], atol=0.1 + ) + ) + + update_finalize_compare(k, v, in_layer=1, in_pos=57) + update_finalize_compare(k, v, in_layer=1, in_pos=58) + update_finalize_compare(k, v, in_layer=2, in_pos=3) + + @parameterized.named_parameters( + ("ring_buffer", True), + ("left_aligned", False), + ) + def test_kv_kernel(self, ring_buffer): + """test kv cache quantization""" - self.assertTrue( - jnp.allclose(k._elem, new_k._elem[:, :, 57:58, :], atol=0.1) - ) - self.assertTrue( - jnp.allclose(v._elem, new_v._elem[:, :, 57:58, :], atol=0.1) - ) + def update_env_data(env_data): + env_data.ring_buffer = ring_buffer + env_data.ragged_mha = not ring_buffer + env_data.flash_attention = not ring_buffer + env_data.generate_cache_stacked = not ring_buffer + env_data.new_cache_stacked = not ring_buffer + env_data.lazy_cache_update = not ring_buffer + env_data.quant_config.enable_kv_quantization = True + env_data.batch_size = 4 + + env, _ = helpers.make_env_tiny(False, update_env_data) + + batch = env.batch_size + if env.generate_cache_stacked: + cache_shape = ( + env.num_layers, + batch, + 2, + 100, + 2, + ) # bs, num heads, seqlen, dim + else: + cache_shape = (batch, 2, 100, 2) # layers, bs, num heads, seqlen, dim - def test_kv_kernel(self): - """test kv cache quantization""" - cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - env, _ = helpers.make_env_tiny(False) + key = jax.random.PRNGKey(123) key2 = jax.random.PRNGKey(456) - cache_k_jax = jax.random.normal(key, cache_shape) - cache_v_jax = jax.random.normal(key2, cache_shape) + cache_k_jax = jax.random.normal(key, cache_shape, dtype=env.default_type) + cache_v_jax = jax.random.normal(key2, cache_shape, dtype=env.default_type) - cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) + start = jnp.zeros((batch,), dtype=jnp.int32) - cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None, env) + cache_k, cache_v, start = torchjax.to_torch( + (cache_k_jax, cache_v_jax, start) + ) + + # Prepare quantized cache before written in + cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (-3, -1)) + cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (-3, -1)) # 1 is seqlen - xq = jax.random.normal(key, (3, 2, 1, 2)) - xk = jax.random.normal(key, (3, 2, 1, 2)) - xv = jax.random.normal(key, (3, 2, 1, 2)) + xq = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) + xk = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) + xv = jax.random.normal(key, (batch, 2, 1, 2), dtype=env.default_type) xq, xk, xv = torchjax.to_torch((xq, xk, xv)) - attention_float = layers.AttentionKernel(env) - float_res = attention_float(xq, xk, xv, None, cache) + def get_var(position: int): + pos = ( + [position] + if env.ring_buffer + else jnp.array([position] * batch, dtype=jnp.int64) + ) + mask = jax.lax.broadcast_in_dim( + jnp.array([0] * position + [float("-inf")] * (100 - position)), + (env.batch_size, 1, 1, 100), + (3,), + ) + mask = torchjax.to_torch((mask)) + return pos, mask + + cache = cache_manager.KVCacheGenerate(cache_k, cache_v, None, None, env) + # layer_id doesn't matter, will assign later + attention_float = layers.AttentionKernel(env, layer_id=0) + + float_res = [] + + def update_finalize_record( + in_attention, in_cache, in_q, in_k, in_v, in_layer, in_pos + ): + pos, mask = get_var(in_pos) + in_attention.layer_id = in_layer + in_cache.input_pos = pos + ret = in_attention( + in_q, in_k, in_v, mask, in_cache, start=start, end=pos + ) + in_cache.finalize() + return ret + + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 1, 57) + ) + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 1, 58) + ) + float_res.append( + update_finalize_record(attention_float, cache, xq, xk, xv, 2, 3) + ) - # == + # Running into the issue of multiple env object always share the same quant_config. + # Record the results and compare as a workaround. + env._data.quant_config.enable_kv_quantization = True + env = environment.JetEngineEnvironment(env._data) - cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) - cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (1, 3)) - cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (1, 3)) cache_int = cache_manager.Int8KVCacheGenerate( cache_k_int, cache_v_int, cache_k_scaler, cache_v_scaler, - [0], + None, None, env, ) - attention_quant = layers.Int8KVAttentionKernel(env) - int_res = attention_quant(xq, xk, xv, None, cache_int) + # layer_id doesn't matter, will assign later + attention_quant = layers.Int8KVAttentionKernel(env, layer_id=0) + + int_res = [] + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 57) + ) + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 1, 58) + ) + int_res.append( + update_finalize_record(attention_quant, cache_int, xq, xk, xv, 2, 3) + ) - self.assertTrue(jnp.allclose(float_res.jax(), int_res.jax(), atol=0.01)) + for f, i in zip(float_res, int_res): + self.assertTrue(jnp.allclose(f.jax(), i.jax(), atol=0.01)) def test_quantize_dequantize_tensor(self):