Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

commit act quant for conditional ffn #156

Draft
wants to merge 3 commits into
base: mlperf-mixtral
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions benchmarks/mixtral_offline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
CACHE_LENGTH=$1
BATCH_SIZE=$2
INPUT_SIZE=1024
OUTPUT_SIZE=1024
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2"
export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization"

pushd ..
python -m benchmarks.run_offline \
--lazy_cache_update=1 \
--ring_buffer=0 \
--model_name=mixtral \
--batch_size=$BATCH_SIZE \
--max_cache_length=$CACHE_LENGTH \
--max_decode_length=$OUTPUT_SIZE \
--context_length=$INPUT_SIZE \
--checkpoint_path=$CHECKPOINT_PATH/model.safetensors \
--tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \
--quantize_weights=1 \
--quantize_type=int8_per_channel \
--quantize_kv_cache=1 \
--profiling_output=/mnt/disks/hanq/mixtral-profiles
popd
echo "batch was $2 cache was $1"
97 changes: 97 additions & 0 deletions benchmarks/offline_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import math
import pandas as pd
import dataclasses
from collections import defaultdict
from absl import flags, app

from typing import Dict

FLAGS = flags.FLAGS

flags.DEFINE_string('dataset_path', '', '')

@dataclasses.dataclass
class Stat:
cache_size: int
batch_size: int
prefill_times: Dict[int, float]
decode_time: float

scenario1 = [
Stat(
cache_size = 512,
batch_size = 2048,
prefill_times = {16: 0.02084908019969589, 32: 0.024125573800120037, 64: 0.02697298339990084, 128: 0.03641403259971412, 256: 0.05809259879970341, 512: 0.10703752639965387},
decode_time = 0.359
#ecode_time = 0.28
),
Stat(
cache_size = 1280,
batch_size = 512,
prefill_times={16: 0.02070321020000847, 32: 0.02408570580009837, 64: 0.02650543759955326, 128: 0.036217428799864136, 256: 0.057748028799687746, 512: 0.10604073840004276, 1024: 0.20993155719988862},
decode_time=0.094,
),
Stat(
cache_size = 3072,
batch_size = 256,
prefill_times={16: 0.020371186199918158, 32: 0.024281639599939807, 64: 0.02710893359981128, 128: 0.03605372060046648, 256: 0.0574128626001766, 512: 0.10610043820051943, 1024: 0.2097496903996216, 2048: 0.4301163775999157},
decode_time = 0.0552,
),
]

scenario2 = [
scenario1[2],
scenario1[2],
scenario1[2]
]
def eval_scenario(dataset, scenario):

total_input_tokens = 0
total_output_tokens = 0
total_prefill_times = defaultdict(float)
total_decode_times = defaultdict(float)
output_tokens_by_bucket = defaultdict(int)
for _, data in dataset.iterrows():
stat = scenario[data.bucket]
total_input_tokens += data.tok_input_len
total_output_tokens += data.tok_ref_output_len
input_len_bucket = 2**math.ceil(math.log2(data.tok_input_len))
if input_len_bucket == 2048 and data.bucket == 1:
import pdb; pdb.set_trace()
total_prefill_times[input_len_bucket] += stat.prefill_times[input_len_bucket]
output_tokens_by_bucket[data.bucket] += data.tok_ref_output_len

for k in output_tokens_by_bucket.keys():
stat = scenario[k]
total_decode_times[k] = output_tokens_by_bucket[k] / stat.batch_size * scenario[k].decode_time

prefill_total = sum(total_prefill_times.values())
decode_total = sum(total_decode_times.values())
print('Total input tokens', total_input_tokens)
print('Total output tokens', total_output_tokens)
print('Input / output', total_input_tokens / total_output_tokens)
print('Prefill times', total_prefill_times)
print('pref throughput', total_input_tokens / sum(total_prefill_times.values()))
print('decode times', total_decode_times)
print('decode throughput', total_output_tokens / sum(total_decode_times.values()) )
print('overall throughput',
total_output_tokens /
(sum(total_decode_times.values()) + sum(total_prefill_times.values())))
print('prefill total time', prefill_total)
print('decode total time', decode_total)



def main(argv):
dataset = pd.read_pickle(FLAGS.dataset_path)
total_len = dataset.tok_input_len + dataset.tok_ref_output_len
bucket = 0 + (total_len > 512) + ((total_len > 1280) | (dataset.tok_input_len > 1024))
dataset.insert(2, 'bucket', bucket)
eval_scenario(dataset, scenario1)
print('======== scenario 2 ========')
eval_scenario(dataset, scenario2)

if __name__ == '__main__':
app.run(main)


23 changes: 17 additions & 6 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
flags.DEFINE_string("sharegpt_path", "", "path to sharegpt json file")


def run_prefill_time(engine, params, decode_state, seqlen):
def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
"""Run prefill and measure time."""
metadata = engine.get_tokenizer()
tokenizer = engine.build_tokenizer(metadata)
Expand All @@ -53,15 +53,20 @@ def run_prefill_time(engine, params, decode_state, seqlen):
nums = 5
start = time.perf_counter()
for i in range(nums):
if i == nums - 1 and FLAGS.profiling_prefill and not profiler_started:
jax.profiler.start_trace(FLAGS.profiling_output)
profiler_started = True

prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = engine.insert(
prefill_result, decode_state, slot=jnp.int32(i)
)
jax.block_until_ready(decode_state)

end = time.perf_counter()
return (end - start) / nums, decode_state
return (end - start) / nums, decode_state, profiler_started


MAXTEXT_PREFILL = {
Expand All @@ -72,6 +77,7 @@ def run_prefill_time(engine, params, decode_state, seqlen):
256: 23.59,
512: 35.28,
1024: 60.28,
2048: 60.28,
}


Expand All @@ -86,9 +92,12 @@ def main(argv):
prefill_times = {}

decode_state = engine.init_decode_state()
profiler_started = False
for batch, _ in MAXTEXT_PREFILL.items():
runtime, decode_state = run_prefill_time(
engine, params, decode_state, batch
if batch > FLAGS.max_cache_length:
continue
runtime, decode_state, profiler_started = run_prefill_time(
engine, params, decode_state, batch, profiler_started
)
prefill_times[batch] = runtime

Expand All @@ -103,10 +112,12 @@ def main(argv):

profiling_output = FLAGS.profiling_output
print("======= decode starting ===")

dec_times = []
for i in range(10):
if profiling_output and i == 7:
if profiling_output and i == 7 and not profiler_started:
jax.profiler.start_trace(profiling_output)
profiler_started = True
start = time.perf_counter()
# pylint: disable-next=all
decode_state, sampled_tokens = engine.generate(params, decode_state)
Expand All @@ -116,7 +127,7 @@ def main(argv):
dec_times.append(end - start)
print(i, "decode time", (end - start))

if profiling_output:
if profiler_started:
jax.profiler.stop_trace()

print("prefill ", prefill_times)
Expand Down
Loading
Loading