Skip to content

Commit

Permalink
Merge pull request #662 from google:mor--kv-cache-layout-reformat-output
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646637997
  • Loading branch information
maxtext authors committed Jun 25, 2024
2 parents 9482bf1 + 9606e62 commit 5a215db
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 36 deletions.
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ inference_microbenchmark_prefill_lengths: "64,128,256,512,1024"
inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
inference_microbenchmark_log_file_path: ""
inference_metadata_file: "" # path to a json file

# KV Cache layout control
# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV
Expand Down
105 changes: 69 additions & 36 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import json
import sys

from collections.abc import MutableMapping
from typing import Any, Dict, Optional

from jetstream.engine import token_utils

import max_utils
Expand Down Expand Up @@ -63,9 +66,9 @@ def prefill_benchmark(
f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n"
)
result_dict = {
"prefill_time_in_ms": prefill_average_ms,
"prefill_total_tflops_per_device": prefill_tflops_per_device,
"prefill_tflops_per_sec_per_device": tflops_per_sec_per_device,
"time_in_ms": prefill_average_ms,
"total_tflops_per_device": prefill_tflops_per_device,
"tflops_per_sec_per_device": tflops_per_sec_per_device,
}
return result_dict

Expand Down Expand Up @@ -106,7 +109,7 @@ def prefill_insert_benchmark(
f"\tPrefill + Insert step average time: {prefill_insert_average_ms:.3f} ms\n\n\n\n"
)
result_dict = {
"prefill_insert_time_in_ms": prefill_insert_average_ms
"time_in_ms": prefill_insert_average_ms
}
return result_dict, decode_state

Expand Down Expand Up @@ -147,20 +150,20 @@ def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_
)

result_dict = {
"ar_step_in_ms": ar_average_ms,
"ar_step_in_ms_per_seq": ar_average_ms / global_batch_size,
"ar_global_batch_size": global_batch_size,
"ar_total_throughput_tokens_per_second": total_throughput,
"ar_device_bandwidth_GB_per_second": bw_per_device,
"step_in_ms": ar_average_ms,
"step_in_ms_per_seq": ar_average_ms / global_batch_size,
"global_batch_size": global_batch_size,
"total_throughput_tokens_per_second": total_throughput,
"device_bandwidth_GB_per_second": bw_per_device,
}
return result_dict, decode_state


def collate_results(config, results, model_size, cache_size, num_model_params, incl_config=False):
"""Adds model/cache size info and optionally config info to results."""
results["sizes"] = {
"Model_size_in_GB": model_size / 1e9,
"cache_size_in_GB": cache_size / 1e9,
"model_size_in_gb": model_size / 1e9,
"cache_size_in_gb": cache_size / 1e9,
"model_params_in_billions": num_model_params / 1e9,
}
if incl_config:
Expand All @@ -170,30 +173,45 @@ def collate_results(config, results, model_size, cache_size, num_model_params, i
return results


def write_results(results, filename):
def flatten_dict(dictionary, prefix='', sep='_'):
results = []
for k, v in dictionary.items():
new_key = str(prefix) + sep + str(k) if prefix else k
if isinstance(v, MutableMapping):
results.extend(flatten_dict(v, new_key, sep=sep).items())
else:
results.append((new_key, v))
return dict(results)


def write_results(results, filename, flatten_microbenchmark_results):
"""Write the results microbenchmark results to a json file."""
if flatten_microbenchmark_results:
results['flattened_results'] = flatten_dict(results)
if filename != "":
with open(filename, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
return results


def print_results_for_analyze(results):
"""Print results."""
print("\nFor usage in analyze_sharegpt.py :")

if "Prefill" in results:
if "prefill" in results:
prefill_bucket_size_to_ms = {}
for k, v in results["Prefill"].items():
prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3)
for k, v in results["prefill"].items():
prefill_bucket_size_to_ms[int(k)] = round(v["time_in_ms"], 3)
print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}")

if "Prefill_Insert" in results:
if "insert" in results:
insert_bucket_size_to_ms = {}
for k, v in results["Prefill_Insert"].items():
insert_bucket_size_to_ms[int(k)] = round(v["prefill_insert_time_in_ms"], 3)
print(f"PREFILL_INSERT_BUCKET_SIZE_TO_MS = {insert_bucket_size_to_ms}")
for k, v in results["insert"].items():
insert_bucket_size_to_ms[int(k)] = round(v["time_in_ms"], 3)
print(f"INSERT_BUCKET_SIZE_TO_MS = {insert_bucket_size_to_ms}")

if "AutoRegressive" in results:
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}")
if "autoregressive" in results:
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['autoregressive']['step_in_ms_per_seq']}")


def summarize_prefill_result(engine, params, tokens, true_length):
Expand All @@ -209,16 +227,16 @@ def summarize_prefill_result(engine, params, tokens, true_length):
)
del prefill_result
return {
"num_prefill_logits_params": num_prefill_logits_params,
"total_prefill_logits_size": total_prefill_logits_size,
"avg_prefill_logits_param_size": avg_prefill_logits_param_size,
"num_prefill_cache_params": num_prefill_cache_params,
"total_prefill_cache_size": total_prefill_cache_size,
"avg_prefill_cache_param_size": avg_prefill_cache_param_size,
"num_logits_params": num_prefill_logits_params,
"total_logits_size": total_prefill_logits_size,
"avg_logits_param_size": avg_prefill_logits_param_size,
"num_cache_params": num_prefill_cache_params,
"total_cache_size": total_prefill_cache_size,
"avg_cache_param_size": avg_prefill_cache_param_size,
}


def main(config):
def main(config, inference_metadata: Optional[Dict[str, Any]] = None):
engine = maxengine.MaxEngine(config)
params = engine.load_params()
prefill_lengths = [int(l) for l in config.inference_microbenchmark_prefill_lengths.split(",")]
Expand All @@ -236,22 +254,22 @@ def main(config):
benchmark_results = {}
if "prefill" in stages_to_benchmark:

benchmark_results["Prefill_Result"] = {}
benchmark_results["Prefill"] = {}
benchmark_results["Prefill_Insert"] = {}
benchmark_results["prefill-result-sizes"] = {}
benchmark_results["prefill"] = {}
benchmark_results["insert"] = {}
prefill_tokens = {}
prefill_true_lengths = {}

for prefill_length in prefill_lengths:
prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[prefill_length]
)
benchmark_results["Prefill_Result"]["prefill_length"] = summarize_prefill_result(
benchmark_results["prefill-result-sizes"][prefill_length] = summarize_prefill_result(
engine, params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length]
)

for prefill_length in prefill_lengths:
benchmark_results["Prefill"][prefill_length] = prefill_benchmark(
benchmark_results["prefill"][prefill_length] = prefill_benchmark(
config,
engine,
params,
Expand All @@ -261,7 +279,7 @@ def main(config):
benchmark_loop_iters
)

benchmark_results["Prefill_Insert"][prefill_length], decode_state = prefill_insert_benchmark(
prefill_insert_time, decode_state = prefill_insert_benchmark(
config,
engine,
decode_state,
Expand All @@ -271,14 +289,29 @@ def main(config):
prefill_true_lengths[prefill_length],
benchmark_loop_iters
)
benchmark_results["insert"][prefill_length] = {}
benchmark_results["insert"][prefill_length]["time_in_ms"] = (
prefill_insert_time["time_in_ms"] - benchmark_results["prefill"][prefill_length]["time_in_ms"]
)

if "generate" in stages_to_benchmark:
benchmark_results["AutoRegressive"], decode_state = ar_benchmark(
benchmark_results["autoregressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters)

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
write_results(results, filename=config.inference_microbenchmark_log_file_path)
print_results_for_analyze(results)
if inference_metadata:
flatten_microbenchmark_results = pyconfig.string_to_bool(
inference_metadata.get('flatten_microbenchmark_results', 'false')
)
else:
flatten_microbenchmark_results = 'false'
results = write_results(
results,
filename=config.inference_microbenchmark_log_file_path,
flatten_microbenchmark_results=flatten_microbenchmark_results
)
return results


if __name__ == "__main__":
Expand Down
143 changes: 143 additions & 0 deletions MaxText/inference_microbenchmark_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Sweep across inference microbenchmarks."""

import os
import sys
import json
import jsonlines
import inference_microbenchmark
import pyconfig
from jax._src.lib import xla_extension


def main():
"""
User needs to set the config's inference_metadata_file, which is a path to a
json file.
This json should contain the following keys:
- two_axis_order_product_id_list: comma separated string of two_axis_order_product_id
- prefill_cache_axis_order_list: comma delimited string of prefill_cache_axis_order
- ar_cache_axis_order_list: comma delimited string of ar_cache_axis_order
- accelerator: name of the accelerator
- flatten_microbenchmark_results: Whether or not to flatten results. Should
be true
"""
pyconfig.initialize(sys.argv)
config = pyconfig.config
base_run_name = config.run_name

with open(config.inference_metadata_file, encoding='utf-8') as json_file:
inference_metadata = json.load(json_file)
print(f"inference_metadata: {inference_metadata}")

two_axis_order_product_id_list = inference_metadata['two_axis_order_product_id_list'].split(':')
prefill_cache_axis_order_list = inference_metadata['prefill_cache_axis_order_list'].split(':')
ar_cache_axis_order_list = inference_metadata['ar_cache_axis_order_list'].split(':')

start_two_axis_order_product_id = two_axis_order_product_id_list[0]
end_two_axis_order_product_id = two_axis_order_product_id_list[-1]

results = []
for (
two_axis_order_product_id,
prefill_cache_axis_order,
ar_cache_axis_order,
) in zip(
two_axis_order_product_id_list,
prefill_cache_axis_order_list,
ar_cache_axis_order_list,
):
print(f"two_axis_order_product_id {two_axis_order_product_id}")
print(f"prefill_cache_axis_order {prefill_cache_axis_order}")
print(f"ar_cache_axis_order {ar_cache_axis_order}")

run_tag = (
f"{two_axis_order_product_id}-{prefill_cache_axis_order.replace(',','')}-{ar_cache_axis_order.replace(',','')}"
)
run_name = f"{base_run_name}/{run_tag}"

tensorboard_dir = os.path.join(config.base_output_directory, run_name, "tensorboard", "")
pyconfig._config.keys['prefill_cache_axis_order'] = prefill_cache_axis_order # pylint: disable=protected-access
pyconfig._config.keys['ar_cache_axis_order'] = ar_cache_axis_order # pylint: disable=protected-access
pyconfig._config.keys['tensorboard_dir'] = tensorboard_dir # pylint: disable=protected-access
pyconfig._config.keys['run_name'] = run_name # pylint: disable=protected-access

# Prepare metadata (dimensions) json for XLML
dimensions_json = {
"base_output_directory": config.base_output_directory,
"model_name": config.model_name,
"tokenizer": config.tokenizer_path,
"weight_dtype": config.weight_dtype,
"inference_microbenchmark_prefill_lengths": f"{config.inference_microbenchmark_prefill_lengths}",
"inference_microbenchmark_stages": config.inference_microbenchmark_stages,
"inference_microbenchmark_loop_iters": f"{config.inference_microbenchmark_loop_iters}",
"max_prefill_predict_length": f"{config.max_prefill_predict_length}",
"max_target_length": f"{config.max_target_length}",
"per_device_batch_size": f"{config.per_device_batch_size}",
"ici_fsdp_parallelism": f"{config.ici_fsdp_parallelism}",
"ici_autoregressive_parallelism": f"{config.ici_autoregressive_parallelism}",
"ici_tensor_parallelism": f"{config.ici_tensor_parallelism}",
"profiler": f"{config.profiler}",
"scan_layers": f"{config.scan_layers}",
"quantization": config.quantization,
"quantize_kvcache": f"{config.quantize_kvcache}",
"attention": config.attention,
"two_axis_order_product_id": f"{two_axis_order_product_id}",
"prefill_cache_axis_order": f"{prefill_cache_axis_order}",
"ar_cache_axis_order": f"{ar_cache_axis_order}",
"compute_axis_order": f"{config.compute_axis_order}",
"reshape_q": f"{config.reshape_q}",
"kv_quant_axis": f"{config.kv_quant_axis}",
"run_name": f"{run_name}",
"run_tag": f"{run_tag}",
"config_json_string": json.dumps(
pyconfig._config.keys, # pylint: disable=protected-access
default=lambda x: f"<<non-serializable: {type(x).__qualname__}>>"
)
}
dimensions_json = {
**dimensions_json,
**inference_metadata,
}
try:
microbenchmark_results = inference_microbenchmark.main(config, inference_metadata=inference_metadata)
metrics = microbenchmark_results['flattened_results']
metrics = {k.lower(): v for k, v in metrics.items()}
dimensions_json['oom'] = 'False'
print(f"Completed run {two_axis_order_product_id} out of: "
f"{start_two_axis_order_product_id} to {end_two_axis_order_product_id}")
except xla_extension.XlaRuntimeError:
# OOM
metrics = {}
dimensions_json['oom'] = 'True'
print(f"Failed at run {two_axis_order_product_id} out of: "
f"{start_two_axis_order_product_id} to {end_two_axis_order_product_id}")

final = {'metrics': metrics, 'dimensions': dimensions_json}
print(f"Result: {final}")
results.append(final)

print(f"All results {results}")
path = 'inference_microbenchmark_sweep_results.jsonl'
with jsonlines.open(path, mode="w") as writer:
writer.write_all(results)


if __name__ == "__main__":
main()

0 comments on commit 5a215db

Please sign in to comment.