Skip to content

Commit

Permalink
support padding for num_layers
Browse files Browse the repository at this point in the history
  • Loading branch information
QPH-SAIL authored and Xinyi Wan committed Nov 20, 2023
1 parent 3763fbd commit cff01bf
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 26 deletions.
5 changes: 3 additions & 2 deletions examples/pretrain_llama_7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ WORLD_SIZE_IN_GPUS=$(( $WORLD_SIZE * $GPUS_PER_NODE ))

if [ -z "$PIPELINE_SIZE" ]; then
PIPELINE_SIZE=$(( $WORLD_SIZE_IN_GPUS))
LAYERS=$(( $PIPELINE_SIZE * 4))
LAYERS=$(( $PIPELINE_SIZE * 4 - 2))
MICRO_BATCH_SIZE=1
GLOBAL_BATCH_SIZE=$(( $PIPELINE_SIZE * 3 * $MICRO_BATCH_SIZE ))
HIDDEN_SIZE=4096
Expand Down Expand Up @@ -86,6 +86,7 @@ options=" \
--profile-step-start 150 \
--profile-step-end 170 \
--profile-ranks $profile_ranks \
--allow-padding-num-layers \
--fp16"


Expand All @@ -96,7 +97,7 @@ fi
if [ ! -z "$ZERO_BUBBLE_V_SCHEDULE" ]; then
ENABLE_ZERO_BUBBLE=1
options="$options --zero-bubble-v-schedule \
--num-layers-per-virtual-pipeline-stage $(( $LAYERS / $PIPELINE_SIZE / 2 ))"
--num-layers-per-virtual-pipeline-stage $(( $(($LAYERS + 2)) / $PIPELINE_SIZE / 2 ))"
fi

if [ ! -z "$ENABLE_ZERO_BUBBLE" ]; then
Expand Down
32 changes: 23 additions & 9 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,26 @@ def validate_args(args, defaults={}):
else:
setattr(args, key, defaults[key])

if args.num_layers is not None:
assert args.encoder_num_layers is None, \
'cannot have both num-layers and encoder-num-layers specified'
args.encoder_num_layers = args.num_layers
else:
assert args.encoder_num_layers is not None, \
'either num-layers or encoder-num-layers should be specified'
args.num_layers = args.encoder_num_layers

remainder = args.num_layers % args.pipeline_model_parallel_size
if args.allow_padding_num_layers and remainder > 0:
assert not args.standalone_embedding_stage, "not support standalone embedding stage if allow_padding_num_layers is true"
# pad num_layers to make num_layers % pipeline_model_parallel_size == 0
num_layers_with_padding = args.num_layers - remainder + args.pipeline_model_parallel_size
else:
num_layers_with_padding = args.num_layers
args.num_layers_without_padding = args.num_layers
args.num_layers = num_layers_with_padding
args.encoder_num_layers = num_layers_with_padding

# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
Expand Down Expand Up @@ -246,15 +266,6 @@ def validate_args(args, defaults={}):
'can only specify one of lr-warmup-fraction ' \
'and lr-warmup-samples'

if args.num_layers is not None:
assert args.encoder_num_layers is None, \
'cannot have both num-layers and encoder-num-layers specified'
args.encoder_num_layers = args.num_layers
else:
assert args.encoder_num_layers is not None, \
'either num-layers or encoder-num-layers should be specified'
args.num_layers = args.encoder_num_layers

# Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
Expand Down Expand Up @@ -1114,6 +1125,9 @@ def _add_zero_bubble_args(parser):
group.add_argument('--zero-bubble-v-schedule', action='store_true',
help='Use zero bubble v schedule pipeline. This method achieves zero-bubble without more memory overhead',
dest='zero_bubble_v_schedule')
group.add_argument('--allow-padding-num-layers', action='store_true',
help='Allow padding num_layers for pipeline parallelism',
dest='allow_padding_num_layers')
return parser


Expand Down
3 changes: 2 additions & 1 deletion megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class TransformerConfig(ModelParallelConfig):

# model architecture
num_layers: int = 0
num_layers_without_padding: int = 0
hidden_size: int = 0
num_attention_heads: int = 0
num_query_groups: int = None
Expand Down Expand Up @@ -285,5 +286,5 @@ def __post_init__(self):
if self.output_layer_init_method is None:
# TODO
self.output_layer_init_method = scaled_init_method_normal(
self.init_method_std, self.num_layers - 2
self.init_method_std, self.num_layers_without_padding
)
67 changes: 53 additions & 14 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,12 +1445,12 @@ def __init__(self, config,
self.checkpoint_core_attention = config.recompute_granularity == 'selective'

# Number of layers.
self.num_layers = _get_num_layers(args, model_type,
layer_type==LayerType.decoder)
num_layers_per_stage_with_padding = _get_num_layers(args, model_type, layer_type==LayerType.decoder)
self.num_layers = num_layers_per_stage_with_padding

self.drop_path_rates = [
rate.item() for rate in
torch.linspace(0, self.drop_path_rate, config.num_layers - 2)]
torch.linspace(0, self.drop_path_rate, config.num_layers_without_padding)]

self.retro_layer_numbers = None
if model_type == ModelType.retro_decoder:
Expand Down Expand Up @@ -1540,29 +1540,20 @@ def build_layer(layer_number):
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
else:
offset = config.num_layers - (mpu.get_pipeline_model_parallel_rank() + 1) * self.num_layers
if offset != 0:
offset -= 1
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
mpu.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
if offset != 0:
offset -= 1
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
if offset != 0:
offset -= 1
# TODO
if mpu.is_pipeline_last_stage():
self.num_layers -= 1
if mpu.is_pipeline_first_stage():
self.num_layers -= 1
if args.allow_padding_num_layers:
self.num_layers, offset = self.get_offset(config, args, num_layers_per_stage_with_padding)
print(f'num layers on rank {torch.distributed.get_rank()}: {self.num_layers} offset: {offset}')
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
Expand Down Expand Up @@ -1594,6 +1585,54 @@ def build_layer(layer_number):
# Final layer norm before output.
self.final_norm = get_norm(config)

@staticmethod
def get_offset(config, args, num_layers_per_stage_with_padding):
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
pipeline_world_size = mpu.get_pipeline_model_parallel_world_size()
if config.virtual_pipeline_model_parallel_size is not None:
assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert num_layers_per_stage_with_padding % config.virtual_pipeline_model_parallel_size == 0
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
num_layers_per_chunk = num_layers_per_stage_with_padding // config.virtual_pipeline_model_parallel_size
num_chunk = pipeline_world_size * config.virtual_pipeline_model_parallel_size
chunk_sizes = [num_layers_per_chunk] * num_chunk
num_padding = args.num_layers - args.num_layers_without_padding
for _index in range(-1, num_padding - 1):
chunk_sizes[_index] -= 1

virtual_rank = mpu.get_virtual_pipeline_model_parallel_rank()
if args.zero_bubble_v_schedule:
assert config.virtual_pipeline_model_parallel_size == 2
if virtual_rank == 0:
chunk_index = pipeline_rank
else:
chunk_index = 2 * pipeline_world_size - pipeline_rank - 1
else:
chunk_index = virtual_rank * pipeline_world_size + pipeline_rank
num_layers = chunk_sizes[chunk_index]
offset = 0
for _index in range(chunk_index):
offset += chunk_sizes[_index]
else:
# Each stage gets a contiguous set of layers.
rank_sizes = [num_layers_per_stage_with_padding] * pipeline_world_size
num_padding = args.num_layers - args.num_layers_without_padding
for _index in range(-1, num_padding - 1):
rank_sizes[_index] -= 1
if args.model_type == ModelType.encoder_and_decoder and \
mpu.get_pipeline_model_parallel_world_size() > 1:
assert False, "Not support yet"
else:
num_layers = rank_sizes[pipeline_rank]
offset = 0
for _index in range(pipeline_rank):
offset += rank_sizes[_index]
return num_layers, offset

def _get_layer(self, layer_number):
return self.layers[layer_number]

Expand Down

0 comments on commit cff01bf

Please sign in to comment.