Skip to content

Commit

Permalink
Functionality tests and ReadME
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi Wan committed Nov 20, 2023
1 parent a99d554 commit 425b42e
Show file tree
Hide file tree
Showing 33 changed files with 391 additions and 69 deletions.
68 changes: 68 additions & 0 deletions ZeroBubble.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Zero Bubble Pipeline Parallelism

Zero Bubble Pipeline Parallelism is a novel pipeline parallelism algorithm able to reduce the bubble of pipeline parallelism to almost zero.


**Quick settings to enable Zero Bubble:**
```
--zero-bubble-v-schedule
--allow-padding-num-layers
--enable-optimizer-post-validation
```

**Acceleration**
Experiments shows zero bubble pipeline parallelism can accelerate training up to 30% with a similar memory comsumption. A detailed table of experiments is coming soon.

**Notices**
* ZBV schedule requires the number of layers per pipeline to be an even number, so that each stage can be splited into two virtual stages evenly.
* To achieve a better throughput, we recommend setting `--num-layers` to a value to `k * pipeline-model-parallel-size - 2` where k can be any value $\gt1$. This is used to compensate for the additional embedding layer on the first/last pipeline stages which could otherwise brings bubble to all other stages.

## Zero Bubble Schedules
The key of achieving zero bubble is to breaking a backward pass into a $B$ pass and $W$ pass. $B$ on one stage will only depend on the $B$ on its next stage, compared to depending on both $B$ and $W$ of in 1F1B.
![image](https://hackmd.io/_uploads/Bkc7CL7N6.png)

### Comparision of Schedules
* 1F1B
![image](https://hackmd.io/_uploads/Hkq-gD7N6.png)
* ZB1P
![image](https://hackmd.io/_uploads/Hy2GxwmEa.png)
* ZB2P
![image](https://hackmd.io/_uploads/S10QgvmV6.png)
* ZBV
![image](https://hackmd.io/_uploads/HkCgLKEET.png)




| | 1F1B | ZB1P | ZB2P | ZBV (Recommended) |
| ----------------------------------------------------- | ------- | -------- | ---- | --- |
| Bubble Rate | $p-1/m$ | $p-1/3m$ | 0 | 0 |
| Activation Memory <br> (Compared to 1F1B) | 1x | 1x | 2x | 1x |
| Pipeline Communication Volume <br> (Compared to 1F1B) | 1x | 1x | 1x | 2x |



<p style="font-size:14px;margin-bottom:0;height:20px;">* p: number of pipeline stages; m: number of microbatches</p>
<p style="font-size:14px;margin-bottom:0;height:20px;">* Assuming T<sub>F</sub> = T<sub>B</sub> = T<sub>W</sub></p>
<p style="font-size:14px;margin-bottom:0;height:20px;">* Communication volume of DP and TP stays the same</p>


## Zero Bubble Command Line Arguments

* `--enable-zero-bubble` Enables zero bubble schedules.
* `--zero-bubble-v-schedule` Enables ZBV schedule recommended above. Implies `--enable-zero-bubble`.
* `--enable-optimizer-post-validation` Enables optimizer post validation explained in [Optimizer Post Validation](#Optimizer-Post-Validation)
* `--allow-padding-num-layers` Allowing the number of layers to NOT be a mutiple of number of Pipelines. This allows us to have one less layer on the first and last pipeline stage to compensate for the bubble caused by embedding layers.
* `--zero-bubble-max-pending-backward` Controls memory limit of zero bubble schedules. Setting this to 1 x number of pipelines will get a schedule like ZB1P while setting to 2x number of pipelines will get ZB2P. No effect for ZBV schedule enabled by `--zero-bubble-v-schedule`.
* `--zero-bubble-pipeline-timers-start-iter` and `--zero-bubble-pipeline-timers-end-iter` Used to control the start/end iterations when ZB scheduler profiles each F/B/W to measure $T_F$, $T_B$ and $T_W$

## Optimizer Post Validation

In most practices of PP there's an all-reduce cross all pipeline stages for numerical robustness, e.g. global gradient norm for gradient clipping. INF/NAN check for mixed precision training, etc. This all-reduce breaks parallelogram and makes zero bubble impossible.
Under the observation that during a stable training both the gradient clipping and INF/NAN rarely triggers, we replace the before-hand synchronizations with a post update validation.

![image](https://hackmd.io/_uploads/B16R3q4N6.png)

We eagerly step the optimizers assuming the grad cliping, INF/NAN conditions are not triggered. In case an amendment to the gradient is required, a rollback will be issued and then we redo the optimizer step based on the fully reduced global state.

To enable this feature, add `--enable-optimizer-post-validation`. Experiments shows NOT enabling this will cause ~8% performance loss.
24 changes: 16 additions & 8 deletions examples/pretrain_llama_7b.sh → examples/pretrain_zero_bubble.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,16 @@ if [ -z "$ZERO_BUBBLE_TIMER_START" ]; then
ZERO_BUBBLE_TIMER_END=110
fi

if [ -z "$EVAL_INTERVAL" ]; then
EVAL_INTERVAL=10000
fi

if [ -z "$TP_SIZE" ]; then
TP_SIZE=1
fi

options=" \
--tensor-model-parallel-size 1 \
--tensor-model-parallel-size $TP_SIZE \
--pipeline-model-parallel-size $PIPELINE_SIZE \
--num-layers $LAYERS \
--hidden-size $HIDDEN_SIZE \
Expand All @@ -70,7 +78,7 @@ options=" \
--lr-decay-style cosine \
--log-interval 10 \
--eval-iters 40 \
--eval-interval 10000 \
--eval-interval $EVAL_INTERVAL \
--data-path ${DATASET} \
--tokenizer-type GPTSentencePieceTokenizer \
--tokenizer-model /tokenizers/tokenizer.model \
Expand All @@ -84,24 +92,24 @@ options=" \
--profile-step-start 150 \
--profile-step-end 170 \
--profile-ranks $profile_ranks \
--allow-padding-num-layers \
--enable-optimizer-post-validation \
--fp16"
--allow-padding-num-layers"

if [ -z "$FP32" ]; then
options="$options --fp16 \
--enable-optimizer-post-validation "
fi

if [ ! -z "$PROFILED" ]; then
options="$options --profile"
fi

if [ ! -z "$ZERO_BUBBLE_V_SCHEDULE" ]; then
ENABLE_ZERO_BUBBLE=1
options="$options --zero-bubble-v-schedule \
--num-layers-per-virtual-pipeline-stage $(( $(($LAYERS + 2)) / $PIPELINE_SIZE / 2 ))"
options="$options --zero-bubble-v-schedule "
fi

if [ ! -z "$ENABLE_ZERO_BUBBLE" ]; then
options="$options --enable-zero-bubble \
--zero-bubble-pipeline-start-iter 100 \
--zero-bubble-pipeline-timers-start-iter $ZERO_BUBBLE_TIMER_START \
--zero-bubble-pipeline-timers-end-iter $ZERO_BUBBLE_TIMER_END \
--zero-bubble-max-pending-backward $ZERO_BUBBLE_MEM_LIMIT"
Expand Down
14 changes: 11 additions & 3 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,22 @@ def validate_args(args, defaults={}):

# TODO: validate more
if args.zero_bubble_v_schedule:
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'number of layers should be divisible by the pipeline parallel size'
num_layers_per_pipeline_stage = args.num_layers // args.transformer_pipeline_model_parallel_size
assert num_layers_per_pipeline_stage % 2 == 0, \
'zero bubble v schedule requires number of layers per pipeline stage to be even'
assert args.num_layers_per_virtual_pipeline_stage is None, \
'num_layers_per_virtual_pipeline_stage should not be set with zero bubble v schedule'
args.virtual_pipeline_model_parallel_size = 2
args.num_layers_per_virtual_pipeline_stage = num_layers_per_pipeline_stage // 2
assert args.virtual_pipeline_model_parallel_size == 2
args.enable_zero_bubble = True
if args.enable_zero_bubble:
assert not args.overlap_grad_reduce, "not supported yet"
assert args.pipeline_model_parallel_size > 1, "zero bubble must be enabled with pipeline parallelism"
if args.enable_optimizer_post_validation:
assert args.fp16, "zero bubble post validation"
else:
args.enable_optimizer_post_validation = False

Expand Down Expand Up @@ -1100,9 +1111,6 @@ def _add_mixed_precision_args(parser):

def _add_zero_bubble_args(parser):
group = parser.add_argument_group(title='zero bubble')
group.add_argument('--zero-bubble-pipeline-start-iter',
type=int, default=1000,
help='The starting iteration that skips all sync cross pipeline parallel groups')
group.add_argument('--zero-bubble-pipeline-timers-start-iter',
type=int, default=100,
help='The starting iteration that start timers for auto scheduling of zero-bubble pipeline parallel')
Expand Down
11 changes: 9 additions & 2 deletions megatron/core/pipeline_parallel/auto_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,11 +700,18 @@ def auto_schedule(nstages, nmb, config):
# auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=10))
# auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=14))
auto_schedule(24, 72, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=100))
auto_schedule(8, 24, GraphConfig(
auto_schedule(4, 12, GraphConfig(
cost_f=5478,
cost_b=5806,
cost_w=3534,
cost_comm=200,
max_mem=32,
print_scaling=1000
))
))
auto_schedule(32, 16, GraphConfig(
cost_f=1,
cost_b=1,
cost_w=1,
cost_comm=0,
max_mem=64,
))
36 changes: 27 additions & 9 deletions megatron/core/pipeline_parallel/zb_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def add_communication(
next_is_comm: bool,
next_compute: auto_schedule.ScheduledNode
):
if self.forward_only and 'BACKWARD' in scheduled_node.type:
return
self.communication_batch[self.direction_map(scheduled_node)].append(
(scheduled_node, self.tensor_shape))
def is_consumer(scheduled_node, next_compute):
Expand All @@ -286,7 +288,7 @@ def is_consumer(scheduled_node, next_compute):
if scheduled_node.type == 'RECV_BACKWARD' and next_compute.type == 'B':
return True
return False
if (next_compute is not None and is_consumer(scheduled_node, next_compute)) or not next_is_comm:
if (next_compute is not None and is_consumer(scheduled_node, next_compute)) or not next_is_comm or self.forward_only:
self.flush()

def schedule_f(self, scheduled_node):
Expand Down Expand Up @@ -528,8 +530,8 @@ def run(self):
# embedding all-reduce for pipeline parallelism).
self.config.finalize_model_grads_func(self.model)

if get_args().zero_bubble_pipeline_timers_end_iter == ScheduleTimers.iter_counter:
ScheduleTimers.concluded = True
if get_args().zero_bubble_pipeline_timers_end_iter == ScheduleTimers.iter_counter:
ScheduleTimers.concluded = True

return self.forward_data_store

Expand Down Expand Up @@ -626,8 +628,8 @@ def multi_no_sync():
decoder_seq_length=decoder_seq_length,
config=config,
)[0] == tensor_shape

ScheduleTimers.iter_counter += 1
if not forward_only:
ScheduleTimers.iter_counter += 1
run_timer = (
get_args().zero_bubble_pipeline_timers_end_iter
>= ScheduleTimers.iter_counter
Expand All @@ -648,6 +650,12 @@ def multi_no_sync():
self.it = 0

def __call__(self, *args, **kwargs):
if kwargs['forward_only']:
self.prepare(*args, **kwargs)
assert self.do_post_validation
self.do_post_validation = True
self.is_first_run = True
return self.run()
if not get_args().enable_optimizer_post_validation:
self.prepare(*args, **kwargs)
self.is_first_run = False
Expand Down Expand Up @@ -790,6 +798,8 @@ def add_communication(
next_is_comm: bool,
next_compute: auto_schedule.ScheduledNode
):
if self.forward_only and 'BACKWARD' in scheduled_node.type:
return
self.communication_batch[self.direction_map(scheduled_node)].append(
(scheduled_node, None))
def is_consumer(scheduled_node, next_compute):
Expand All @@ -799,7 +809,7 @@ def is_consumer(scheduled_node, next_compute):
if scheduled_node.type == 'RECV_BACKWARD' and next_compute.type == 'B':
return True
return False
if (next_compute is not None and is_consumer(scheduled_node, next_compute)) or not next_is_comm:
if (next_compute is not None and is_consumer(scheduled_node, next_compute)) or not next_is_comm or self.forward_only:
self.flush()

def schedule_f(self, scheduled_node):
Expand Down Expand Up @@ -926,7 +936,8 @@ def prepare(
no_sync_func = contextlib.nullcontext
self.no_sync_func = no_sync_func
self.no_sync_context = None
ScheduleTimers.iter_counter += 1
if not forward_only:
ScheduleTimers.iter_counter += 1

self.disable_grad_sync()

Expand Down Expand Up @@ -1061,6 +1072,7 @@ def run(self):
it = self.it
while it < len(self.schedules):
scheduled_node = self.schedules[it]
# print(f"iter {torch.distributed.get_rank()}-{it}: {scheduled_node.type}-{scheduled_node.minibatch}")
if "POST_VALIDATION" in scheduled_node.type:
pass
elif scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
Expand Down Expand Up @@ -1107,11 +1119,17 @@ def run(self):
# embedding all-reduce for pipeline parallelism).
self.config.finalize_model_grads_func([self.model])

if get_args().zero_bubble_pipeline_timers_end_iter == ScheduleTimers.iter_counter:
ScheduleTimers.concluded = True
if get_args().zero_bubble_pipeline_timers_end_iter == ScheduleTimers.iter_counter:
ScheduleTimers.concluded = True
return self.forward_data_store

def __call__(self, *args, **kwargs):
if kwargs['forward_only']:
self.prepare(*args, **kwargs)
assert self.do_post_validation
self.do_post_validation = True
self.is_first_run = True
return self.run()
if not get_args().enable_optimizer_post_validation:
self.prepare(*args, **kwargs)
self.is_first_run = False
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def param_is_not_shared(param):
def local_binary_reduction(param: torch.nn.parameter.Parameter, key):
assert param.grad is None
if key in cache:
#cache[key].add_(param)
cache[key].add_(param)
param.copy_(cache[key])
a = cache[key]
del cache[key]
Expand Down
23 changes: 13 additions & 10 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def setup_model_and_optimizer(model_provider_func,


def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
model, optimizer, opt_param_scheduler, config, next_is_eval=False):
"""Single training step."""
args = get_args()
timers = get_timers()
Expand Down Expand Up @@ -462,7 +462,7 @@ def run_forward_backward_func():
if args.enable_zero_bubble and args.enable_optimizer_post_validation:
from megatron.core.pipeline_parallel.zb_schedules import get_zb_scheduler_instance
zb_scheduler = get_zb_scheduler_instance()
if optimizer.post_validation_enabled:
if optimizer.post_validation_enabled and not next_is_eval:
optimizer.pre_step(args, timers)
zb_scheduler.optimizer = optimizer
assert not zb_scheduler.is_first_run and zb_scheduler.do_post_validation
Expand Down Expand Up @@ -797,14 +797,17 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,

update_num_microbatches(args.consumed_train_samples)
args.curr_iteration = iteration

iteration += 1
do_eval = args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
iteration += 1
config, do_eval)

args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
Expand All @@ -827,8 +830,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
opt_param_scheduler)

# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
if do_eval:
if args.manual_gc and args.manual_gc_eval:
# Collect all objects.
gc.collect()
Expand Down Expand Up @@ -930,7 +932,6 @@ def evaluate(forward_step_func,
eval_batch_size = args.global_batch_size
eval_num_microbatches = eval_batch_size // \
(args.micro_batch_size * args.data_parallel_size)

with torch.no_grad():
iteration = 0
if verbose:
Expand Down Expand Up @@ -1048,9 +1049,11 @@ def evaluate_and_print_results(prefix, forward_step_func,
process_non_loss_data_func(collected_non_loss_data, iteration, writer)

length = len(string) + 1
print_rank_last('-' * length)
print_rank_last(string)
print_rank_last('-' * length)
pfunc=print_rank_0 if get_args().zero_bubble_v_schedule else print_rank_last

pfunc('-' * length)
pfunc(string)
pfunc('-' * length)


def cyclic_iter(iter):
Expand Down
File renamed without changes.
15 changes: 15 additions & 0 deletions tests/zerobubble_tests/0_test_pp_1f1b_eval_exact.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

source $(dirname "${BASH_SOURCE[0]}")/commons.sh
setup;

export WORLD_SIZE_IN_GPUS=8
export GLOBAL_BATCH_SIZE=24
export PIPELINE_SIZE=8
export EVAL_INTERVAL=100
export AIP_RUN_NAME=$(basename $0 | cut -d '.' -f 1)

export ENABLE_EXACTLY_NUMERIC_MATCH=1
launch

check_loss "7.073670E+00"
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ export ENABLE_EXACTLY_NUMERIC_MATCH=1

launch

check_loss "$(loss_of test_pp_1f1b_exact)"
check_loss "7.073670E+00"
Loading

0 comments on commit 425b42e

Please sign in to comment.