From b00be7ff6fd946ee44ae0e7486ceb1ce2fc599c5 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Tue, 23 Jul 2024 03:54:57 +0000 Subject: [PATCH] commit act quant for conditional ffn init params add other scripts debug accuracy --- benchmarks/mixtral_offline.sh | 2 + benchmarks/offline_benchmark.py | 42 +- benchmarks/run_offline.py | 2 + jetstream_pt/config.py | 9 +- jetstream_pt/offline_inference.py | 192 +++++++++ jetstream_pt/third_party/mixtral/model.py | 41 +- mlperf/accuracy_run.sh | 56 +++ mlperf/evaluate_accuracy.py | 251 ++++++++++++ mlperf/gmm.py | 202 +++++++++ mlperf/install.sh | 41 ++ mlperf/mixtral_run.sh | 55 +++ mlperf/mlperf.conf | 98 +++++ mlperf/offline_mode.py | 476 ++++++++++++++++++++++ mlperf/user.conf | 6 + 14 files changed, 1435 insertions(+), 38 deletions(-) create mode 100644 jetstream_pt/offline_inference.py create mode 100644 mlperf/accuracy_run.sh create mode 100644 mlperf/evaluate_accuracy.py create mode 100644 mlperf/gmm.py create mode 100644 mlperf/install.sh create mode 100755 mlperf/mixtral_run.sh create mode 100644 mlperf/mlperf.conf create mode 100644 mlperf/offline_mode.py create mode 100644 mlperf/user.conf diff --git a/benchmarks/mixtral_offline.sh b/benchmarks/mixtral_offline.sh index 5424fd7e..9572366f 100644 --- a/benchmarks/mixtral_offline.sh +++ b/benchmarks/mixtral_offline.sh @@ -3,6 +3,8 @@ BATCH_SIZE=$2 INPUT_SIZE=1024 OUTPUT_SIZE=1024 CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ +export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2" +export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization" pushd .. python -m benchmarks.run_offline \ diff --git a/benchmarks/offline_benchmark.py b/benchmarks/offline_benchmark.py index fe057ff1..0b007eea 100644 --- a/benchmarks/offline_benchmark.py +++ b/benchmarks/offline_benchmark.py @@ -21,48 +21,28 @@ class Stat: Stat( cache_size = 512, batch_size = 2048, - prefill_times = { - 16: 0.016024088603444397, - 32: 0.021154335999926843, - 64: 0.02999803279999469, - 128: 0.043986773600045125, 256: 0.07524209819985117, 512: 0.13882793779994246}, -decode_time = 0.28033976474989686 + prefill_times = {16: 0.02084908019969589, 32: 0.024125573800120037, 64: 0.02697298339990084, 128: 0.03641403259971412, 256: 0.05809259879970341, 512: 0.10703752639965387}, + decode_time = 0.359 + #ecode_time = 0.28 ), Stat( cache_size = 1280, batch_size = 512, - prefill_times = { - 16: 0.016024088603444397, - 32: 0.020686019999993734, 64: 0.02952769919993443, 128: 0.04383329960000992, 256: 0.07538782240008005, 512: 0.13893127239989553, 1024: 0.2693996697998955}, -decode_time=0.11505070800001249, + prefill_times={16: 0.02070321020000847, 32: 0.02408570580009837, 64: 0.02650543759955326, 128: 0.036217428799864136, 256: 0.057748028799687746, 512: 0.10604073840004276, 1024: 0.20993155719988862}, + decode_time=0.094, ), Stat( cache_size = 3072, batch_size = 256, - prefill_times = {32: 0.021193669800049976, 64: 0.030565194799964956, 128: 0.04334795760005363, 256: 0.07586566419995507, 512: 0.13899565000010625, 1024: 0.26945373279995694, 2048: 0.35605709000010394}, - decode_time = 0.06467210225014242, - ) + prefill_times={16: 0.020371186199918158, 32: 0.024281639599939807, 64: 0.02710893359981128, 128: 0.03605372060046648, 256: 0.0574128626001766, 512: 0.10610043820051943, 1024: 0.2097496903996216, 2048: 0.4301163775999157}, + decode_time = 0.0552, + ), ] scenario2 = [ - Stat( - cache_size = 3072, - batch_size = 256, - prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882}, - decode_time = 0.0631, - ), - Stat( - cache_size = 3072, - batch_size = 256, - prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882}, - decode_time = 0.0631, - ), - Stat( - cache_size = 3072, - batch_size = 256, - prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882}, - decode_time = 0.0631, - ) + scenario1[2], + scenario1[2], + scenario1[2] ] def eval_scenario(dataset, scenario): diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index dd664d8b..e705dfe5 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -94,6 +94,8 @@ def main(argv): decode_state = engine.init_decode_state() profiler_started = False for batch, _ in MAXTEXT_PREFILL.items(): + if batch > FLAGS.max_cache_length: + continue runtime, decode_state, profiler_started = run_prefill_time( engine, params, decode_state, batch, profiler_started ) diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 70b530fc..b22d0287 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -157,11 +157,14 @@ def create_quantization_config_from_flags(): return config -def create_engine_from_config_flags(): +def create_engine_from_config_flags(batch=None, cache_len=None): """create a pytorch engine from cmd flag""" jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + batch = batch or FLAGS.batch_size + cache_len = cache_len or FLAGS.max_cache_length + devices = jax.devices() start = time.perf_counter() @@ -196,9 +199,9 @@ def create_engine_from_config_flags(): bf16_enable=FLAGS.bf16_enable, param_size=FLAGS.size, context_length=FLAGS.context_length, - batch_size=FLAGS.batch_size, + batch_size=batch, quant_config=quant_config, - max_cache_length=FLAGS.max_cache_length, + max_cache_length=cache_len, max_decode_length=FLAGS.max_decode_length, sharding_config=sharding_file_name, shard_on_batch=FLAGS.shard_on_batch, diff --git a/jetstream_pt/offline_inference.py b/jetstream_pt/offline_inference.py new file mode 100644 index 00000000..1834ed3e --- /dev/null +++ b/jetstream_pt/offline_inference.py @@ -0,0 +1,192 @@ +from typing import Callable +import dataclasses +from collections import defaultdict +import jax +from jax import numpy as jnp +import numpy as np + +from jetstream.engine import engine_api + +import logging + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class InputData: + id: str + tokens: jax.Array + true_length: int + + +class OfflineInference: + + def __init__(self, engine: engine_api.Engine, params=None): + self.engine = engine + self.decode_state = None + if params is None: + params = engine.load_params() + self.params = params + + self.batch_size = engine.env.batch_size + self.max_decode_length = engine.max_decode_length + metadata = engine.get_tokenizer() + self.tokenizer = engine.build_tokenizer(metadata) + self.dummy = False + + self._cached_pref = {} + self._cached_generate = None + + def init_decode_state(self): + if self.decode_state is None: + self.decode_state = self.engine.init_decode_state() + + def warmup(self, max_length=2048): + self.init_decode_state() + interesting_buckets = [ + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + ] + for length in interesting_buckets: + if length > max_length: + break + log.info(f"Compiling prefill: {length}") + input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32")) + self._cached_pref[length] = ( + jax.jit(self._prefill_insert, donate_argnums=(4,)) + .lower( + self.params, + tokens=input_data, + slot=0, + true_length=length - 1, + decode_state=self.decode_state) + .compile() + ) + + log.info(f"Compiling decode") + self._cached_generate = ( + jax.jit(self.engine.generate, donate_argnums=(1,)) + .lower(self.params, self.decode_state) + .compile() + ) + + def _prefill_insert(self, params, tokens, slot, true_length, decode_state): + """return decodestate.""" + prefill_result, first_token = self.engine.prefill( + params=params, padded_tokens=tokens, true_length=true_length + ) + decode_state = self.engine.insert(prefill_result, decode_state, slot=slot) + return first_token, decode_state + + def batch_inference_with_callback( + self, + data: InputData, + emit_first_token: Callable[[str, int], bool], + emit_token: Callable[[str, int], bool], + ): + """callback is a function that takes id and token. It will be called once per output + + token. + """ + + def prefill(slot, tokens, true_length): + nonlocal self + if self.dummy: + log.debug("dummy prefill") + return 123 + + prefill_fn = self._prefill_insert + if (cached := self._cached_pref.get(len(tokens))) is not None: + prefill_fn = cached + + first_token, self.decode_state = prefill_fn( + self.params, tokens=tokens, slot=slot, + true_length=true_length, decode_state=self.decode_state + ) + return first_token.data[0][0].item() + + empty_slots = list(range(self.batch_size)) + slot_to_id = {} + + dummy_length = 1 + + def decode(): + log.debug("decode") + nonlocal self + nonlocal slot_to_id + nonlocal dummy_length + if self.dummy: + log.debug("Dummy generate") + res = engine_api.ResultTokens( + data=np.array([[123, 1, dummy_length]] * self.batch_size), + tokens_idx=(0, 0), + valid_idx=(0, 0), + length_idx=(0, 0), + samples_per_slot=(0, 0), + ) + dummy_length += 1 + self.decode_state, result_tokens = self.decode_state, res + else: + gen_fn = self.engine.generate + if self._cached_generate is not None: + gen_fn = self._cached_generate + self.decode_state, result_tokens = gen_fn( + self.params, self.decode_state + ) + + result_tokens = result_tokens.convert_to_numpy() + + newly_empty = [] + for slot, id_ in slot_to_id.items(): + token, is_valid, length = result_tokens.data[slot] + log.debug(f"slot is {slot}, length is {length}") + should_finish = False + if is_valid: + should_finish = emit_token(id_, token.item()) + if should_finish or length >= self.max_decode_length: + newly_empty.append(slot) + + # Add slots of those that are empty to emtpy + for slot in newly_empty: + del slot_to_id[slot] + empty_slots.append(slot) + + for row in data: + log.debug(f"empty_slots {len(empty_slots)}") + while not empty_slots: + # If slots are all full, decode until there are free slots + # to insert + decode() + # do one insert + log.debug(f"prefill {row.id}") + slot = empty_slots.pop() + first_token = prefill(slot, row.tokens, row.true_length) + should_terminate = emit_first_token(row.id, first_token) + if not should_terminate: + slot_to_id[slot] = row.id + else: + empty_slots.append(slot) # dont use the slot + + while slot_to_id: + log.debug(f"slot to id {len(slot_to_id)}") + decode() + + def batch_inference(self, data: InputData): + """data is list of obj with id, tokens, and true length""" + ans = defaultdict(list) + + def callback(id_, token): + nonlocal ans + ans[id_].append(token) + return token == self.tokenizer.eos_id + + self.batch_inference_with_callback( + data, emit_first_token=callback, emit_token=callback + ) + return ans diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py index 1602e886..422f4990 100644 --- a/jetstream_pt/third_party/mixtral/model.py +++ b/jetstream_pt/third_party/mixtral/model.py @@ -22,8 +22,10 @@ from torch.nn import functional as F from .config import ModelArgs, find_multiple from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer +from jetstream_pt import quantize, torchjax import jax +import jax.numpy as jnp class Transformer(nn.Module): @@ -233,6 +235,31 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: else: return self.forward_for_short_seq_len(x, expert_indices) + def _int_ti_eoi_teo(self, lhs, rhs): + # x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler) + result = torchjax.call_jax( + jax.lax.dot_general, + lhs, + rhs, + (((1,), (2)), ((), ())), + None, + jnp.bfloat16.dtype, + ) + return result + + def _int_teo_eio_tei(self, lhs, rhs): + #torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler + result = torchjax.call_jax( + jax.lax.dot_general, + lhs, + rhs, + (((2,), (2,)), ((1, ), (0, ))), + None, + jnp.bfloat16.dtype, + ) # output is (eti) for some reason + return result.transpose(0, 1) + + def forward_for_short_seq_len( self, x: Tensor, expert_indices: Tensor ) -> Tensor: @@ -260,14 +287,20 @@ def forward_for_long_seq_len(self, x, expert_indices): # o = config.imtermediate size # i = config.dim with jax.named_scope("conditional_ff"): - x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler) - x3 = torch.einsum("ti, eoi-> teo", x, self.w3) * self.w3_scaler + x_int, x_scaler, _ = quantize.quantize_tensor(x, (1,)) + x_scaler = x_scaler.reshape(seqlen, 1, 1) + + x1 = F.silu(self._int_ti_eoi_teo(x_int, self.w1) * self.w1_scaler * x_scaler) + x3 = self._int_ti_eoi_teo(x_int, self.w3) * self.w3_scaler * x_scaler + + x1x3_int, x1x3_scaler, _ = quantize.quantize_tensor(x1 * x3, (1, 2)) + x1x3_scaler = x1x3_scaler.reshape(seqlen, 1, 1) expert_outs = ( - torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler + self._int_teo_eio_tei(x1x3_int, self.w2) * self.w2_scaler ) # e = 8; need to reduce to 2 seq_indexes = torch.arange(seqlen).unsqueeze(1) - return expert_outs[seq_indexes, expert_indices] + return expert_outs[seq_indexes, expert_indices] * x1x3_scaler class ConditionalFeedForward(nn.Module): diff --git a/mlperf/accuracy_run.sh b/mlperf/accuracy_run.sh new file mode 100644 index 00000000..b940ae74 --- /dev/null +++ b/mlperf/accuracy_run.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +me=$(basename "$0") + +BASEDIR=mlperf +API_URL=0.0.0.0:9000 +USER_CONFIG=$BASEDIR/user.conf +DATA_DISK_DIR=$BASEDIR/data +TOTAL_SAMPLE_COUNT=1000 +DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl + +# HF model id +TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1" +LOADGEN_RUN_TYPE=offline-performance +OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} +OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} + +mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} + +OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json + +CACHE_LENGTH=1024 +INPUT_SIZE=512 +OUTPUT_SIZE=512 +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ + +LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" +# makes subsequent runs faster +export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2" +export LIBTPU_INIT_ARGS + +pushd .. +# python -m mlperf.offline_mode \ +# --model_name=mixtral \ +# --max_cache_length=$CACHE_LENGTH \ +# --max_decode_length=$OUTPUT_SIZE \ +# --context_length=$INPUT_SIZE \ +# --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ +# --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ +# --quantize_weights=1 \ +# --quantize_type=int8_per_channel \ +# --quantize_kv_cache=1 \ +# --scenario Offline \ +# --input_mode tokenized \ +# --output_mode tokenized \ +# --mlperf_conf $BASEDIR/mlperf.conf \ +# --user_conf ${USER_CONFIG} \ +# --audit_conf no_audit \ +# --total_sample_count ${TOTAL_SAMPLE_COUNT} \ +# --dataset_path ${DATASET_PATH} \ +# --output_log_dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log + +python -m mlperf.evaluate_accuracy \ + --checkpoint-path ${TOKENIZER_PATH} \ + --mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \ + --dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log +popd \ No newline at end of file diff --git a/mlperf/evaluate_accuracy.py b/mlperf/evaluate_accuracy.py new file mode 100644 index 00000000..2fd74e1e --- /dev/null +++ b/mlperf/evaluate_accuracy.py @@ -0,0 +1,251 @@ +import argparse +from transformers import AutoTokenizer +import nltk +import evaluate +import numpy as np +import pandas as pd +import json +import re + +import logging +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger("evaluate_accuracy.py") + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint-path", + required=True, + help="Path to Mixtral-8x7b-Instruct checkpoint", + ) + parser.add_argument( + "--mlperf-accuracy-file", + required=True, + help="path to mlperf_log_accuracy.json", + ) + parser.add_argument( + "--dataset-file", + required=True, + help="path to processed validation dataset", + ) + parser.add_argument( + "--n_workers", + default=2, + type=int, + help="Number of workers used for the MBXP evaluation", + ) + parser.add_argument("--verbose", action="store_true", help="verbose messages") + parser.add_argument( + "--dtype", + default="int64", + help="dtype of the accuracy log", + choices=["int32", "int64", "float"], + ) + args = parser.parse_args() + return args + + +def get_groundtruth(processed_dataset_file): + data = pd.read_pickle(processed_dataset_file) + return data + + +# Functions for evaluating GSM8K +def find_numbers(x: str) -> list[str]: + """Finds all numbers in a string.""" + # Search for number, possibly negative (hyphen), with thousand separators + # (comma), and with a decimal point (period inbetween digits). + numbers = re.compile( + r"-?[\d,]*\.?\d+", + re.MULTILINE | re.DOTALL | re.IGNORECASE, + ).findall(x) + return numbers + + +def find_number(x: str, answer_delimiter: str = "The answer is") -> str: + """Finds the most relevant number in a string.""" + # If model uses the answer delimiter, then select the first number following + # that format. + if answer_delimiter in x: + answer = x.split(answer_delimiter)[-1] + numbers = find_numbers(answer) + if numbers: + return numbers[0] + + # In general, select the last number in the string. + numbers = find_numbers(x) + if numbers: + return numbers[-1] + return "" + + +def maybe_remove_comma(x: str) -> str: + # Example: 5,600 -> 5600 + return x.replace(",", "") + + +def try_float(x: str): + try: + ret = float(x) + except BaseException: + ret = None + return ret + + +# Functions for evaluating OpenOrca + + +def postprocess_text(preds, targets): + preds = [pred.strip() for pred in preds] + targets = [target.strip() for target in targets] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] + + return preds, targets + + +# Functions for MBXP + + +def create_mbxp_dict(row, response): + lang, entry_point = row["id"].split("_", 1) + return { + "lang": lang, + "prompt": row["input"], + "test_code": row["gt_output"], + "entry_point": entry_point, + "response": response, + } + + +def main(): + + args = get_args() + dataset_path = args.dataset_file + checkpoint_path = args.checkpoint_path + metric = evaluate.load("rouge") + nltk.download("punkt") + + tokenizer = AutoTokenizer.from_pretrained( + checkpoint_path, + model_max_length=2048, + padding_side="left", + use_fast=False, + ) + + data = get_groundtruth(args.dataset_file) + query_types, gt_outputs = data["dataset"], data["gt_output"] + + target_required_GSM8K = [] + target_required_OpenOrca = [] + results_MBXP = [] + preds_token_GSM8K = [] + preds_token_OpenOrca = [] + preds_token_MBXP = [] + + eval_dtype = np.int64 + if args.dtype == "int32": + eval_dtype = np.int32 + elif args.dtype == "float": + eval_dtype = np.float32 + + with open(args.mlperf_accuracy_file, "r") as f: + results = json.load(f) + + seen = set() + gen_tok_len = 0 + gen_num = 0 + for pred in results: + gen_num += 1 + qsl_idx = pred["qsl_idx"] + if qsl_idx in seen: + continue + + seen.add(qsl_idx) + + query_type = query_types.iloc[qsl_idx] + if query_type == "GSM8K": + target = gt_outputs.iloc[qsl_idx] + target_required_GSM8K.append(target) + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + gen_tok_len += len(pred) + preds_token_GSM8K.append(pred) + elif query_type == "OpenOrca": + target = gt_outputs.iloc[qsl_idx] + target_required_OpenOrca.append(target) + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + preds_token_OpenOrca.append(pred) + gen_tok_len += len(pred) + else: + target = data.iloc[qsl_idx] + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + pred_str = tokenizer.decode(pred, skip_special_tokens=True) + results_MBXP.append(create_mbxp_dict(target, pred_str)) + gen_tok_len += len(pred) + + # OpenOrca metric + preds_decoded_text = tokenizer.batch_decode( + preds_token_OpenOrca, skip_special_tokens=True + ) + + preds, targets = postprocess_text( + preds_decoded_text, target_required_OpenOrca + ) + + if preds: + result = metric.compute( + predictions=preds, + references=targets, + use_stemmer=True, + use_aggregator=False, + ) + result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} + prediction_lens = [len(pred) for pred in preds] + + else: + result = {} + prediction_lens = [] + + # GSM8K metric + preds_decoded_text = tokenizer.batch_decode( + preds_token_GSM8K, skip_special_tokens=True + ) + pred_nums = [ + maybe_remove_comma(find_number(pred_text.split("\nQ:")[0])) + for pred_text in preds_decoded_text + ] + gsm8k_total = len(target_required_GSM8K) + correct = 0 + for idx in range(len(target_required_GSM8K)): + ref = try_float(target_required_GSM8K[idx]) + tgt = try_float(pred_nums[idx]) + if tgt is None: + continue + correct += ref == tgt + + result["gsm8k"] = 100.0 * correct / gsm8k_total + + # MBXP metric + # from evaluate_mbxp import evaluate_mbxp + + # if results_MBXP: + # result['mbxp'] = evaluate_mbxp(results_MBXP, args.n_workers) + # else: + # result['mbxp'] = 0 + + result = { + **result, + "gen_len": np.sum(prediction_lens), + "gen_num": gen_num, + "gen_tok_len": gen_tok_len, + "tokens_per_sample": round(gen_tok_len / gen_num, 1), + } + + print("\nResults\n") + print(result) + + +if __name__ == "__main__": + main() diff --git a/mlperf/gmm.py b/mlperf/gmm.py new file mode 100644 index 00000000..73c0dd3b --- /dev/null +++ b/mlperf/gmm.py @@ -0,0 +1,202 @@ +import jax +import jax.numpy as jnp + +from jax.sharding import Mesh, PartitionSpec as P, NamedSharding +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +from functools import partial +from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm + +devices = mesh_utils.create_device_mesh((8, 1)) +mesh = Mesh(devices, axis_names=('x', 'y')) +jax.config.update('jax_default_matmul_precision', "float32") +import torch +interp = False +pdtype = jnp.dtype('float32') + +def _reference_gmm(lhs: torch.Tensor, rhs: torch.Tensor, + group_sizes: torch.Tensor, tiling=None, preferred_element_type=None, interpret=False) -> torch.Tensor: + start = 0 + out = [] + for i, size in enumerate(group_sizes): + result = lhs[start:start + size, :] @ rhs[i, :, :] + out.append(result) + start += group_sizes[i] + return jnp.concatenate(out) + +#gmm = _reference_gmm + + + + +@partial( + shard_map, + mesh=mesh, + in_specs=( + P(), + P(None, 'x', None), + P(None, None, 'x'), + P(None, 'x', None), + P(None, None)), + out_specs=(P()), check_rep=False) +def temp(x, w1, w2, w3, expert_indices): + print('x inside', x.shape) + print('w1', w1.shape) + print('w2', w2.shape) + print('w3', w3.shape) + print('index', expert_indices.shape) + def _histogram(input, min: int, max: int): + assert min <= max, "min must be less than or equal to max." + + def searchsorted(sorted_sequence, values_to_search): + return (jax.numpy.expand_dims(sorted_sequence, 1) == values_to_search).sum(axis=1) + + bin_edges = jax.numpy.linspace( + min, max, max - min + 1, dtype=input.dtype) + return searchsorted(bin_edges, input) + + num_tokens, k = expert_indices.shape + _, n = x.shape + top_flat = expert_indices.flatten() + hidden_states_order = top_flat.argsort() + hidden_states_reverse_order = hidden_states_order.argsort() + # Always replicated, so okay to skip manual sharding. + hidden_states_indices = jnp.arange(num_tokens).repeat(k)[hidden_states_order] + hidden_states_sorted = x[hidden_states_indices] # (num_tokens, hidden_dim/8) + group_sizes = _histogram(top_flat, 0, 7) + # w1 (num_experts, hiddent_dim/8, intermeditate) + gmm1 = gmm(hidden_states_sorted.astype('bfloat16'), + jnp.transpose(w1, (0, 2, 1)).astype('bfloat16'), + group_sizes, tiling=(16,128,128), + preferred_element_type=pdtype, interpret=interp) + gmm3 = gmm( + hidden_states_sorted.astype('float32'), + jnp.transpose(w3, (0, 2, 1)).astype('float32'), + group_sizes, tiling=(16,128,128), preferred_element_type=pdtype, interpret=interp) + + gmm1_s = gmm1 + gmm3_s = gmm3 + #gmm1_s = jax.lax.psum(gmm1, 'x') + #gmm3_s = jax.lax.psum(gmm3, 'x') + silu = jax.nn.silu(gmm1_s) + sgmm = silu * gmm3_s # (num_tokens, intermediate_size) + gmm2 = gmm( + sgmm, + jnp.transpose(w2, (0, 2, 1)).astype('float32'), + group_sizes, + tiling=(8,512,512), + preferred_element_type=pdtype, interpret=interp) #(num_tokens, hidden_dim/8) + print(gmm2.shape) + gmm2 = jax.lax.psum(gmm2, 'x') + current_hidden_states = gmm2[hidden_states_reverse_order].reshape(-1, k, n) + return current_hidden_states + + +# Create a PRNG key +key = jax.random.PRNGKey(123) # Using a different seed for variety + +seqlen = 16 + +expert_indices = jax.random.randint(key, shape=(seqlen, 2), minval=0, maxval=8) +hidden_states = jax.random.normal(key, (seqlen, 4096), dtype=jnp.bfloat16) + +w1 = jnp.broadcast_to(jnp.arange(8).reshape((8, 1, 1)).astype('float32'), (8, 14336, 4096)) +w2 = jnp.broadcast_to(jnp.arange(8).reshape((8, 1, 1)).astype('float32'), (8, 4096, 14336)) +w3 = jnp.broadcast_to(jnp.arange(8).reshape((8, 1, 1)).astype('float32'), (8, 14336, 4096)) + + +hidden_states = jax.device_put(hidden_states, NamedSharding(mesh, P(None, "x"))) +w1 = jax.device_put(w1, NamedSharding(mesh, P(None, "x"))) +w2 = jax.device_put(w2, NamedSharding(mesh, P(None, None, "x"))) +w3 = jax.device_put(w3, NamedSharding(mesh, P(None, "x"))) + +def exp_einsum(x, w1, expert_indices): + w1_weights = w1[expert_indices] + x1 = jnp.einsum("ti,taoi -> tao", x, w1_weights) + return x1 + +def _repeat_index(num_tokens, k): + start = jnp.arange(num_tokens).repeat(k) + start = start.reshape((num_tokens, k)) + return start.T.flatten() + +def exp_gmm(x, w1, expert_indices): + num_tokens, k = expert_indices.shape + _, n = x.shape + e, o, i = w1.shape + top_flat = expert_indices.flatten() + hidden_states_order = top_flat.argsort() + hidden_states_reverse_order = hidden_states_order.argsort() + # Always replicated, so okay to skip manual sharding. + hidden_states_indices = jnp.arange(num_tokens).repeat(k)[hidden_states_order] + hidden_states_sorted = x[hidden_states_indices] # (num_tokens, hidden_dim/8) + def _histogram(input, min: int, max: int): + assert min <= max, "min must be less than or equal to max." + + def searchsorted(sorted_sequence, values_to_search): + return (jax.numpy.expand_dims(sorted_sequence, 1) == values_to_search).sum(axis=1) + + bin_edges = jax.numpy.linspace( + min, max, max - min + 1, dtype=input.dtype) + return searchsorted(bin_edges, input) + + group_sizes = _histogram(top_flat, 0, 7) + gmm1 = gmm(hidden_states_sorted.astype('float32'), + jnp.transpose(w1, (0, 2, 1)).astype('float32'), + group_sizes, + tiling=(16,128,128), + preferred_element_type=pdtype, interpret=interp) + return gmm1[hidden_states_reverse_order].reshape(-1, k, o) + + +def forward_for_long_seq_len(x, w1, w2, w3, expert_indices): + seqlen = x.shape[0] + num_experts = w1.shape[0] + + # e = total num of exp = 8 + # t = seqlen + # o = config.imtermediate size + # i = config.dim + with jax.named_scope("conditional_ff"): + x1 = jax.nn.silu(jnp.einsum("ti,eoi -> teo", x, w1)) + x3 = jnp.einsum("ti, eoi-> teo", x, w3) + expert_outs = ( + jnp.einsum("teo, eio -> tei", (x1 * x3), w2) + ) + # e = 8; need to reduce to 2 + seq_indexes = jnp.expand_dims(jnp.arange(seqlen), 1) + return expert_outs[seq_indexes, expert_indices] + +def main(): + + # x = jnp.arange(8).reshape((2, 4)).astype('float64') / 100 + # w1 = jnp.arange(3 * 4 * 5).reshape((3,5,4)).astype('float64') + # w2 = jnp.arange(3 * 4 * 5).reshape((3,4,5)).astype('float64') + # w3 = jnp.arange(3 * 4 * 5).reshape((3,5,4)).astype('float64') + # expert_indices = jnp.array([[0, 2], [1, 0]]) + x = hidden_states + out1 = forward_for_long_seq_len(x, w1, w2, w3, expert_indices) + out = temp(x, w1, w2, w3, expert_indices) + print(jnp.max(jnp.abs(out1 - out))) + + # out1 = exp_einsum(x, w1, expert_indices) + # out_gmm = exp_gmm(x, w1, expert_indices) + # print(out1 - out_gmm) + + + #group_sizes = jnp.array([4] * 8) + #gmm1 = gmm(hidden_states.astype('float32'), + # w1.astype('float32'), + # group_sizes, + # tiling=(16,128,128), interpret=interp) + #gmm1_ref = _reference_gmm(hidden_states.astype('float32'), + # w1.astype('float32'), + # group_sizes, + # tiling=(16,128,128), interpret=interp) + #print(gmm1 - gmm1_ref) + + +if __name__ == '__main__': + main() + + diff --git a/mlperf/install.sh b/mlperf/install.sh new file mode 100644 index 00000000..3a8f037b --- /dev/null +++ b/mlperf/install.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +DATA_DISK_DIR=data + +mkdir -p $DATA_DISK_DIR + +pip install -U "huggingface_hub[cli]" +pip install \ + transformers \ + nltk==3.8.1 \ + evaluate==0.4.0 \ + absl-py==1.4.0 \ + rouge-score==0.1.2 \ + sentencepiece==0.1.99 \ + accelerate==0.21.0 + +# install loadgen +pip install mlperf-loadgen + + +pushd $DATA_DISK_DIR + +# model weights +gcloud storage cp gs://sixiang_gcp/mixtral-instruct-quantized ./ --recursive +# NOTE: uncomment one so you dont download too much weights to your box +# gcloud storage cp gs://sixiang_gcp/llama2-70b/llama2-70b/ ./ --recursive + +# Get mixtral data +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl +mv mixtral_8x7b%2F2024.06.06_mixtral_15k_v4.pkl mixtral_15k_data.pkl +wget https://inference.mlcommons-storage.org/mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl +mv mixtral_8x7b%2F2024.06.06_mixtral_15k_calibration_v4.pkl mixtral_15k_calibration_data.pkl + +# Get llama70b data +gcloud storage cp \ + gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl \ + processed-calibration-data.pkl +gcloud storage cp \ + gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl \ + processed-data.pkl +popd diff --git a/mlperf/mixtral_run.sh b/mlperf/mixtral_run.sh new file mode 100755 index 00000000..77504920 --- /dev/null +++ b/mlperf/mixtral_run.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +me=$(basename "$0") + +BASEDIR=mlperf +USER_CONFIG=$BASEDIR/user.conf +DATA_DISK_DIR=$BASEDIR/data +TOTAL_SAMPLE_COUNT=900 + +# HF model id +TOKENIZER_PATH="mistralai/Mixtral-8x7B-Instruct-v0.1" +LOADGEN_RUN_TYPE=offline-performance +OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} +OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} + +mkdir -p ${OUTPUT_LOG_DIR} && cp ../${USER_CONFIG} ${OUTPUT_LOG_DIR} + +OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json + +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ + +LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" +# makes subsequent runs faster +export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2" +export LIBTPU_INIT_ARGS + + +DATASET_PATH=$BASEDIR/data/mixtral_1k_data.pkl + +pushd .. +python -m mlperf.offline_mode \ + --lazy_cache_update=1 \ + --ring_buffer=0 \ + --mlperf_test_mode=$1 \ + --model_name=mixtral \ + --checkpoint_path=$CHECKPOINT_PATH/model.safetensors \ + --tokenizer_path=$CHECKPOINT_PATH/tokenizer.model \ + --quantize_weights=1 \ + --quantize_type=int8_per_channel \ + --quantize_kv_cache=1 \ + --input_mode tokenized \ + --output_mode tokenized \ + --mlperf_conf $BASEDIR/mlperf.conf \ + --user_conf ${USER_CONFIG} \ + --audit_conf no_audit \ + --total_sample_count ${TOTAL_SAMPLE_COUNT} \ + --dataset_path ${DATASET_PATH} \ + --output_log_dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log + +if [ "$1" = "accuracy" ]; then +python -m mlperf.evaluate_accuracy \ + --checkpoint-path ${TOKENIZER_PATH} \ + --mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \ + --dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log +fi +popd diff --git a/mlperf/mlperf.conf b/mlperf/mlperf.conf new file mode 100644 index 00000000..e9ae205e --- /dev/null +++ b/mlperf/mlperf.conf @@ -0,0 +1,98 @@ +# The format of this config file is 'key = value'. +# The key has the format 'model.scenario.key'. Value is mostly int64_t. +# Model maybe '*' as wildcard. In that case the value applies to all models. +# All times are in milli seconds + +# Set performance_sample_count for each model. +# User can optionally set this to higher values in user.conf. +resnet50.*.performance_sample_count_override = 1024 +ssd-mobilenet.*.performance_sample_count_override = 256 +retinanet.*.performance_sample_count_override = 64 +bert.*.performance_sample_count_override = 10833 +dlrm.*.performance_sample_count_override = 204800 +dlrm-v2.*.performance_sample_count_override = 204800 +rnnt.*.performance_sample_count_override = 2513 +gptj.*.performance_sample_count_override = 13368 +llama2-70b.*.performance_sample_count_override = 24576 +stable-diffusion-xl.*.performance_sample_count_override = 5000 +# set to 0 to let entire sample set to be performance sample +3d-unet.*.performance_sample_count_override = 0 + +# Set seeds. The seeds will be distributed two weeks before the submission. +*.*.qsl_rng_seed = 3066443479025735752 +*.*.sample_index_rng_seed = 10688027786191513374 +*.*.schedule_rng_seed = 14962580496156340209 +# Set seeds for TEST_05. The seeds will be distributed two weeks before the submission. +*.*.test05_qsl_rng_seed = 16799458546791641818 +*.*.test05_sample_index_rng_seed = 5453809927556429288 +*.*.test05_schedule_rng_seed = 5435552105434836064 + + +*.SingleStream.target_latency_percentile = 90 +*.SingleStream.min_duration = 600000 + +*.MultiStream.target_latency_percentile = 99 +*.MultiStream.samples_per_query = 8 +*.MultiStream.min_duration = 600000 +*.MultiStream.min_query_count = 662 +retinanet.MultiStream.target_latency = 528 + +# 3D-UNet uses equal issue mode because it has non-uniform inputs +3d-unet.*.sample_concatenate_permutation = 1 + +# LLM benchmarks have non-uniform inputs and outputs, and use equal issue mode for all latency scenario +gptj.*.sample_concatenate_permutation = 1 +llama2-70b.*.sample_concatenate_permutation = 1 +mixtral-8x7B.*.sample_concatenate_permutation = 1 + +*.Server.target_latency = 10 +*.Server.target_latency_percentile = 99 +*.Server.target_duration = 0 +*.Server.min_duration = 600000 +resnet50.Server.target_latency = 15 +retinanet.Server.target_latency = 100 +bert.Server.target_latency = 130 +dlrm.Server.target_latency = 60 +dlrm-v2.Server.target_latency = 60 +rnnt.Server.target_latency = 1000 +gptj.Server.target_latency = 20000 +stable-diffusion-xl.Server.target_latency = 20000 +# Llama2-70b benchmarks measures token latencies +llama2-70b.*.use_token_latencies = 1 +mixtral-8x7b.*.use_token_latencies = 1 +# gptj benchmark infers token latencies +gptj.*.infer_token_latencies = 1 +gptj.*.token_latency_scaling_factor = 69 +# Only ttft and tpot are tracked for the llama2-70b & mixtral-8x7B benchmark therefore target_latency = 0 +llama2-70b.Server.target_latency = 0 +llama2-70b.Server.ttft_latency = 2000 +llama2-70b.Server.tpot_latency = 200 + +mixtral-8x7b.Server.target_latency = 0 +mixtral-8x7b.Server.ttft_latency = 2000 +mixtral-8x7b.Server.tpot_latency = 200 + +*.Offline.target_latency_percentile = 90 +*.Offline.min_duration = 600000 + +# In Offline scenario, we always have one query. But LoadGen maps this to +# min_sample_count internally in Offline scenario. If the dataset size is larger +# than 24576 we limit the min_query_count to 24576 and otherwise we use +# the dataset size as the limit + +resnet50.Offline.min_query_count = 24576 +retinanet.Offline.min_query_count = 24576 +dlrm-v2.Offline.min_query_count = 24576 +bert.Offline.min_query_count = 10833 +gptj.Offline.min_query_count = 13368 +rnnt.Offline.min_query_count = 2513 +3d-unet.Offline.min_query_count = 43 +stable-diffusion-xl.Offline.min_query_count = 5000 +llama2-70b.Offline.min_query_count = 1000 +mixtral-8x7b.Offline.min_query_count = 15000 + +# These fields should be defined and overridden by user.conf. +*.SingleStream.target_latency = 10 +*.MultiStream.target_latency = 80 +*.Server.target_qps = 1.0 +*.Offline.target_qps = 4.0 diff --git a/mlperf/offline_mode.py b/mlperf/offline_mode.py new file mode 100644 index 00000000..09b00c25 --- /dev/null +++ b/mlperf/offline_mode.py @@ -0,0 +1,476 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import contextlib +import copy +import gc +import time +import math +import logging +import os +import sys +import array +import collections +import threading + +import torch_xla2 +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd + +import mlperf_loadgen as lg +from jetstream_pt.config import create_engine_from_config_flags +from jetstream_pt import offline_inference + +_MLPERF_ID = "mixtral-8x7b" + +logging.basicConfig(level=logging.DEBUG) + +sys.path.insert(0, os.getcwd()) +log = logging.getLogger("main2.py") + + +from absl import app, flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "mlperf_test_mode", + "performance", + "performance, accuracy, submission", +) +flags.DEFINE_string( + "api_url", None, "SAX published model path.", required=False +) +flags.DEFINE_string("dataset_path", None, "", required=False) +flags.DEFINE_bool("is_stream", False, "", required=False) +flags.DEFINE_string( + "input_mode", + "tokenized", + "Input mode", +) +flags.DEFINE_string( + "output_mode", + "tokenized", + "Output mode", +) + +flags.DEFINE_string( + "audit_conf", + "audit.conf", + "audit config for LoadGen settings during compliance runs", + required=False, +) +flags.DEFINE_string( + "mlperf_conf", + "mlperf.conf", + "mlperf rules config", + required=False, +) +flags.DEFINE_string( + "user_conf", + "user.conf", + "user config for user LoadGen settings such as target QPS", + required=False, +) +flags.DEFINE_integer( + "total_sample_count", + 15000, + "Number of samples to use in benchmark.", + required=False, +) +flags.DEFINE_integer( + "perf_count_override", + None, + "Overwrite number of samples to use in benchmark.", + required=False, +) +flags.DEFINE_string( + "output_log_dir", + "output-logs", + "Where logs are saved.", + required=False, +) +flags.DEFINE_bool( + "enable_log_trace", + False, + "Enable log tracing. This file can become quite large", + required=False, +) +flags.DEFINE_bool( + "skip_warmup", False, "Skips warmup" +) +flags.DEFINE_bool( + "internal_dummy_model", False, "Skips actual model compute, used for testing" +) + + +scenario_map = { + "offline": lg.TestScenario.Offline, + "server": lg.TestScenario.Server, +} + + +def pad_tokens(tokens): + true_length = len(tokens) + target_length = max(int(2 ** math.ceil(math.log2(true_length))), 32) + padded = tokens + [0] * (target_length - true_length) + return padded, true_length + + +@contextlib.contextmanager +def timed(msg): + log.info(msg + " start") + start = time.perf_counter() + yield + end = time.perf_counter() + log.info(msg + " done: " + str(end - start)) + + +def _classify_query(dataset_rows, index): + # return groupped indexes + if FLAGS.model_name == 'mixtral': + sample = dataset_rows[index][1] + input_len = sample.tok_input_len + total_len = sample.tok_input_len + sample.tok_ref_output_len + if total_len <= 512: + return 0 + elif total_len <= 1280 and input_len <= 1024: + return 1 + else: + return 2 + else: + sample = dataset_rows[index][1] + total_len = sample.tok_input_length + sample.tok_output_length + if total_len <= 512: + return 0 + elif total_len <= 1024: + return 1 + else: + return 2 + + + + +def _pick_batch_size(num_samples, max_batch, dataset_size, sample_size): + """max_batch to not run OOM.""" + if num_samples <= max_batch: + return num_samples + mult = math.ceil(num_samples / max_batch) + return math.ceil(num_samples / mult * (sample_size / dataset_size)) + + +def _log_complete(sample_id, response_token_ids): + assert (response_token_ids[0] <= 32000) + n_tokens = len(response_token_ids) + response_token_ids = np.array(response_token_ids, dtype=np.int64) + response_array = array.array("B", response_token_ids.tobytes()) + response_info = response_array.buffer_info() + response_data = response_info[0] + response_size = response_info[1] * response_array.itemsize + query_sample_response = lg.QuerySampleResponse( + sample_id, response_data, response_size, n_tokens + ) + lg.QuerySamplesComplete([query_sample_response]) + # import ipdb; ipdb.set_trace() + +def _log_first(sample_id, response_token_ids): + assert (response_token_ids[0] <= 32000) + assert len(response_token_ids) == 1 + response_token_ids = np.array(response_token_ids, dtype=np.int64) + response_array = array.array("B", response_token_ids.tobytes()) + response_info = response_array.buffer_info() + first_token_response = lg.QuerySampleResponse( + sample_id, response_info[0], response_info[1] + ) + lg.FirstTokenComplete([first_token_response]) + + +class SUT: + + def __init__(self, data, offline_inf): + # dict of int (cache length) -> offline_inf + self.offline_inf = offline_inf + + # pandas dataframe, it has tok + self._dataset = data + self.pandas_rows = list(self._dataset.iterrows()) + + # List of things with .id and .index + self._queries = None + + # index to loaded data + self._processed_data = None + + # self.replicated = self.offline_inf.engine.env.sharding_by_axis(-1) + self._sample_id_to_input = None + self._sample_id_to_bucket = None + self._groupped_queries = [[], [], []] + self._id_to_index = {} + self._index_to_group = {} + self._eos = offline_inf[0].tokenizer.eos_id + + def _get_eos_seq(self, id_): + idx = self._id_to_index[id_] + pandas_row = self.pandas_rows[idx][1] + if hasattr(pandas_row, 'tok_stop_sequence'): + return pandas_row.tok_stop_sequence + else: + return [self._eos] + + def issue_queries(self, queries): + log.info('issue queries called') + self._groupped_queries = [[], [], []] + assert self._sample_id_to_input is not None + self._processed_data = [] + self._queries = queries + for q in queries: + self._id_to_index[q.id] = q.index + group = self._index_to_group[q.index] + input_data = copy.copy(self._sample_id_to_input[q.index]) + input_data.id = q.id + self._groupped_queries[group].append(input_data) + + if len(self._queries) != sum(len(q) for q in self._groupped_queries): + import ipdb; ipdb.set_trace() + + # At this point _processed_data is ready + + @timed("flush_queries") + def flush_queries(self): + start = time.perf_counter() + completed = set() + resp = collections.defaultdict(list) + lock = threading.RLock() + def emit_token(id_, token): + nonlocal resp + nonlocal completed + with lock: + resp[id_].append(token) + end_seq = self._get_eos_seq(id_) + is_end = (token == self._eos) or (end_seq == resp[id_][-len(end_seq):]) + if is_end: + _log_complete(id_, resp[id_]) + completed.add(id_) + if id_ in resp: + del resp[id_] + return is_end + + def emit_first_token(id_, token): + nonlocal resp + nonlocal completed + # emit first token + with lock: + _log_first(id_, [token]) + end_seq = self._get_eos_seq(id_) + is_end = (token == self._eos) or (len(end_seq) == 1 and end_seq[0] == token) + if is_end: + # We have four OpenOrca samples that return empty (eos_token) output. + # It was decided to allow two eos_tokens to not break throughput computation + # PR - https://github.com/mlcommons/inference/pull/1778 + _log_complete(id_, [token, self._eos]) + completed.add(id_) + if id_ in resp: + import pdb; pdb.set_trace() + del resp[id_] + return is_end + + for group_idx in [2,1,0]: + group = self._groupped_queries[group_idx] + self.offline_inf[group_idx].init_decode_state() + result = self.offline_inf[group_idx].batch_inference_with_callback( + group, emit_first_token, emit_token) + + # some never reached eos but reached max sequence + with lock: + for key, value in resp.items(): + if key in completed: + continue + _log_complete(key, value) + completed.add(key) + + if group_idx != 0: + # no need to drop state for the last one + self.offline_inf[group_idx].decode_state = None + gc.collect() + + end = time.perf_counter() + + def LoadSamplesToRam(self, sample_list): + """Pads the data, move them to jax array on device""" + log.info("LoadSamplesToRam start") + start = time.perf_counter() + input_data = {} + + for sample_id in sample_list: + p = self.pandas_rows[sample_id][1] + padded, length = pad_tokens(p.tok_input) + input_data[sample_id] = offline_inference.InputData( + "", jnp.array(padded), length # to be filled later + ) + self._index_to_group[sample_id] = _classify_query(self.pandas_rows, sample_id) + + for data in input_data.values(): + # make sure tokens are transfered to device + jax.block_until_ready(data.tokens) + + self._sample_id_to_input = input_data + + end = time.perf_counter() + log.info(f"LoadSamplesToRam finished: {end - start}s") + + def UnloadSamplesFromRam(self, sample_list): + print("UnloadSamplesFromRam called") + pass + + +def _count_by_bucket(dataset): + if FLAGS.model_name == 'mixtral': + total_len = dataset.tok_input_len + dataset.tok_ref_output_len + + group1 = total_len <= 512 + group2 = (total_len <= 1280) & (dataset.tok_input_len <= 1024) + + # with 5 percent extra + mult = FLAGS.total_sample_count / len(dataset) * 1.05 + + counts = [ + # power of 2 + math.ceil(len(dataset[group1]) * mult), + math.ceil(len(dataset[~group1 & group2]) * mult), + math.ceil(len(dataset[~group1 & ~group2]) * mult), + ] + return counts + else: + total_len = dataset.tok_input_length + dataset.tok_output_length + group1 = total_len <= 512 + group2 = total_len <= 1024 + # with 5 percent extra + mult = FLAGS.total_sample_count / len(dataset) * 1.05 + counts = [ + math.ceil(len(dataset[group1]) * mult), + math.ceil(len(dataset[~group1 & group2]) * mult), + math.ceil(len(dataset[~group1 & ~group2]) * mult), + ] + return counts + + +def main(argv): + del argv + args = FLAGS + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + + if len(jax.devices()) < 4: + print("Looks like TPU not available?", jax.devices()) + return -1 + # jax.config.update("jax_explain_cache_misses", True) + + settings = lg.TestSettings() + settings.scenario = lg.TestScenario.Offline + user_conf = FLAGS.user_conf + + settings.FromConfig(FLAGS.mlperf_conf, _MLPERF_ID, "Offline") + settings.FromConfig(user_conf, _MLPERF_ID, "Offline") + log.info("Mlperf config: %s", FLAGS.mlperf_conf) + log.info("User config: %s", user_conf) + + dataset = pd.read_pickle(FLAGS.dataset_path) + rows = list(dataset.iterrows()) + counts_by_bucket = _count_by_bucket(dataset) + log.info(f"Counts by bucket {counts_by_bucket}") + + if FLAGS.model_name == "mixtral": + length_and_batch = ( + (512, 2048), + (1280, 512), + (3072, 256), + ) + else: + length_and_batch = ( + (512, 512), + (1024, 256), + (2048, 96), + ) + engines = [] + params = None + for i, (length, max_batch) in enumerate(length_and_batch): + batch = min(counts_by_bucket[i], max_batch) + log.info(f"Using batch size of {batch} for {length}") + engine = create_engine_from_config_flags(batch=batch, cache_len=length) + offline_inf = offline_inference.OfflineInference(engine, params) + offline_inf.dummy = FLAGS.internal_dummy_model + params = offline_inf.params + engines.append(offline_inf) + + if not FLAGS.skip_warmup: + with timed("warmup"): + for (length, _), engine in zip(length_and_batch, engines): + log.info(f"warm up for {length}") + engine.init_decode_state() + engine.warmup(length) + if length != 3072: + # dont need to drop state for the last one + engine.decode_state = None # drop state + gc.collect() + + sut = SUT(dataset, engines) + + if FLAGS.mlperf_test_mode == "accuracy": + settings.mode = lg.TestMode.AccuracyOnly + log.warning( + "Accuracy run will generate the accuracy logs, but the evaluation of the log is not completed yet" + ) + elif FLAGS.mlperf_test_mode == "submission": + settings.mode = lg.TestMode.SubmissionRun + settings.print_timestamps = True + else: + settings.mode = lg.TestMode.PerformanceOnly + settings.print_timestamps = True + + settings.use_token_latencies = True + + os.makedirs(FLAGS.output_log_dir, exist_ok=True) + log_output_settings = lg.LogOutputSettings() + log_output_settings.outdir = FLAGS.output_log_dir + log_output_settings.copy_summary_to_stdout = True + log_settings = lg.LogSettings() + log_settings.log_output = log_output_settings + log_settings.enable_trace = FLAGS.enable_log_trace + + lgSUT = lg.ConstructSUT(sut.issue_queries, sut.flush_queries) + qsl = lg.ConstructQSL( + FLAGS.total_sample_count, + FLAGS.total_sample_count, + sut.LoadSamplesToRam, + sut.UnloadSamplesFromRam, + ) + log.info("Starting Benchmark run") + lg.StartTestWithLogSettings( + lgSUT, qsl, settings, log_settings, FLAGS.audit_conf + ) + log.info(f"query counts {[len(q) for q in sut._groupped_queries]}") + log.info("Run Completed!") + log.info("Destroying SUT...") + lg.DestroySUT(lgSUT) + + log.info("Destroying QSL...") + lg.DestroyQSL(qsl) + + +if __name__ == "__main__": + # Disable garbage collection to avoid stalls when running tests. + gc.disable() + app.run(main) diff --git a/mlperf/user.conf b/mlperf/user.conf new file mode 100644 index 00000000..95ef75ef --- /dev/null +++ b/mlperf/user.conf @@ -0,0 +1,6 @@ +mixtral-8x7b.Server.target_qps = 2.0 +mixtral-8x7b.Offline.target_qps = 100.0 + +# send unique queries +mixtral-8x7b.Offline.performance_issue_unique = 1 +