Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 064de53

Browse files
author
DEKHTIARJonathan
committed
TF Profiler Instrumentation
1 parent 263043b commit 064de53

File tree

2 files changed

+68
-37
lines changed

2 files changed

+68
-37
lines changed

tftrt/examples/benchmark_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,14 @@ def __init__(self):
245245
"to the set location in JSON format for further processing."
246246
)
247247

248+
self._parser.add_argument(
249+
"--tf_profile_export_path",
250+
type=str,
251+
default=None,
252+
help="If set, the script will export tf.profile files for further "
253+
"performance analysis."
254+
)
255+
248256
self._add_bool_argument(
249257
name="debug",
250258
default=False,

tftrt/examples/benchmark_runner.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66

77
import abc
8+
import contextlib
89
import copy
910
import json
1011
import logging
@@ -389,6 +390,11 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
389390
dataset, activate=self._args.debug_performance
390391
)
391392

393+
@force_gpu_resync
394+
@tf.function()
395+
def dequeue_batch(ds_iter):
396+
return next(ds_iter)
397+
392398
@force_gpu_resync
393399
@tf.function()
394400
def force_data_on_gpu(data, device="/gpu:0"):
@@ -408,53 +414,70 @@ def force_data_on_gpu(data, device="/gpu:0"):
408414
step_idx = 0
409415
ds_iter = iter(dataset)
410416

411-
while True:
417+
if self._args.tf_profile_export_path:
418+
profiling_ctx = tf.profiler.experimental.Profile(
419+
self._args.tf_profile_export_path
420+
)
421+
tracing_ctx = tf.profiler.experimental.Trace
422+
else:
423+
profiling_ctx = contextlib.nullcontext()
424+
tracing_ctx = lambda *a, **kw: contextlib.nullcontext()
412425

413-
try:
414-
start_time = time.time()
415-
data_batch = next(ds_iter)
416-
dequeue_times.append(time.time() - start_time)
417-
except:
418-
break
419-
420-
start_time = time.time()
421-
data_batch = force_data_on_gpu(data_batch)
422-
memcopy_times.append(time.time() - start_time)
423-
424-
x, y = self.preprocess_model_inputs(data_batch)
425-
426-
start_time = time.time()
427-
y_pred = infer_batch(x)
428-
iter_times.append(time.time() - start_time)
429-
430-
if not self._args.debug_performance:
431-
log_step(
432-
step_idx + 1,
433-
display_every=self._args.display_every,
434-
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
435-
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,
436-
dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000
437-
)
438-
else:
439-
print(f"{'GPU Iteration Time':18s}: {iter_times[-1]:08.4f}s")
440-
print(f"{'Data MemCopyHtoD Time':18s}: {memcpyHtoD_time[-1]:08.4f}s")
441-
print(f"{'Data Dequeue Time':18s}: {dequeue_times[-1]:08.4f}s")
426+
with profiling_ctx:
427+
428+
while True:
442429

443-
if not self._args.use_synthetic_data:
444-
data_aggregator.aggregate_data(y_pred, y)
430+
step_idx += 1
445431

446-
if (self._args.num_iterations is not None and
447-
step_idx + 1 >= self._args.num_iterations):
448-
break
432+
with tracing_ctx('Inference Step', step_num=step_idx, _r=1):
449433

450-
step_idx += 1
434+
with tracing_ctx('Input Dequeueing', step_num=step_idx, _r=1):
435+
try:
436+
start_time = time.time()
437+
data_batch = dequeue_batch(ds_iter)
438+
dequeue_times.append(time.time() - start_time)
439+
except:
440+
break
441+
442+
with tracing_ctx('Inputs MemcpyHtoD', step_num=step_idx, _r=1):
443+
start_time = time.time()
444+
data_batch = force_data_on_gpu(data_batch)
445+
memcopy_times.append(time.time() - start_time)
446+
447+
with tracing_ctx('Inputs Preprocessing', step_num=step_idx, _r=1):
448+
x, y = self.preprocess_model_inputs(data_batch)
449+
450+
with tracing_ctx('GPU Inference', step_num=step_idx, _r=1):
451+
start_time = time.time()
452+
y_pred = infer_batch(x)
453+
iter_times.append(time.time() - start_time)
454+
455+
if not self._args.debug_performance:
456+
log_step(
457+
step_idx,
458+
display_every=self._args.display_every,
459+
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
460+
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,
461+
dequeue_time=np.mean(dequeue_times[-self._args.display_every:]) * 1000
462+
)
463+
else:
464+
print(f"{'GPU Iteration Time':18s}: {iter_times[-1]:08.4f}s")
465+
print(f"{'Data MemCopyHtoD Time':18s}: {memcpyHtoD_time[-1]:08.4f}s")
466+
print(f"{'Data Dequeue Time':18s}: {dequeue_times[-1]:08.4f}s")
467+
468+
if not self._args.use_synthetic_data:
469+
data_aggregator.aggregate_data(y_pred, y)
470+
471+
if (self._args.num_iterations is not None and
472+
step_idx >= self._args.num_iterations):
473+
break
451474

452475
if (
453476
not self._args.debug_performance and
454477
step_idx % self._args.display_every != 0
455478
): # avoids double printing
456479
log_step(
457-
step_idx + 1,
480+
step_idx,
458481
display_every=1, # force print
459482
iter_time=np.mean(iter_times[-self._args.display_every:]) * 1000,
460483
memcpyHtoD_time=np.mean(memcopy_times[-self._args.display_every:]) * 1000,

0 commit comments

Comments
 (0)