From 2748e7c7d4ad314f78bbd73f6771699cdbce26c7 Mon Sep 17 00:00:00 2001 From: Deepak Narayanan Date: Sun, 26 Nov 2023 19:04:24 -0800 Subject: [PATCH] Compute and log throughput if --log-throughput option is specified --- megatron/arguments.py | 2 ++ megatron/training.py | 38 +++++++++++++++++++++++++++++++++----- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 0ca8776eda..d4f1cd5a32 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -657,6 +657,8 @@ def _add_logging_args(parser): help='If set, calculate and log parameters norm.') group.add_argument('--log-num-zeros-in-grad', action='store_true', help='If set, calculate and log the number of zeros in gradient.') + group.add_argument('--log-throughput', action='store_true', + help='If set, calculate and log throughput per GPU.') group.add_argument('--timing-log-level', type=int, default=0, choices=range(0,3), help='Granularity level to measure and report timing. ' diff --git a/megatron/training.py b/megatron/training.py index 8c5284c2a6..f3e3cafa31 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -56,6 +56,25 @@ def print_datetime(string): print_rank_0('[' + string + '] datetime: {} '.format(time_str)) +def num_floating_point_operations(args, batch_size): + if not args.group_query_attention: + args.num_query_groups = args.num_attention_heads + return ( + 60 + * batch_size + * args.seq_length + * args.num_layers + * args.hidden_size + * args.hidden_size + * ( + 1 + + (args.num_query_groups / (5 * args.num_attention_heads)) + + (args.seq_length / (5 * args.hidden_size)) + + (args.padded_vocab_size / (10 * args.num_layers * args.hidden_size)) + ) + ) + + def pretrain(train_valid_test_dataset_provider, model_provider, model_type, @@ -628,19 +647,28 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, if iteration % args.log_interval == 0: elapsed_time = timers('interval-time').elapsed(barrier=True) elapsed_time_per_iteration = elapsed_time / total_iterations - if writer: - if args.log_timers_to_tensorboard: + throughput = num_floating_point_operations(args, batch_size) / ( + elapsed_time_per_iteration * 10**12 * args.world_size) + if args.log_timers_to_tensorboard: + if writer: writer.add_scalar('iteration-time', elapsed_time_per_iteration, iteration) - if wandb_writer: - wandb_writer.log({'iteration-time': - elapsed_time_per_iteration}, iteration) + if wandb_writer: + wandb_writer.log({'iteration-time': elapsed_time_per_iteration}, + iteration) log_string = ' iteration {:8d}/{:8d} |'.format( iteration, args.train_iters) log_string += ' consumed samples: {:12d} |'.format( args.consumed_train_samples) log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( elapsed_time_per_iteration * 1000.0) + if args.log_throughput: + log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |' + if args.log_timers_to_tensorboard: + if writer: + writer.add_scalar('throughput', throughput, iteration) + if wandb_writer: + wandb_writer.log({'throughput': throughput}, iteration) log_string += ' learning rate: {:.3E} |'.format(learning_rate) log_string += ' global batch size: {:5d} |'.format(batch_size) for key in total_loss_dict: