From 535c132ca3aa5d7ea8f6e932852080f4ddbbc926 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Tue, 10 May 2022 14:43:40 +0200 Subject: [PATCH] Metric Export to CSV added --- tftrt/examples/benchmark_args.py | 8 +++ tftrt/examples/benchmark_runner.py | 84 +++++++++++++++++++++++------- tftrt/examples/benchmark_utils.py | 2 +- 3 files changed, 73 insertions(+), 21 deletions(-) diff --git a/tftrt/examples/benchmark_args.py b/tftrt/examples/benchmark_args.py index d68902333..73f3b3525 100644 --- a/tftrt/examples/benchmark_args.py +++ b/tftrt/examples/benchmark_args.py @@ -256,6 +256,14 @@ def __init__(self): "to the set location in JSON format for further processing." ) + self._parser.add_argument( + "--export_metrics_csv_path", + type=str, + default=None, + help="If set, the script will export runtime metrics and arguments " + "to the set location in CSV format for further processing." + ) + self._parser.add_argument( "--tf_profile_export_path", type=str, diff --git a/tftrt/examples/benchmark_runner.py b/tftrt/examples/benchmark_runner.py index 7d0b0a548..a2160ae72 100644 --- a/tftrt/examples/benchmark_runner.py +++ b/tftrt/examples/benchmark_runner.py @@ -7,6 +7,7 @@ import abc import contextlib import copy +import csv import json import logging import sys @@ -125,28 +126,70 @@ def _debug_print(self, msg): def _export_runtime_metrics_to_json(self, metric_dict): - metric_dict = { - # Creating a copy to avoid modifying the original - "results": copy.deepcopy(metric_dict), - "runtime_arguments": vars(self._args) - } + try: - json_path = self._args.export_metrics_json_path - if json_path is not None: - try: - with open(json_path, 'w') as json_f: - json_string = json.dumps( - metric_dict, - default=lambda o: o.__dict__, - sort_keys=True, - indent=4 - ) - print(json_string, file=json_f) - except Exception as e: - print( - "[ERROR] Impossible to save JSON File at path: " - f"{json_path}.\nError: {str(e)}" + file_path = self._args.export_metrics_json_path + if file_path is None: + return + + metric_dict = { + # Creating a copy to avoid modifying the original + "results": copy.deepcopy(metric_dict), + "runtime_arguments": vars(self._args) + } + + with open(file_path, 'w') as json_f: + json_string = json.dumps( + metric_dict, + default=lambda o: o.__dict__, + sort_keys=True, + indent=4 ) + print(json_string, file=json_f) + + except Exception as e: + print(f"An exception occured during export to JSON: {e}") + + def _export_runtime_metrics_to_csv(self, metric_dict): + + try: + + file_path = self._args.export_metrics_csv_path + if file_path is None: + return + + data = {f"metric_{k}": v for k, v in metric_dict.items()} + + args_to_save = [ + "batch_size", + "input_saved_model_dir", + "minimum_segment_size", + "no_tf32", + "precision", + "use_dynamic_shape", + "use_synthetic_data", + "use_tftrt", + "use_xla", + "use_xla_auto_jit" + ] + + runtime_arguments = vars(self._args) + for key in args_to_save: + data[f"arg_{key}"] = str(runtime_arguments[key]).split("/")[-1] + + fieldnames = sorted(data.keys()) + + if not os.path.isfile(file_path): + with open(file_path, 'w') as outcsv: + writer = csv.DictWriter(outcsv, fieldnames=fieldnames, delimiter=',') + writer.writeheader() + + with open(file_path, 'a') as outcsv: + writer = csv.DictWriter(outcsv, fieldnames=fieldnames, delimiter=',') + writer.writerow(data) + + except Exception as e: + print(f"An exception occured during export to CSV: {e}") def _get_graph_func(self): """Retreives a frozen SavedModel and applies TF-TRT @@ -524,6 +567,7 @@ def timing_metrics(time_arr, log_prefix): metrics.update(timing_metrics(memcopy_times, "Data MemCopyHtoD Time")) self._export_runtime_metrics_to_json(metrics) + self._export_runtime_metrics_to_csv(metrics) def log_value(key, val): if isinstance(val, int): diff --git a/tftrt/examples/benchmark_utils.py b/tftrt/examples/benchmark_utils.py index 4709c904e..8eb890565 100644 --- a/tftrt/examples/benchmark_utils.py +++ b/tftrt/examples/benchmark_utils.py @@ -49,7 +49,7 @@ def timed_section(msg, activate=True, start_end_mode=True): total_time = time.time() - start_time if start_end_mode: - print(f"[END] `{msg}` - Duration: {total_time:.1f}s") + print(f"[END] {msg} - Duration: {total_time:.1f}s") print("=" * 80, "\n") else: print(f"{msg:18s}: {total_time:.4f}s")