5
5
import os
6
6
7
7
import abc
8
+ import contextlib
8
9
import copy
9
10
import json
10
11
import logging
@@ -389,6 +390,11 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
389
390
dataset , activate = self ._args .debug_performance
390
391
)
391
392
393
+ @force_gpu_resync
394
+ @tf .function ()
395
+ def dequeue_batch (ds_iter ):
396
+ return next (ds_iter )
397
+
392
398
@force_gpu_resync
393
399
@tf .function ()
394
400
def force_data_on_gpu (data , device = "/gpu:0" ):
@@ -408,53 +414,70 @@ def force_data_on_gpu(data, device="/gpu:0"):
408
414
step_idx = 0
409
415
ds_iter = iter (dataset )
410
416
411
- while True :
417
+ if self ._args .tf_profile_export_path :
418
+ profiling_ctx = tf .profiler .experimental .Profile (
419
+ self ._args .tf_profile_export_path
420
+ )
421
+ tracing_ctx = tf .profiler .experimental .Trace
422
+ else :
423
+ profiling_ctx = contextlib .nullcontext ()
424
+ tracing_ctx = lambda * a , ** kw : contextlib .nullcontext ()
412
425
413
- try :
414
- start_time = time .time ()
415
- data_batch = next (ds_iter )
416
- dequeue_times .append (time .time () - start_time )
417
- except :
418
- break
419
-
420
- start_time = time .time ()
421
- data_batch = force_data_on_gpu (data_batch )
422
- memcopy_times .append (time .time () - start_time )
423
-
424
- x , y = self .preprocess_model_inputs (data_batch )
425
-
426
- start_time = time .time ()
427
- y_pred = infer_batch (x )
428
- iter_times .append (time .time () - start_time )
429
-
430
- if not self ._args .debug_performance :
431
- log_step (
432
- step_idx + 1 ,
433
- display_every = self ._args .display_every ,
434
- iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
435
- memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
436
- dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
437
- )
438
- else :
439
- print (f"{ 'GPU Iteration Time' :18s} : { iter_times [- 1 ]:08.4f} s" )
440
- print (f"{ 'Data MemCopyHtoD Time' :18s} : { memcpyHtoD_time [- 1 ]:08.4f} s" )
441
- print (f"{ 'Data Dequeue Time' :18s} : { dequeue_times [- 1 ]:08.4f} s" )
426
+ with profiling_ctx :
427
+
428
+ while True :
442
429
443
- if not self ._args .use_synthetic_data :
444
- data_aggregator .aggregate_data (y_pred , y )
430
+ step_idx += 1
445
431
446
- if (self ._args .num_iterations is not None and
447
- step_idx + 1 >= self ._args .num_iterations ):
448
- break
432
+ with tracing_ctx ('Inference Step' , step_num = step_idx , _r = 1 ):
449
433
450
- step_idx += 1
434
+ with tracing_ctx ('Input Dequeueing' , step_num = step_idx , _r = 1 ):
435
+ try :
436
+ start_time = time .time ()
437
+ data_batch = dequeue_batch (ds_iter )
438
+ dequeue_times .append (time .time () - start_time )
439
+ except :
440
+ break
441
+
442
+ with tracing_ctx ('Inputs MemcpyHtoD' , step_num = step_idx , _r = 1 ):
443
+ start_time = time .time ()
444
+ data_batch = force_data_on_gpu (data_batch )
445
+ memcopy_times .append (time .time () - start_time )
446
+
447
+ with tracing_ctx ('Inputs Preprocessing' , step_num = step_idx , _r = 1 ):
448
+ x , y = self .preprocess_model_inputs (data_batch )
449
+
450
+ with tracing_ctx ('GPU Inference' , step_num = step_idx , _r = 1 ):
451
+ start_time = time .time ()
452
+ y_pred = infer_batch (x )
453
+ iter_times .append (time .time () - start_time )
454
+
455
+ if not self ._args .debug_performance :
456
+ log_step (
457
+ step_idx ,
458
+ display_every = self ._args .display_every ,
459
+ iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
460
+ memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
461
+ dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
462
+ )
463
+ else :
464
+ print (f"{ 'GPU Iteration Time' :18s} : { iter_times [- 1 ]:08.4f} s" )
465
+ print (f"{ 'Data MemCopyHtoD Time' :18s} : { memcpyHtoD_time [- 1 ]:08.4f} s" )
466
+ print (f"{ 'Data Dequeue Time' :18s} : { dequeue_times [- 1 ]:08.4f} s" )
467
+
468
+ if not self ._args .use_synthetic_data :
469
+ data_aggregator .aggregate_data (y_pred , y )
470
+
471
+ if (self ._args .num_iterations is not None and
472
+ step_idx >= self ._args .num_iterations ):
473
+ break
451
474
452
475
if (
453
476
not self ._args .debug_performance and
454
477
step_idx % self ._args .display_every != 0
455
478
): # avoids double printing
456
479
log_step (
457
- step_idx + 1 ,
480
+ step_idx ,
458
481
display_every = 1 , # force print
459
482
iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
460
483
memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
0 commit comments