Skip to content

Commit

Permalink
Merge branch 'check_memory_usage_v2' into 'main'
Browse files Browse the repository at this point in the history
Add memory_usage printing to megatron/training.py

See merge request ADLR/megatron-lm!956
  • Loading branch information
jaredcasper committed Nov 28, 2023
2 parents d0beaa7 + 0bbdc62 commit 744adfc
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 79 deletions.
79 changes: 0 additions & 79 deletions compute_memory_usage.py

This file was deleted.

159 changes: 159 additions & 0 deletions megatron/theoretical_memory_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Computes theoretical memory footprint for model training."""


import math


NUM_BYTES_IN_MEGABYTE = 1024 * 1024


def compute_weight_and_optimizer_memory(args, verbose=False):
if not args.group_query_attention:
args.num_query_groups = args.num_attention_heads
num_parameters_in_transformer_layers = (
10
* args.num_layers
* args.hidden_size
* args.hidden_size
* (
1
+ (args.num_query_groups / (5.0 * args.num_attention_heads))
+ (2 / (5 * args.hidden_size))
+ (1 / (5 * args.num_layers * args.hidden_size))
)
)
embedding_size = args.hidden_size * args.padded_vocab_size
if args.untie_embeddings_and_output_weights:
num_total_parameters_with_embeddings = num_parameters_in_transformer_layers + (
2 * embedding_size
)
else:
num_total_parameters_with_embeddings = num_parameters_in_transformer_layers + embedding_size
if verbose:
print(
f"Number of parameters in billions: {num_total_parameters_with_embeddings / 10**9:.2f}"
)

# Most loaded model shard has (1/pp_size transformer layers + 1 embedding layer) / tp_size.
num_parameters_on_most_loaded_model_shard = (
(num_parameters_in_transformer_layers / args.pipeline_model_parallel_size) + embedding_size
) / args.tensor_model_parallel_size
if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1:
num_parameters_on_most_loaded_model_shard += (
embedding_size / args.tensor_model_parallel_size
)
if verbose:
print(
f"Number of parameters in most loaded shard in billions: {num_parameters_on_most_loaded_model_shard / 10**9:.4f}"
)

if args.pipeline_model_parallel_size > 1:
# Other shards just have (1/pp_size transformer layers) / tp_size.
num_parameters_on_other_model_shards = num_parameters_in_transformer_layers / (
args.pipeline_model_parallel_size * args.tensor_model_parallel_size
)
if verbose:
print(
f"Number of parameters in other shards in billions: {num_parameters_on_other_model_shards / 10**9:.4f}"
)

num_bytes_per_parameter = (
18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size)
)
weight_and_optimizer_memory = (
num_parameters_on_most_loaded_model_shard * num_bytes_per_parameter
)

return weight_and_optimizer_memory


def compute_activation_memory(args, num_microbatches, verbose=False):
# Using formula in Table 2 of https://arxiv.org/pdf/2205.05198.pdf.
# We are trying to compute the maximum activation footprint, so all calculations in this function
# are for the first pipeline stage.

# Memory footprint from transformer layer (self-attention and MLP).
activation_memory = (args.seq_length * args.micro_batch_size * args.hidden_size) * 34
if verbose:
print(
f"Activation memory footprint per transformer layer: "
f"{activation_memory / NUM_BYTES_IN_MEGABYTE / args.tensor_model_parallel_size:.1f} MB"
)
activation_memory *= args.num_layers

# Now add activation memory required for input embeddings, last LayerNorm and output layer.

# Input to embedding (pp_size microbatches in flight).
activation_memory += (
8 * args.seq_length * args.micro_batch_size * args.pipeline_model_parallel_size
)
# Dropout in embedding layer (pp_size microbatches in flight).
activation_memory += (
args.seq_length
* args.micro_batch_size
* args.hidden_size
* args.pipeline_model_parallel_size
)

# Multiply by interleaved PP memory factor.
if args.virtual_pipeline_model_parallel_size is not None:
interleaved_schedule_memory_penalty = 1 + (
(args.pipeline_model_parallel_size - 1)
/ (args.pipeline_model_parallel_size * args.virtual_pipeline_model_parallel_size)
)
in_flight_microbatches = math.ceil(
interleaved_schedule_memory_penalty * args.pipeline_model_parallel_size
)
if verbose:
print(
f"Memory penalty from interleaved schedule: {interleaved_schedule_memory_penalty:.2f}"
)
print(f"Number of in-flight microbatches: {in_flight_microbatches}")
activation_memory *= interleaved_schedule_memory_penalty

# If using non-interleaved schedule, number of microbatches in pipeline can be less than pp_size,
# so discount accordingly.
if args.virtual_pipeline_model_parallel_size is None and args.pipeline_model_parallel_size > 1:
if num_microbatches is not None:
activation_memory *= min(1, num_microbatches / args.pipeline_model_parallel_size)
in_flight_microbatches = min(num_microbatches, args.pipeline_model_parallel_size)
else:
in_flight_microbatches = args.pipeline_model_parallel_size
if verbose:
print(f"Number of in-flight microbatches: {in_flight_microbatches}")

if args.pipeline_model_parallel_size == 1:
# Inputs to output layer and CE loss.
activation_memory += (
args.seq_length
* args.micro_batch_size
* args.hidden_size
* 4
* (1 + (args.padded_vocab_size / args.hidden_size))
)

# Activation memory is partitioned by TP size due to tensor and sequence model parallelism.
return activation_memory / args.tensor_model_parallel_size


def report_theoretical_memory(args, num_microbatches=None, verbose=False):
# Formulae here assume sequence parallelism and selective activation recomputation.
if not args.sequence_parallel or args.recompute_granularity != 'selective':
return

weight_and_optimizer_memory = (
compute_weight_and_optimizer_memory(args, verbose=verbose) / NUM_BYTES_IN_MEGABYTE
)
activation_memory = (
compute_activation_memory(args, num_microbatches=num_microbatches, verbose=verbose)
/ NUM_BYTES_IN_MEGABYTE
)
total_memory = weight_and_optimizer_memory + activation_memory

print(
f"Theoretical memory footprints: weight and optimizer={weight_and_optimizer_memory:.2f} MB, "
f"activation={activation_memory:.2f} MB, "
f"total={total_memory:.2f} MB\n"
)
4 changes: 4 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .log_handler import CustomHandler
# Make default logging level INFO, but filter out all log messages not from MCore.
logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO)
from .theoretical_memory_usage import report_theoretical_memory
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
Expand Down Expand Up @@ -668,6 +669,9 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
# Report memory after optimizer state has been initialized.
if torch.distributed.get_rank() == 0:
num_microbatches = get_num_microbatches()
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
report_memory('(after {} iterations)'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
Expand Down
14 changes: 14 additions & 0 deletions report_theoretical_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Computes theoretical memory footprint for model training without instantiating
a model and running training iterations on GPU(s)."""

from megatron import get_args
from megatron.initialize import initialize_megatron
from megatron.theoretical_memory_usage import report_theoretical_memory

if __name__ == "__main__":
initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True)
args = get_args()

report_theoretical_memory(args, verbose=True)

0 comments on commit 744adfc

Please sign in to comment.