55import os
66
77import abc
8+ import contextlib
89import copy
910import json
1011import logging
@@ -390,7 +391,12 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
390391 )
391392
392393 @force_gpu_resync
393- @tf .function ()
394+ @tf .function (jit_compile = self ._args .use_xla )
395+ def dequeue_batch (ds_iter ):
396+ return next (ds_iter )
397+
398+ @force_gpu_resync
399+ @tf .function (jit_compile = self ._args .use_xla )
394400 def force_data_on_gpu (data , device = "/gpu:0" ):
395401 with tf .device (device ):
396402 if isinstance (data , (list , tuple )):
@@ -403,58 +409,77 @@ def force_data_on_gpu(data, device="/gpu:0"):
403409 output_data [k ] = tf .identity (v )
404410 else :
405411 output_data = tf .identity (data )
412+
406413 return output_data
407414
415+ if self ._args .tf_profile_export_path :
416+ profiling_ctx = tf .profiler .experimental .Profile (
417+ self ._args .tf_profile_export_path
418+ )
419+ tracing_ctx = tf .profiler .experimental .Trace
420+ else :
421+ profiling_ctx = contextlib .nullcontext ()
422+ tracing_ctx = lambda * a , ** kw : contextlib .nullcontext ()
423+
408424 step_idx = 0
409425 ds_iter = iter (dataset )
410426
411- while True :
427+ with profiling_ctx :
412428
413- try :
414- start_time = time .time ()
415- data_batch = next (ds_iter )
416- dequeue_times .append (time .time () - start_time )
417- except :
418- break
419-
420- start_time = time .time ()
421- data_batch = force_data_on_gpu (data_batch )
422- memcopy_times .append (time .time () - start_time )
423-
424- x , y = self .preprocess_model_inputs (data_batch )
425-
426- start_time = time .time ()
427- y_pred = infer_batch (x )
428- iter_times .append (time .time () - start_time )
429-
430- if not self ._args .debug_performance :
431- log_step (
432- step_idx + 1 ,
433- display_every = self ._args .display_every ,
434- iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
435- memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
436- dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
437- )
438- else :
439- print (f"{ 'GPU Iteration Time' :18s} : { iter_times [- 1 ]:08.4f} s" )
440- print (f"{ 'Data MemCopyHtoD Time' :18s} : { memcpyHtoD_time [- 1 ]:08.4f} s" )
441- print (f"{ 'Data Dequeue Time' :18s} : { dequeue_times [- 1 ]:08.4f} s" )
429+ while True :
442430
443- if not self ._args .use_synthetic_data :
444- data_aggregator .aggregate_data (y_pred , y )
431+ step_idx += 1
445432
446- if (self ._args .num_iterations is not None and
447- step_idx + 1 >= self ._args .num_iterations ):
448- break
433+ if (self ._args .num_iterations is not None and
434+ step_idx >= self ._args .num_iterations ):
435+ break
436+
437+ with tracing_ctx ('Inference Step' , step_num = step_idx , _r = 1 ):
438+
439+ with tracing_ctx ('Input Dequeueing' , step_num = step_idx , _r = 1 ):
440+ try :
441+ start_time = time .time ()
442+ data_batch = dequeue_batch (ds_iter )
443+ dequeue_times .append (time .time () - start_time )
444+ except :
445+ print ("[Exiting] Reached end of dataset ..." )
446+ break
447+
448+ with tracing_ctx ('Inputs MemcpyHtoD' , step_num = step_idx , _r = 1 ):
449+ start_time = time .time ()
450+ data_batch = force_data_on_gpu (data_batch )
451+ memcopy_times .append (time .time () - start_time )
452+
453+ with tracing_ctx ('Inputs Preprocessing' , step_num = step_idx , _r = 1 ):
454+ x , y = self .preprocess_model_inputs (data_batch )
455+
456+ with tracing_ctx ('GPU Inference' , step_num = step_idx , _r = 1 ):
457+ start_time = time .time ()
458+ y_pred = infer_batch (x )
459+ iter_times .append (time .time () - start_time )
460+
461+ if not self ._args .debug_performance :
462+ log_step (
463+ step_idx ,
464+ display_every = self ._args .display_every ,
465+ iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
466+ memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
467+ dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
468+ )
469+ else :
470+ print (f"{ 'GPU Iteration Time' :18s} : { iter_times [- 1 ]:08.4f} s" )
471+ print (f"{ 'Data MemCopyHtoD Time' :18s} : { memcpyHtoD_time [- 1 ]:08.4f} s" )
472+ print (f"{ 'Data Dequeue Time' :18s} : { dequeue_times [- 1 ]:08.4f} s" )
449473
450- step_idx += 1
474+ if not self ._args .use_synthetic_data :
475+ data_aggregator .aggregate_data (y_pred , y )
451476
452477 if (
453478 not self ._args .debug_performance and
454479 step_idx % self ._args .display_every != 0
455480 ): # avoids double printing
456481 log_step (
457- step_idx + 1 ,
482+ step_idx ,
458483 display_every = 1 , # force print
459484 iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
460485 memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
0 commit comments