diff --git a/tftrt/examples/benchmark_args.py b/tftrt/examples/benchmark_args.py index 22d80e6e3..dc161c211 100644 --- a/tftrt/examples/benchmark_args.py +++ b/tftrt/examples/benchmark_args.py @@ -272,6 +272,13 @@ def __init__(self): "performance analysis." ) + self._add_bool_argument( + name="tf_profile_verbose", + default=False, + required=False, + help="If set to True, will add extra information to the TF Profile." + ) + self._add_bool_argument( name="debug", default=False, @@ -378,6 +385,15 @@ def _validate_args(self, args): "calibration." ) + if ( + args.tf_profile_verbose and + args.tf_profile_export_path is None + ): + raise ValueError( + "`--tf_profile_verbose` can only be set if " + "`--tf_profile_export_path=/path/to/export` is defined." + ) + def _post_process_args(self, args): return args diff --git a/tftrt/examples/benchmark_autotuner.py b/tftrt/examples/benchmark_autotuner.py new file mode 100644 index 000000000..25fc5a728 --- /dev/null +++ b/tftrt/examples/benchmark_autotuner.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# -*- coding: utf-8 -*- + +import time +import numpy as np +import tensorflow as tf + +from benchmark_utils import force_gpu_resync + + +class _TFFunctionAutoTuner(object): + def __init__(self, funcs, calls_per_func, skip_n_first): + if not isinstance(funcs, (tuple, list)): + raise ValueError("Argument `funcs` must be a list or tuple.") + + if any([not callable(fn) for fn in funcs]): + raise ValueError("One of the function passed is not callable.") + + self._fns = funcs + self._calls_per_func = calls_per_func + self._skip_n_first = skip_n_first + + self._call_counter = 0 + self._timings = [[] for _ in range(len(self._fns))] + + self._best_fn = self._autotune + + def _autotune(self, *arg, **kwargs): + fn_id = self._call_counter // self._calls_per_func + try: + start_t = time.time() + output = self._fns[fn_id](*arg, **kwargs) + self._timings[fn_id].append(time.time() - start_t) + except IndexError: + print("\n[DEBUG] AutoTuning is over... Collecting timing statistics:") + perf_data = [] + for idx, fn_stat in enumerate(self._timings): + perf_data.append(np.mean(fn_stat[self._skip_n_first:])) + print(f"\t- [DEBUG] Function ID: {idx} - " + f"Name: {self._fns[idx].__name__:40s} - " + f"Average Exec Time: {perf_data[-1]}") + + best_fn_id = np.argmin(perf_data) + print(f"[DEBUG] Selecting function ID: {best_fn_id}. " + f"Setting exec path to: `{self._fns[best_fn_id].__name__}`\n") + + self._best_fn = self._fns[best_fn_id] + return self._best_fn(*arg, **kwargs) + + self._call_counter += 1 + return output + + def __call__(self, *arg, **kwargs): + return self._best_fn(*arg, **kwargs) + + +def _force_using_concrete_function(func): + # `context` needs to be a closure of type list or dict for persistance + context = [] + def _wrapper(*args, **kwargs): + try: + return context[0](*args, **kwargs) + except IndexError: + print(f"[INFO] Building the concrete function") + context.append(func.get_concrete_function(*args, **kwargs)) + return context[0](*args, **kwargs) + return _wrapper + + +def auto_tf_func_tuner( + calls_per_func=45, + skip_n_first=30, + use_xla=False, + use_synthetic_data=False +): + + def wrapper(func): + + @force_gpu_resync + def eager_function(*args, **kwargs): + return func(*args, **kwargs) + + @force_gpu_resync + @tf.function(jit_compile=use_xla) + def tf_function(*args, **kwargs): + return func(*args, **kwargs) + + @force_gpu_resync + @_force_using_concrete_function + @tf.function(jit_compile=use_xla) + def tf_concrete_function(*args, **kwargs): + return func(*args, **kwargs) + + eager_function.__name__ = f"{func.__name__}_eager" + tf_function.__name__ = f"{func.__name__}_tf_function" + tf_concrete_function.__name__ = f"{func.__name__}_tf_concrete_function" + + funcs2autotune = [eager_function, tf_function] + if use_synthetic_data: + print("[INFO] Allowing direct concrete_function call with " + "synthetic data loader.") + funcs2autotune.append(tf_concrete_function) + + return _TFFunctionAutoTuner( + funcs2autotune, + calls_per_func=calls_per_func, + skip_n_first=skip_n_first + ) + + return wrapper diff --git a/tftrt/examples/benchmark_runner.py b/tftrt/examples/benchmark_runner.py index a2160ae72..7515e7559 100644 --- a/tftrt/examples/benchmark_runner.py +++ b/tftrt/examples/benchmark_runner.py @@ -15,8 +15,9 @@ from distutils.util import strtobool +from benchmark_autotuner import auto_tf_func_tuner + from benchmark_utils import DataAggregator -from benchmark_utils import force_gpu_resync from benchmark_utils import print_dict from benchmark_utils import timed_section @@ -383,16 +384,14 @@ def execute_benchmark(self): dataset, bypass_data_to_eval = self.get_dataset_batches() if self._args.use_synthetic_data: - old_ds = dataset try: - dataset = SyntheticDataset(old_ds, device="/gpu:0") + dataset = SyntheticDataset(dataset, device="/gpu:0") self._debug_print( "Model dataset has been replaced by a synthetic data " "loader to minimize data loading jitter." ) except Exception as e: - dataset = old_ds print( f"[ERROR] Impossible to transform the dataset into a " f"synthetic dataset. Performance numbers will be " @@ -401,8 +400,10 @@ def execute_benchmark(self): else: dataset = ensure_dataset_on_gpu(dataset, device="GPU:0") - @force_gpu_resync - @tf.function(jit_compile=self._args.use_xla) + @auto_tf_func_tuner( + use_xla=self._args.use_xla, + use_synthetic_data=self._args.use_synthetic_data + ) def infer_batch(x): if isinstance(x, (tuple, list)): model_out = graph_func(*x) @@ -439,72 +440,112 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time): ) if self._args.tf_profile_export_path: - profiling_ctx = tf.profiler.experimental.Profile( - self._args.tf_profile_export_path - ) + def start_profiling(): + if self._args.tf_profile_verbose: + profiler_opts = tf.profiler.experimental.ProfilerOptions( + # Ajust TraceMe levels: + # - 1: critical + # - 2: info [default] + # - 3: verbose + host_tracer_level=2, + # Enables python function call tracing + # - 0: disable [default] + # - 1: enable + python_tracer_level=1, + # Adjust device (TPU/GPU) tracer level: + # - 0: disable + # - 1: enable [default] + device_tracer_level=1, + # start profiling after 15 sec. + # - Skip tf.function building + # - Skip autotuning + delay_ms=30000 + ) + print("[INFO] Using verbose TF Profiler.") + else: + profiler_opts = None + + profiling_ctx = tf.profiler.experimental.start( + self._args.tf_profile_export_path, + options=profiler_opts + ) + + stop_profiling = tf.profiler.experimental.stop + tracing_ctx = tf.profiler.experimental.Trace + else: + start_profiling = stop_profiling = lambda *a, **kw: None profiling_ctx = contextlib.nullcontext() tracing_ctx = lambda *a, **kw: contextlib.nullcontext() step_idx = 0 ds_iter = iter(dataset) - dequeue_batch_fn = get_dequeue_batch_fn(ds_iter) + dequeue_batch_fn = get_dequeue_batch_fn( + ds_iter, + use_xla=self._args.use_xla, + use_synthetic_data=self._args.use_synthetic_data + ) + force_data_on_gpu_fn = get_force_data_on_gpu_fn( device="/gpu:0", - use_xla=self._args.use_xla + use_xla=self._args.use_xla, + use_synthetic_data=self._args.use_synthetic_data ) - with profiling_ctx: - - while True: - - step_idx += 1 + while True: - if (self._args.num_iterations is not None and - step_idx > self._args.num_iterations): - break - - with tracing_ctx('Inference Step', step_num=step_idx, _r=1): + step_idx += 1 - with tracing_ctx('Input Dequeueing', step_num=step_idx, _r=1): - try: - start_time = time.time() - data_batch = dequeue_batch_fn() - dequeue_times.append(time.time() - start_time) - except (StopIteration, OutOfRangeError): - print("[Exiting] Reached end of dataset ...") - break + if step_idx == self._args.num_warmup_iterations - 5: + start_profiling() - with tracing_ctx('Inputs Preprocessing', step_num=step_idx, _r=1): - x, y = self.preprocess_model_inputs(data_batch) + if ( + self._args.num_iterations is not None and + step_idx > self._args.num_iterations + ): + break - with tracing_ctx('Inputs MemcpyHtoD', step_num=step_idx, _r=1): - start_time = time.time() - x = force_data_on_gpu_fn(x) - memcopy_times.append(time.time() - start_time) + with tracing_ctx('', step_num=step_idx, _r=1): - with tracing_ctx('GPU Inference', step_num=step_idx, _r=1): + with tracing_ctx('Input Dequeueing'): + try: start_time = time.time() - y_pred = infer_batch(x) - iter_times.append(time.time() - start_time) - - if not self._args.debug_performance: - log_step( - step_idx, - display_every=self._args.display_every, - iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000, - memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000, - dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000 - ) - else: - print(f"{'GPU Iteration Time':18s}: {iter_times[-1]:08.4f}s") - print(f"{'Data MemCopyHtoD Time':18s}: {memcpyHtoD_time[-1]:08.4f}s") - print(f"{'Data Dequeue Time':18s}: {dequeue_times[-1]:08.4f}s") + data_batch = dequeue_batch_fn() + dequeue_times.append(time.time() - start_time) + except (StopIteration, OutOfRangeError): + print("[Exiting] Reached end of dataset ...") + break + + with tracing_ctx('Inputs Preprocessing'): + x, y = self.preprocess_model_inputs(data_batch) + + with tracing_ctx('Inputs MemcpyHtoD'): + start_time = time.time() + x = force_data_on_gpu_fn(x) + memcopy_times.append(time.time() - start_time) + + with tracing_ctx('GPU Inference'): + start_time = time.time() + y_pred = infer_batch(x) + iter_times.append(time.time() - start_time) + + if not self._args.debug_performance: + log_step( + step_idx, + display_every=self._args.display_every, + iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000, + memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000, + dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000 + ) + else: + print(f"{'GPU Iteration Time':18s}: {iter_times[-1]:08.4f}s") + print(f"{'Data MemCopyHtoD Time':18s}: {memcpyHtoD_time[-1]:08.4f}s") + print(f"{'Data Dequeue Time':18s}: {dequeue_times[-1]:08.4f}s") - if not self._args.use_synthetic_data: - data_aggregator.aggregate_data(y_pred, y) + if not self._args.use_synthetic_data: + data_aggregator.aggregate_data(y_pred, y) if ( not self._args.debug_performance and @@ -518,6 +559,9 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time): dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000 ) + if step_idx >= 100: + stop_profiling() + with timed_section("Metric Computation"): metrics = dict() diff --git a/tftrt/examples/benchmark_utils.py b/tftrt/examples/benchmark_utils.py index 8eb890565..127ecbb50 100644 --- a/tftrt/examples/benchmark_utils.py +++ b/tftrt/examples/benchmark_utils.py @@ -9,18 +9,30 @@ from contextlib import contextmanager -__all__ = ["DataAggregator", "force_gpu_resync", "print_dict", "timed_section"] - def force_gpu_resync(func): - p = tf.constant(0.) # Create small tensor to force GPU resync + try: + sync_device_fn = tf.experimental.sync_devices + print("[INFO] Using API `tf.experimental.sync_devices` to resync GPUs.") + + def wrapper(*args, **kwargs): + rslt = func(*args, **kwargs) + sync_device_fn() + return rslt + + return wrapper - def wrapper(*args, **kwargs): - rslt = func(*args, **kwargs) - (p + 1.).numpy() # Sync the GPU - return rslt + except AttributeError: + print("[WARNING] Using deprecated API to resync GPUs. " + "Non negligeable overhead might be present.") + p = tf.constant(0.) # Create small tensor to force GPU resync - return wrapper + def wrapper(*args, **kwargs): + rslt = func(*args, **kwargs) + (p + 1.).numpy() # Sync the GPU + return rslt + + return wrapper def print_dict(input_dict, prefix=' ', postfix='', redirect_to_str=False): @@ -191,8 +203,6 @@ def aggregate_data(self, y_pred, y): ) self._expected[key][idx_start:idx_stop] = y[key] - print() - def patch_dali_dataset(dataset): import nvidia.dali.plugin.tf as dali_tf diff --git a/tftrt/examples/dataloading_utils.py b/tftrt/examples/dataloading_utils.py index 6d3a8d043..9e25aae99 100644 --- a/tftrt/examples/dataloading_utils.py +++ b/tftrt/examples/dataloading_utils.py @@ -5,7 +5,7 @@ import time import tensorflow as tf -from benchmark_utils import force_gpu_resync +from benchmark_autotuner import auto_tf_func_tuner class SyntheticDataset(object): @@ -25,38 +25,42 @@ def __init__(self, dataset, device): buffer_size=tf.data.experimental.AUTOTUNE ) ) - self._ds_iter = iter(dataset) - self._device = device + self._ds = dataset + self._data_batch = next(iter(dataset)) def __iter__(self): - - data_batch = next(self._ds_iter) - - while True: - yield data_batch + return iter(self._ds) def ensure_dataset_on_gpu(dataset, device): + if isinstance(dataset, SyntheticDataset): + return dataset + try: ds_device = dataset._variant_tensor_attr.device.lower() - except AttributeError: + except AttributeError as e: + print( + f"[ERROR] Impossible to find the device from the dataset.\n" + f"Error: {e}." + ) return dataset if device.lower() not in ds_device: - return dataset.apply( + print(f"[INFO] Adding prefetch to device `{device}` to the dataset.") + dataset = dataset.apply( tf.data.experimental.prefetch_to_device( device=device, buffer_size=tf.data.experimental.AUTOTUNE ) ) - - else: return dataset + return dataset + -def get_dequeue_batch_fn(ds_iter): +def get_dequeue_batch_fn(ds_iter, use_xla=False, use_synthetic_data=False): - @force_gpu_resync + @auto_tf_func_tuner(use_xla=use_xla, use_synthetic_data=use_synthetic_data) def dequeue_batch_fn(): """This function should not use tf.function(). It would create two unwanted effects: @@ -68,10 +72,9 @@ def dequeue_batch_fn(): return dequeue_batch_fn -def get_force_data_on_gpu_fn(device="/gpu:0", use_xla=False): +def get_force_data_on_gpu_fn(device="/gpu:0", use_xla=False, use_synthetic_data=False): - @force_gpu_resync - @tf.function(jit_compile=use_xla) + @auto_tf_func_tuner(use_xla=use_xla, use_synthetic_data=use_synthetic_data) def force_data_on_gpu_fn(data): with tf.device(device): if isinstance(data, (list, tuple)):