Skip to content

Commit

Permalink
commit act quant for conditional ffn
Browse files Browse the repository at this point in the history
init params

add other scripts

debug accuracy
  • Loading branch information
qihqi committed Jul 23, 2024
1 parent b7a2310 commit b00be7f
Show file tree
Hide file tree
Showing 14 changed files with 1,435 additions and 38 deletions.
2 changes: 2 additions & 0 deletions benchmarks/mixtral_offline.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ BATCH_SIZE=$2
INPUT_SIZE=1024
OUTPUT_SIZE=1024
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2"
export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization"

pushd ..
python -m benchmarks.run_offline \
Expand Down
42 changes: 11 additions & 31 deletions benchmarks/offline_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,48 +21,28 @@ class Stat:
Stat(
cache_size = 512,
batch_size = 2048,
prefill_times = {
16: 0.016024088603444397,
32: 0.021154335999926843,
64: 0.02999803279999469,
128: 0.043986773600045125, 256: 0.07524209819985117, 512: 0.13882793779994246},
decode_time = 0.28033976474989686
prefill_times = {16: 0.02084908019969589, 32: 0.024125573800120037, 64: 0.02697298339990084, 128: 0.03641403259971412, 256: 0.05809259879970341, 512: 0.10703752639965387},
decode_time = 0.359
#ecode_time = 0.28
),
Stat(
cache_size = 1280,
batch_size = 512,
prefill_times = {
16: 0.016024088603444397,
32: 0.020686019999993734, 64: 0.02952769919993443, 128: 0.04383329960000992, 256: 0.07538782240008005, 512: 0.13893127239989553, 1024: 0.2693996697998955},
decode_time=0.11505070800001249,
prefill_times={16: 0.02070321020000847, 32: 0.02408570580009837, 64: 0.02650543759955326, 128: 0.036217428799864136, 256: 0.057748028799687746, 512: 0.10604073840004276, 1024: 0.20993155719988862},
decode_time=0.094,
),
Stat(
cache_size = 3072,
batch_size = 256,
prefill_times = {32: 0.021193669800049976, 64: 0.030565194799964956, 128: 0.04334795760005363, 256: 0.07586566419995507, 512: 0.13899565000010625, 1024: 0.26945373279995694, 2048: 0.35605709000010394},
decode_time = 0.06467210225014242,
)
prefill_times={16: 0.020371186199918158, 32: 0.024281639599939807, 64: 0.02710893359981128, 128: 0.03605372060046648, 256: 0.0574128626001766, 512: 0.10610043820051943, 1024: 0.2097496903996216, 2048: 0.4301163775999157},
decode_time = 0.0552,
),
]

scenario2 = [
Stat(
cache_size = 3072,
batch_size = 256,
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
decode_time = 0.0631,
),
Stat(
cache_size = 3072,
batch_size = 256,
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
decode_time = 0.0631,
),
Stat(
cache_size = 3072,
batch_size = 256,
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
decode_time = 0.0631,
)
scenario1[2],
scenario1[2],
scenario1[2]
]
def eval_scenario(dataset, scenario):

Expand Down
2 changes: 2 additions & 0 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def main(argv):
decode_state = engine.init_decode_state()
profiler_started = False
for batch, _ in MAXTEXT_PREFILL.items():
if batch > FLAGS.max_cache_length:
continue
runtime, decode_state, profiler_started = run_prefill_time(
engine, params, decode_state, batch, profiler_started
)
Expand Down
9 changes: 6 additions & 3 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,14 @@ def create_quantization_config_from_flags():
return config


def create_engine_from_config_flags():
def create_engine_from_config_flags(batch=None, cache_len=None):
"""create a pytorch engine from cmd flag"""
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

batch = batch or FLAGS.batch_size
cache_len = cache_len or FLAGS.max_cache_length

devices = jax.devices()
start = time.perf_counter()

Expand Down Expand Up @@ -196,9 +199,9 @@ def create_engine_from_config_flags():
bf16_enable=FLAGS.bf16_enable,
param_size=FLAGS.size,
context_length=FLAGS.context_length,
batch_size=FLAGS.batch_size,
batch_size=batch,
quant_config=quant_config,
max_cache_length=FLAGS.max_cache_length,
max_cache_length=cache_len,
max_decode_length=FLAGS.max_decode_length,
sharding_config=sharding_file_name,
shard_on_batch=FLAGS.shard_on_batch,
Expand Down
192 changes: 192 additions & 0 deletions jetstream_pt/offline_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Callable
import dataclasses
from collections import defaultdict
import jax
from jax import numpy as jnp
import numpy as np

from jetstream.engine import engine_api

import logging

log = logging.getLogger(__name__)


@dataclasses.dataclass
class InputData:
id: str
tokens: jax.Array
true_length: int


class OfflineInference:

def __init__(self, engine: engine_api.Engine, params=None):
self.engine = engine
self.decode_state = None
if params is None:
params = engine.load_params()
self.params = params

self.batch_size = engine.env.batch_size
self.max_decode_length = engine.max_decode_length
metadata = engine.get_tokenizer()
self.tokenizer = engine.build_tokenizer(metadata)
self.dummy = False

self._cached_pref = {}
self._cached_generate = None

def init_decode_state(self):
if self.decode_state is None:
self.decode_state = self.engine.init_decode_state()

def warmup(self, max_length=2048):
self.init_decode_state()
interesting_buckets = [
32,
64,
128,
256,
512,
1024,
2048,
4096,
]
for length in interesting_buckets:
if length > max_length:
break
log.info(f"Compiling prefill: {length}")
input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32"))
self._cached_pref[length] = (
jax.jit(self._prefill_insert, donate_argnums=(4,))
.lower(
self.params,
tokens=input_data,
slot=0,
true_length=length - 1,
decode_state=self.decode_state)
.compile()
)

log.info(f"Compiling decode")
self._cached_generate = (
jax.jit(self.engine.generate, donate_argnums=(1,))
.lower(self.params, self.decode_state)
.compile()
)

def _prefill_insert(self, params, tokens, slot, true_length, decode_state):
"""return decodestate."""
prefill_result, first_token = self.engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = self.engine.insert(prefill_result, decode_state, slot=slot)
return first_token, decode_state

def batch_inference_with_callback(
self,
data: InputData,
emit_first_token: Callable[[str, int], bool],
emit_token: Callable[[str, int], bool],
):
"""callback is a function that takes id and token. It will be called once per output
token.
"""

def prefill(slot, tokens, true_length):
nonlocal self
if self.dummy:
log.debug("dummy prefill")
return 123

prefill_fn = self._prefill_insert
if (cached := self._cached_pref.get(len(tokens))) is not None:
prefill_fn = cached

first_token, self.decode_state = prefill_fn(
self.params, tokens=tokens, slot=slot,
true_length=true_length, decode_state=self.decode_state
)
return first_token.data[0][0].item()

empty_slots = list(range(self.batch_size))
slot_to_id = {}

dummy_length = 1

def decode():
log.debug("decode")
nonlocal self
nonlocal slot_to_id
nonlocal dummy_length
if self.dummy:
log.debug("Dummy generate")
res = engine_api.ResultTokens(
data=np.array([[123, 1, dummy_length]] * self.batch_size),
tokens_idx=(0, 0),
valid_idx=(0, 0),
length_idx=(0, 0),
samples_per_slot=(0, 0),
)
dummy_length += 1
self.decode_state, result_tokens = self.decode_state, res
else:
gen_fn = self.engine.generate
if self._cached_generate is not None:
gen_fn = self._cached_generate
self.decode_state, result_tokens = gen_fn(
self.params, self.decode_state
)

result_tokens = result_tokens.convert_to_numpy()

newly_empty = []
for slot, id_ in slot_to_id.items():
token, is_valid, length = result_tokens.data[slot]
log.debug(f"slot is {slot}, length is {length}")
should_finish = False
if is_valid:
should_finish = emit_token(id_, token.item())
if should_finish or length >= self.max_decode_length:
newly_empty.append(slot)

# Add slots of those that are empty to emtpy
for slot in newly_empty:
del slot_to_id[slot]
empty_slots.append(slot)

for row in data:
log.debug(f"empty_slots {len(empty_slots)}")
while not empty_slots:
# If slots are all full, decode until there are free slots
# to insert
decode()
# do one insert
log.debug(f"prefill {row.id}")
slot = empty_slots.pop()
first_token = prefill(slot, row.tokens, row.true_length)
should_terminate = emit_first_token(row.id, first_token)
if not should_terminate:
slot_to_id[slot] = row.id
else:
empty_slots.append(slot) # dont use the slot

while slot_to_id:
log.debug(f"slot to id {len(slot_to_id)}")
decode()

def batch_inference(self, data: InputData):
"""data is list of obj with id, tokens, and true length"""
ans = defaultdict(list)

def callback(id_, token):
nonlocal ans
ans[id_].append(token)
return token == self.tokenizer.eos_id

self.batch_inference_with_callback(
data, emit_first_token=callback, emit_token=callback
)
return ans
41 changes: 37 additions & 4 deletions jetstream_pt/third_party/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from torch.nn import functional as F
from .config import ModelArgs, find_multiple
from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer
from jetstream_pt import quantize, torchjax

import jax
import jax.numpy as jnp


class Transformer(nn.Module):
Expand Down Expand Up @@ -233,6 +235,31 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
else:
return self.forward_for_short_seq_len(x, expert_indices)

def _int_ti_eoi_teo(self, lhs, rhs):
# x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler)
result = torchjax.call_jax(
jax.lax.dot_general,
lhs,
rhs,
(((1,), (2)), ((), ())),
None,
jnp.bfloat16.dtype,
)
return result

def _int_teo_eio_tei(self, lhs, rhs):
#torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler
result = torchjax.call_jax(
jax.lax.dot_general,
lhs,
rhs,
(((2,), (2,)), ((1, ), (0, ))),
None,
jnp.bfloat16.dtype,
) # output is (eti) for some reason
return result.transpose(0, 1)


def forward_for_short_seq_len(
self, x: Tensor, expert_indices: Tensor
) -> Tensor:
Expand Down Expand Up @@ -260,14 +287,20 @@ def forward_for_long_seq_len(self, x, expert_indices):
# o = config.imtermediate size
# i = config.dim
with jax.named_scope("conditional_ff"):
x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler)
x3 = torch.einsum("ti, eoi-> teo", x, self.w3) * self.w3_scaler
x_int, x_scaler, _ = quantize.quantize_tensor(x, (1,))
x_scaler = x_scaler.reshape(seqlen, 1, 1)

x1 = F.silu(self._int_ti_eoi_teo(x_int, self.w1) * self.w1_scaler * x_scaler)
x3 = self._int_ti_eoi_teo(x_int, self.w3) * self.w3_scaler * x_scaler

x1x3_int, x1x3_scaler, _ = quantize.quantize_tensor(x1 * x3, (1, 2))
x1x3_scaler = x1x3_scaler.reshape(seqlen, 1, 1)
expert_outs = (
torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler
self._int_teo_eio_tei(x1x3_int, self.w2) * self.w2_scaler
)
# e = 8; need to reduce to 2
seq_indexes = torch.arange(seqlen).unsqueeze(1)
return expert_outs[seq_indexes, expert_indices]
return expert_outs[seq_indexes, expert_indices] * x1x3_scaler


class ConditionalFeedForward(nn.Module):
Expand Down
Loading

0 comments on commit b00be7f

Please sign in to comment.