From e96bc84b85837bd09bf561677ced915185c760f8 Mon Sep 17 00:00:00 2001 From: "panxuchen.pxc" Date: Tue, 12 Dec 2023 20:00:17 +0800 Subject: [PATCH 01/12] tune EE only --- megatron/arguments.py | 2 ++ megatron/core/pipeline_parallel/schedules.py | 23 ++++++++++--------- megatron/model/transformer.py | 24 ++++++++++++++------ megatron/training.py | 19 +++++++++------- 4 files changed, 42 insertions(+), 26 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index cc853b95..86a160d0 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1257,6 +1257,8 @@ def _add_multi_exit_args(parser): group.add_argument('--num-fill-warmup-microbatches', type=int, default=None) group.add_argument('--num-fill-cooldown-microbatches', type=int, default=None) group.add_argument('--backward-forward-ratio', type=float, default=2.0) + group.add_argument('--tune-exit', action='store_true', + help='Only finetune early exit parameters.') group.add_argument('--use-dynamic-exit-layer-weight', action='store_true') return parser diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index b3bbf224..f43c61c7 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -2123,18 +2123,19 @@ def early_exit_backward_step(input_tensor, output_tensor, output_tensor_grad, co # Backward pass. if output_tensor_grad[0] is None and config.grad_scale_func is not None: output_tensor[0] = config.grad_scale_func(output_tensor[0]) - if early_exit_loss is not None: - if output_tensor_grad[0] is not None: - fake_loss = early_exit_loss + torch.sum(output_tensor[0] * output_tensor_grad[0]) - elif output_tensor[0].numel() == 1: - fake_loss = early_exit_loss + output_tensor[0] + with torch.enable_grad(): + if early_exit_loss is not None: + if output_tensor_grad[0] is not None: + fake_loss = early_exit_loss + torch.sum(output_tensor[0] * output_tensor_grad[0]) + elif output_tensor[0].numel() == 1: + fake_loss = early_exit_loss + output_tensor[0] + else: + fake_loss = early_exit_loss + custom_backward(fake_loss, None) + elif config.deallocate_pipeline_outputs: + custom_backward(output_tensor[0], output_tensor_grad[0]) else: - fake_loss = early_exit_loss - custom_backward(fake_loss, None) - elif config.deallocate_pipeline_outputs: - custom_backward(output_tensor[0], output_tensor_grad[0]) - else: - torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) + torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) # Collect the grad of the input_tensor. input_tensor_grad = [None] diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 09614b9c..db85e4ea 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1283,6 +1283,7 @@ def __init__(self, config, self.use_exit_mlp = args.use_exit_mlp self.use_exit_norm = args.use_exit_norm self.use_exit_block = args.use_exit_block + self.tune_exit = args.tune_exit self.exit_layer_temperature = args.exit_layer_temperature[args.exit_layer_nums.index(self.layer_number)] self.exit_output_weight = None @@ -1426,10 +1427,17 @@ def _forward_main(self, hidden_states, attention_mask, return output # exit MLP. - exit_output = self._forward_mlp(mlp=self.mlp.branch, - norm_output=norm_output, - residual=residual, - bias_dropout_add_func=bias_dropout_add_func) + if self.tune_exit: + exit_output = partial(self._forward_mlp, + mlp=self.mlp.branch, + norm_output=norm_output, + residual=residual, + bias_dropout_add_func=bias_dropout_add_func) + else: + exit_output = self._forward_mlp(mlp=self.mlp.branch, + norm_output=norm_output, + residual=residual, + bias_dropout_add_func=bias_dropout_add_func) return output, exit_output def _cal_exit_loss(self, hidden_states, exit_process_func, exit_loss_func, @@ -1470,8 +1478,9 @@ def _forward_exit(self, hidden_states, exit_process_func, exit_loss_func, hidden_states=hidden_states, exit_process_func=exit_process_func, exit_loss_func=exit_loss_func, - lazy_hidden_states=False) - return lazy_exit_forward_func, False + lazy_hidden_states=self.tune_exit and self.use_exit_mlp) + exit = self.tune_exit and (self.layer_number == mpu.get_early_exit_layer_nums()[-1]) and not mpu.post_stage_has_early_exit() + return lazy_exit_forward_func, exit def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, @@ -2057,6 +2066,7 @@ def __init__(self, config, drop_path_rate ) self.exit_states = list(map(lambda x: x in mpu.get_early_exit_layer_nums(), self.layer_nums)) + self.tune_exit = get_args().tune_exit def _build_layer(self, layer_number, args, config, model_type, layer_type, self_attn_mask_type): @@ -2141,7 +2151,7 @@ def forward(self, hidden_states, attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb) - if torch.is_grad_enabled() and self.training: + if (torch.is_grad_enabled() or self.tune_exit) and self.training: self.microbatch_count += 1 if self.post_process and self.post_norm: diff --git a/megatron/training.py b/megatron/training.py index c4cb06bd..74fb68d6 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -767,14 +767,17 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, update_num_microbatches(args.consumed_train_samples) args.curr_iteration = iteration - - loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ - train_step(forward_backward_func, - train_data_iterator, - model, - optimizer, - opt_param_scheduler, - config) + context = nullcontext + if args.tune_exit: + context = torch.no_grad + with context(): + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ + train_step(forward_backward_func, + train_data_iterator, + model, + optimizer, + opt_param_scheduler, + config) iteration += 1 args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.micro_batch_size * \ From a42ffe3cb0147b12c79564a960cb76273374fde3 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 13 Dec 2023 04:59:50 +0000 Subject: [PATCH 02/12] load_ee --- megatron/arguments.py | 8 +++++--- megatron/initialize.py | 30 +++++++++++++++++++++--------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 86a160d0..539df7bb 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -41,7 +41,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): parser = _add_transformer_engine_args(parser) parser = _add_retro_args(parser) parser = _add_experimental_args(parser) - parser = _add_multi_exit_args(parser) + parser = _add_early_exit_args(parser) # Custom arguments. if extra_args_provider is not None: @@ -1229,7 +1229,7 @@ def _add_data_args(parser): return parser -def _add_multi_exit_args(parser): +def _add_early_exit_args(parser): group = parser.add_argument_group(title='multexit') group.add_argument('--exit-layer-nums', type=int, nargs='+', default=[], @@ -1257,9 +1257,11 @@ def _add_multi_exit_args(parser): group.add_argument('--num-fill-warmup-microbatches', type=int, default=None) group.add_argument('--num-fill-cooldown-microbatches', type=int, default=None) group.add_argument('--backward-forward-ratio', type=float, default=2.0) + group.add_argument('--use-dynamic-exit-layer-weight', action='store_true') group.add_argument('--tune-exit', action='store_true', help='Only finetune early exit parameters.') - group.add_argument('--use-dynamic-exit-layer-weight', action='store_true') + group.add_argument('--tune-exit-tensor-parallel-size', type=int, default=None) + group.add_argument('--tune-exit-pipeline-parallel-size', type=int, default=None) return parser diff --git a/megatron/initialize.py b/megatron/initialize.py index 9f1cfe4f..26ae323c 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -207,15 +207,27 @@ def _initialize_distributed(): if mpu.model_parallel_is_initialized(): print("model parallel is already initialized") else: - mpu.initialize_model_parallel( - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, - args.virtual_pipeline_model_parallel_size, - args.pipeline_model_parallel_split_rank, - expert_model_parallel_size=args.expert_model_parallel_size, - num_layers=args.num_layers, - early_exit_layer_nums=args.exit_layer_nums - ) + if args.tune_exit and args.tune_exit_tensor_parallel_size is not None \ + and args.tune_exit_pipeline_parallel_size is not None: + mpu.initialize_model_parallel( + args.tune_exit_tensor_parallel_size, + args.tune_exit_pipeline_parallel_size, + args.virtual_pipeline_model_parallel_size, + args.pipeline_model_parallel_split_rank, + expert_model_parallel_size=args.expert_model_parallel_size, + num_layers=args.num_layers / (args.pipeline_model_parallel_size / args.tune_exit_pipeline_parallel_size), + early_exit_layer_nums=args.exit_layer_nums + ) + else: + mpu.initialize_model_parallel( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.virtual_pipeline_model_parallel_size, + args.pipeline_model_parallel_split_rank, + expert_model_parallel_size=args.expert_model_parallel_size, + num_layers=args.num_layers, + early_exit_layer_nums=args.exit_layer_nums + ) if args.rank == 0: print( f"> initialized tensor model parallel with size " From b70b62745696237263d5b1443d6a0a578f036986 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 13 Dec 2023 11:48:35 +0000 Subject: [PATCH 03/12] tune EE only with partial checkpoint loading --- megatron/arguments.py | 24 ++-- megatron/checkpointing.py | 2 +- megatron/core/parallel_state.py | 34 ++++- megatron/core/pipeline_parallel/schedules.py | 123 ++++++++++++++++++- megatron/initialize.py | 11 +- megatron/model/transformer.py | 2 + megatron/training.py | 5 +- pretrain_early_exit_gpt.py | 3 +- 8 files changed, 184 insertions(+), 20 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 539df7bb..8d1ab6a3 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -67,9 +67,10 @@ def validate_args(args, defaults={}): ' ({}) is not divisible by tensor model parallel size ({})'.format( args.world_size, args.tensor_model_parallel_size) # Pipeline model parallel size. - args.pipeline_model_parallel_size = min( - args.pipeline_model_parallel_size, - (args.world_size // args.tensor_model_parallel_size)) + if not args.tune_exit: + args.pipeline_model_parallel_size = min( + args.pipeline_model_parallel_size, + (args.world_size // args.tensor_model_parallel_size)) args.transformer_pipeline_model_parallel_size = ( args.pipeline_model_parallel_size - 1 if args.standalone_embedding_stage else @@ -78,11 +79,14 @@ def validate_args(args, defaults={}): # Checks. model_parallel_size = args.pipeline_model_parallel_size * \ args.tensor_model_parallel_size - assert args.world_size % model_parallel_size == 0, 'world size ({}) is not'\ - ' divisible by tensor parallel size ({}) times pipeline parallel ' \ - 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, - args.pipeline_model_parallel_size) - args.data_parallel_size = args.world_size // model_parallel_size + if not args.tune_exit: + assert args.world_size % model_parallel_size == 0, 'world size ({}) is not'\ + ' divisible by tensor parallel size ({}) times pipeline parallel ' \ + 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, + args.pipeline_model_parallel_size) + args.data_parallel_size = args.world_size // model_parallel_size + else: + args.data_parallel_size = args.world_size // (args.tensor_model_parallel_size * args.tune_exit_pipeline_parallel_size) if args.rank == 0: print('using world size: {}, data-parallel-size: {}, ' 'tensor-model-parallel size: {}, ' @@ -400,6 +404,9 @@ def validate_args(args, defaults={}): if not args.add_position_embedding and args.position_embedding_type != 'rope': raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type') + if args.position_embedding_type == 'rope': + args.add_position_embedding = False + # MoE Spec check if args.num_experts is not None: assert args.model_spec is None, "Model Spec must be None when using MoEs" @@ -1260,7 +1267,6 @@ def _add_early_exit_args(parser): group.add_argument('--use-dynamic-exit-layer-weight', action='store_true') group.add_argument('--tune-exit', action='store_true', help='Only finetune early exit parameters.') - group.add_argument('--tune-exit-tensor-parallel-size', type=int, default=None) group.add_argument('--tune-exit-pipeline-parallel-size', type=int, default=None) return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 14cde010..efe7a2c7 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -89,7 +89,7 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, # Use both the tensor and pipeline MP rank. if pipeline_parallel is None: - pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1) + pipeline_parallel = mpu.has_pipeline_parallel() if tensor_rank is None: tensor_rank = mpu.get_tensor_model_parallel_rank() if pipeline_rank is None: diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 38743e53..93dbaebe 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -80,6 +80,10 @@ _EARLY_EXIT_STAGES = None +_TUNE_EXIT = False + +_FULL_EXIT_PIPELINE_PARALLEL_SIZE = None + _EMBEDDING_STAGES = None def initialize_model_parallel( @@ -92,6 +96,8 @@ def initialize_model_parallel( expert_model_parallel_size: int = 1, num_layers: Optional[int] = None, early_exit_layer_nums: Optional[List[int]] = None, + tune_exit: Optional[bool] = False, + full_exit_pipeline_parallel_size: Optional[int] = None, ) -> None: """Initialize model data parallel groups. @@ -342,7 +348,17 @@ def initialize_model_parallel( assert _EARLY_EXIT_LAYER_NUMS is None, 'early exit layer nums is already initialized' global _EARLY_EXIT_STAGES assert _EARLY_EXIT_STAGES is None, 'early exit stages is already initialized' - layer_per_stage = num_layers / pipeline_model_parallel_size + global _FULL_EXIT_PIPELINE_PARALLEL_SIZE + assert _FULL_EXIT_PIPELINE_PARALLEL_SIZE is None, 'full exit pipeline parallel size is already initialized' + global _TUNE_EXIT + _TUNE_EXIT = tune_exit + if tune_exit: + if full_exit_pipeline_parallel_size is None: + full_exit_pipeline_parallel_size = pipeline_model_parallel_size + layer_per_stage = num_layers // full_exit_pipeline_parallel_size + _FULL_EXIT_PIPELINE_PARALLEL_SIZE = full_exit_pipeline_parallel_size + else: + layer_per_stage = num_layers // pipeline_model_parallel_size _EARLY_EXIT_STAGES = list(set(map(lambda layer_num: int((layer_num - 1) // layer_per_stage), early_exit_layer_nums))) for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) @@ -921,6 +937,18 @@ def has_early_exit(): """Return true if pipeline stage has early exit output""" return _EARLY_EXIT_LAYER_NUMS != None and len(_EARLY_EXIT_LAYER_NUMS) > 0 +def is_tune_exit(): + return _TUNE_EXIT + +def has_pipeline_parallel(): + if _TUNE_EXIT: + return _FULL_EXIT_PIPELINE_PARALLEL_SIZE > 1 + else: + return get_pipeline_model_parallel_world_size() > 1 + +def is_real_pipeline_last_stage_in_tune_exit(): + return get_pipeline_model_parallel_rank() == (_FULL_EXIT_PIPELINE_PARALLEL_SIZE - 1) + def get_early_exit_layer_nums(): return _EARLY_EXIT_LAYER_NUMS @@ -1021,3 +1049,7 @@ def destroy_model_parallel(): _EARLY_EXIT_LAYER_NUMS = None global _EARLY_EXIT_STAGES _EARLY_EXIT_STAGES = None + global _FULL_EXIT_PIPELINE_PARALLEL_SIZE + _FULL_EXIT_PIPELINE_PARALLEL_SIZE = None + global _TUNE_EXIT + _TUNE_EXIT = False \ No newline at end of file diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index f43c61c7..775ae694 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -113,6 +113,8 @@ def forward_step(data_iterator, model): num_fill_warmup_microbatches=args.num_fill_warmup_microbatches, num_fill_cooldown_microbatches=args.num_fill_cooldown_microbatches, early_exit_loss_weight=early_exit_loss_weight) + elif args.tune_exit: + forward_backward_func = partial(forward_backward_pipelining_for_early_exit_tuning, early_exit_loss_weight=early_exit_loss_weight) else: forward_backward_func = partial(early_exit_forward_backward_pipelining, early_exit_loss_weight=early_exit_loss_weight) elif parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: @@ -1442,8 +1444,125 @@ def early_exit_forward_backward_no_pipelining( # Finalize model grads (perform full grad all-reduce / reduce-scatter for # data parallelism and layernorm all-reduce for sequence parallelism). config.finalize_model_grads_func([model]) + if len(backward_data_store) == len(forward_data_store): + backward_data_store = [{**f, **b} for (f, b) in zip(forward_data_store, backward_data_store)] + return backward_data_store - forward_data_store = [{**f, **b} for (f, b) in zip(forward_data_store, backward_data_store)] + +def forward_backward_pipelining_for_early_exit_tuning( + *, + forward_step_func, + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, + seq_length: int, + micro_batch_size: int, + decoder_seq_length: int = None, + forward_only: bool = False, + collect_non_loss_data: bool = False, + early_exit_loss_weight: EarlyExitLossWeight = None, +): + if isinstance(model, list): + assert ( + len(model) == 1 + ), "non-interleaved pipeline parallelism does not support model chunking" + model = model[0] + if isinstance(data_iterator, list): + assert ( + len(data_iterator) == 1 + ), "non-pipeline-parallel schedule does not support model chunking" + data_iterator = data_iterator[0] + + config = get_model_config(model) + if config.overlap_p2p_comm: + raise ValueError( + "Non-interleaved pipeline parallelism does not support overlapping p2p communication" + ) + + if config.timers is not None: + config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + + # Disable async grad reductions + no_sync_func = config.no_sync_func + if no_sync_func is None: + no_sync_func = contextlib.nullcontext + no_sync_context = None + has_early_exit = parallel_state.has_early_exit() + + def disable_grad_sync(): + """Disable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is None: + no_sync_context = no_sync_func() + no_sync_context.__enter__() + + def enable_grad_sync(): + """Enable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is not None: + no_sync_context.__exit__(None, None, None) + no_sync_context = None + + disable_grad_sync() + + if early_exit_loss_weight: + early_exit_loss_weight.update() + + + model_type = get_model_type(model) + + rank = parallel_state.get_pipeline_model_parallel_rank() + recv_tensor_shapes = get_tensor_shapes( + rank=rank - 1, + model_type=model_type, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + config=config, + ) + send_tensor_shapes = get_tensor_shapes( + rank=rank, + model_type=model_type, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + config=config, + ) + + forward_data_store = [] + + # Run warmup forward passes. + for i in range(num_microbatches): + input_tensor = recv_forward(recv_tensor_shapes, config) + output_tensor, early_exit_loss_func = early_exit_forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + ) + send_forward(output_tensor, send_tensor_shapes, config) + + if has_early_exit and not forward_only: + if i == num_microbatches - 1: + if config.grad_sync_func is None or rank == 0: + enable_grad_sync() + exit_loss = cal_early_exit_loss(early_exit_loss_func, forward_data_store, num_microbatches, early_exit_loss_weight) + early_exit_backward_step(input_tensor, output_tensor, [None], config, + early_exit_loss=exit_loss + ) + + if config.timers is not None: + config.timers('forward-backward').stop() + + if config.finalize_model_grads_func is not None and not forward_only: + # Finalize model grads (perform full grad all-reduce / reduce-scatter for + # data parallelism, layernorm all-reduce for sequence parallelism, and + # embedding all-reduce for pipeline parallelism). + config.finalize_model_grads_func([model]) return forward_data_store @@ -2076,7 +2195,7 @@ def early_exit_forward_step( output_tensor = lm_output loss_dict = {} - if parallel_state.is_pipeline_last_stage(): + if parallel_state.is_pipeline_last_stage() and not parallel_state.is_tune_exit(): output_tensor = loss_func(output_tensor=output_tensor, log_dict=loss_dict, log_key='lm loss') diff --git a/megatron/initialize.py b/megatron/initialize.py index 26ae323c..f3ffa684 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -207,16 +207,17 @@ def _initialize_distributed(): if mpu.model_parallel_is_initialized(): print("model parallel is already initialized") else: - if args.tune_exit and args.tune_exit_tensor_parallel_size is not None \ - and args.tune_exit_pipeline_parallel_size is not None: + if args.tune_exit: mpu.initialize_model_parallel( - args.tune_exit_tensor_parallel_size, + args.tensor_model_parallel_size, args.tune_exit_pipeline_parallel_size, args.virtual_pipeline_model_parallel_size, args.pipeline_model_parallel_split_rank, expert_model_parallel_size=args.expert_model_parallel_size, - num_layers=args.num_layers / (args.pipeline_model_parallel_size / args.tune_exit_pipeline_parallel_size), - early_exit_layer_nums=args.exit_layer_nums + num_layers=args.num_layers, + early_exit_layer_nums=args.exit_layer_nums, + tune_exit=True, + full_exit_pipeline_parallel_size=args.pipeline_model_parallel_size ) else: mpu.initialize_model_parallel( diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index db85e4ea..7c8b32d7 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1594,6 +1594,8 @@ def _get_num_layers(args, model_type, is_decoder=False): else: if not is_decoder: num_layers = args.encoder_num_layers + if args.tune_exit: + num_layers = num_layers // args.pipeline_model_parallel_size else: num_layers = args.decoder_num_layers return num_layers diff --git a/megatron/training.py b/megatron/training.py index 74fb68d6..41fa8469 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -241,7 +241,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap model.append(this_model) else: pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() + if args.tune_exit: + post_process = mpu.is_real_pipeline_last_stage_in_tune_exit() + else: + post_process = mpu.is_pipeline_last_stage() add_encoder = True add_decoder = True if model_type == ModelType.encoder_and_decoder: diff --git a/pretrain_early_exit_gpt.py b/pretrain_early_exit_gpt.py index f8ace67b..88ec97d0 100644 --- a/pretrain_early_exit_gpt.py +++ b/pretrain_early_exit_gpt.py @@ -51,7 +51,8 @@ def get_batch(data_iterator): tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() - + # for Llama2Tokenizer + tokens.masked_fill_(tokens == 32002, 2) # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, From 5b68a7cf9efcfa34b0dd736e79e791e0d90bab98 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 13 Dec 2023 12:20:54 +0000 Subject: [PATCH 04/12] set requires_grad to False for non-EE params in tune-EE-only mode --- megatron/training.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/megatron/training.py b/megatron/training.py index 41fa8469..285dcdb3 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -217,6 +217,15 @@ def update_train_iters(args): print_rank_0('setting training iterations to {}'.format(args.train_iters)) +def is_early_exit_param(param_name): + # for exit_output_layer / exit_norm / exit_block + if 'exit' in param_name: + return True + # for branch mlp + if '.branch.' in param_name: + return True + return False + def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True): """Build the model.""" args = get_args() @@ -274,6 +283,13 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap if not isinstance(model, list): model = [model] + # tune early exit only + if args.tune_exit: + for model_module in model: + for name, param in model_module.named_parameters(): + if not is_early_exit_param(name): + param.requires_grad = False + # Disallow training and inference with Transformer Engine # for non-GPT models args.allow_transformer_engine = all([type(m) == GPTModel for m in model]) From 219b39489e9474457395114ea0ddea808a631525 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 14 Dec 2023 09:25:45 +0000 Subject: [PATCH 05/12] simplify tune EE only --- megatron/core/pipeline_parallel/schedules.py | 63 ++++++++++---------- megatron/training.py | 18 +++--- pretrain_early_exit_gpt.py | 2 +- 3 files changed, 39 insertions(+), 44 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 775ae694..173d4f1d 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -1533,18 +1533,19 @@ def enable_grad_sync(): # Run warmup forward passes. for i in range(num_microbatches): - input_tensor = recv_forward(recv_tensor_shapes, config) - output_tensor, early_exit_loss_func = early_exit_forward_step( - forward_step_func, - data_iterator, - model, - num_microbatches, - input_tensor, - forward_data_store, - config, - collect_non_loss_data, - ) - send_forward(output_tensor, send_tensor_shapes, config) + with torch.no_grad(): + input_tensor = recv_forward(recv_tensor_shapes, config) + output_tensor, early_exit_loss_func = early_exit_forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + ) + send_forward(output_tensor, send_tensor_shapes, config) if has_early_exit and not forward_only: if i == num_microbatches - 1: @@ -2139,14 +2140,13 @@ def enable_grad_sync(): def cal_early_exit_loss(early_exit_loss_funcs, forward_data_store, num_microbatches, early_exit_loss_weight): exit_loss_dict = {} exit_losses = [] - with torch.enable_grad(): - for layer_num, exit_loss_func in early_exit_loss_funcs.items(): - loss = exit_loss_func(log_dict=exit_loss_dict) - loss_weight = early_exit_loss_weight.get_weight(layer_num) - exit_losses.append(loss.multiply_(loss_weight)) - exit_loss_dict[f'exit weight [{layer_num}]'] = early_exit_loss_weight.get_weight(layer_num) - forward_data_store.append(exit_loss_dict) - return torch.sum(torch.stack(exit_losses), dim=0).div(num_microbatches) + for layer_num, exit_loss_func in early_exit_loss_funcs.items(): + loss = exit_loss_func(log_dict=exit_loss_dict) + loss_weight = early_exit_loss_weight.get_weight(layer_num) + exit_losses.append(loss.multiply_(loss_weight)) + exit_loss_dict[f'exit weight [{layer_num}]'] = early_exit_loss_weight.get_weight(layer_num) + forward_data_store.append(exit_loss_dict) + return torch.sum(torch.stack(exit_losses), dim=0).div(num_microbatches) def early_exit_forward_step( @@ -2242,19 +2242,18 @@ def early_exit_backward_step(input_tensor, output_tensor, output_tensor_grad, co # Backward pass. if output_tensor_grad[0] is None and config.grad_scale_func is not None: output_tensor[0] = config.grad_scale_func(output_tensor[0]) - with torch.enable_grad(): - if early_exit_loss is not None: - if output_tensor_grad[0] is not None: - fake_loss = early_exit_loss + torch.sum(output_tensor[0] * output_tensor_grad[0]) - elif output_tensor[0].numel() == 1: - fake_loss = early_exit_loss + output_tensor[0] - else: - fake_loss = early_exit_loss - custom_backward(fake_loss, None) - elif config.deallocate_pipeline_outputs: - custom_backward(output_tensor[0], output_tensor_grad[0]) + if early_exit_loss is not None: + if output_tensor_grad[0] is not None: + fake_loss = early_exit_loss + torch.sum(output_tensor[0] * output_tensor_grad[0]) + elif output_tensor[0].numel() == 1: + fake_loss = early_exit_loss + output_tensor[0] else: - torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) + fake_loss = early_exit_loss + custom_backward(fake_loss, None) + elif config.deallocate_pipeline_outputs: + custom_backward(output_tensor[0], output_tensor_grad[0]) + else: + torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) # Collect the grad of the input_tensor. input_tensor_grad = [None] diff --git a/megatron/training.py b/megatron/training.py index 285dcdb3..d45fa244 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -786,17 +786,13 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, update_num_microbatches(args.consumed_train_samples) args.curr_iteration = iteration - context = nullcontext - if args.tune_exit: - context = torch.no_grad - with context(): - loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ - train_step(forward_backward_func, - train_data_iterator, - model, - optimizer, - opt_param_scheduler, - config) + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ + train_step(forward_backward_func, + train_data_iterator, + model, + optimizer, + opt_param_scheduler, + config) iteration += 1 args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.micro_batch_size * \ diff --git a/pretrain_early_exit_gpt.py b/pretrain_early_exit_gpt.py index 88ec97d0..0fb3d521 100644 --- a/pretrain_early_exit_gpt.py +++ b/pretrain_early_exit_gpt.py @@ -52,7 +52,7 @@ def get_batch(data_iterator): labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # for Llama2Tokenizer - tokens.masked_fill_(tokens == 32002, 2) + tokens.masked_fill_(tokens >= 32000, 2) # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, From 7f74c31aedf5462fdf93d0569568f666bcd35434 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 18 Dec 2023 06:24:57 +0000 Subject: [PATCH 06/12] support use specific EE --- megatron/early_exit_text_generation_server.py | 13 ++++++- megatron/text_generation/api.py | 34 ++++++++++++++----- megatron/text_generation/generation.py | 10 ++++-- megatron/text_generation/inference_params.py | 7 +++- tools/request_client.py | 13 ++++--- 5 files changed, 59 insertions(+), 18 deletions(-) diff --git a/megatron/early_exit_text_generation_server.py b/megatron/early_exit_text_generation_server.py index 16c8d610..2271083d 100644 --- a/megatron/early_exit_text_generation_server.py +++ b/megatron/early_exit_text_generation_server.py @@ -59,7 +59,8 @@ async def generate(self, req): random_seed=req['random_seed'], early_exit_thres=req['early_exit_thres'], use_early_exit=req['use_early_exit'], - print_max_prob=req['print_max_prob']) + print_max_prob=req['print_max_prob'], + exit_layers=req['exit_layers']) end_time = time.time() print(f"Response(use {end_time - start_time}s): " + str(response)) return { @@ -130,6 +131,16 @@ def put(self): else: raw_req['print_max_prob'] = False + if "exit_layers" in raw_req: + if not type(raw_req['exit_layers']) == list: + return "exit_layers must be a list of int" + else: + for i in raw_req['exit_layers']: + if not type(i) == int: + return "exit_layers must be a list of int" + else: + raw_req['exit_layers'] = [] + top_k = 0.0 if "top_k" in raw_req: top_k = raw_req["top_k"] diff --git a/megatron/text_generation/api.py b/megatron/text_generation/api.py index 6d1fdba6..51eeee85 100644 --- a/megatron/text_generation/api.py +++ b/megatron/text_generation/api.py @@ -37,7 +37,8 @@ def generate_and_post_process(model, return_logits=False, early_exit_thres=1.0, use_early_exit=False, - print_max_prob=False): + print_max_prob=False, + exit_layers=[]): """Run inference and post-process outputs, i.e., detokenize, move to cpu and convert to list.""" @@ -62,7 +63,8 @@ def generate_and_post_process(model, random_seed=random_seed, early_exit_thres=early_exit_thres, use_early_exit=use_early_exit, - print_max_prob=print_max_prob) + print_max_prob=print_max_prob, + exit_layers=exit_layers) # Only post-process on first stage. if mpu.is_pipeline_first_stage(): @@ -104,7 +106,8 @@ def generate(model, random_seed=-1, early_exit_thres=1.0, use_early_exit=False, - print_max_prob=False): + print_max_prob=False, + exit_layers=[]): """Given prompts and input parameters, run inference and return: tokens: prompts plus the generated tokens. lengths: length of the prompt + generations. Note that we can @@ -124,9 +127,16 @@ def generate(model, if stop_token_ids != None: stop_token_ids = torch.tensor(stop_token_ids, dtype=torch.int64) values.append(len(stop_token_ids)) - values.extend(stop_token_ids) else: values.append(0) + stop_token_ids = [] + + if len(exit_layers) > 0: + exit_layers = torch.tensor(exit_layers, dtype=torch.int64) + values.append(len(exit_layers)) + else: + values.append(0) + values_float_tensor = broadcast_float_list(len(values), float_list=values) tokens_to_generate = int(values_float_tensor[0].item()) return_output_log_probs = bool(values_float_tensor[1].item()) @@ -144,13 +154,19 @@ def generate(model, early_exit_thres = values_float_tensor[13].item() use_early_exit = bool(values_float_tensor[14].item()) print_max_prob = bool(values_float_tensor[15].item()) - stop_tokens_length = int(values_float_tensor[16].item()) + exit_layers_length = int(values_float_tensor[17].item()) + if stop_tokens_length > 0: - stop_token_ids = values_float_tensor[17: 17 + stop_tokens_length].int() + stop_token_ids = broadcast_float_list(stop_tokens_length, float_list=stop_token_ids) else: stop_token_ids = None + if exit_layers_length > 0: + exit_layers = broadcast_float_list(exit_layers_length, float_list=exit_layers).int().cpu().numpy().tolist() + else: + exit_layers = [] + if random_seed != -1: torch.random.manual_seed(random_seed) @@ -184,7 +200,8 @@ def generate(model, echo_prompts=echo_prompts, early_exit_thres=early_exit_thres, use_early_exit=use_early_exit, - print_max_prob=print_max_prob) + print_max_prob=print_max_prob, + exit_layers=exit_layers) else: output = generate_tokens_probs_and_return_on_first_stage( model, context_tokens_tensor, context_length_tensor, @@ -200,7 +217,8 @@ def generate(model, echo_prompts=echo_prompts, early_exit_thres=early_exit_thres, use_early_exit=use_early_exit, - print_max_prob=print_max_prob) + print_max_prob=print_max_prob, + exit_layers=exit_layers) except Exception as e: traceback.print_exc() return output diff --git a/megatron/text_generation/generation.py b/megatron/text_generation/generation.py index 32874979..d867cff0 100644 --- a/megatron/text_generation/generation.py +++ b/megatron/text_generation/generation.py @@ -103,6 +103,7 @@ def generate_tokens_probs_and_return_on_first_stage( early_exit_thres=1.0, use_early_exit=False, print_max_prob=False, + exit_layers=[] ): """Main token generation function. Arguments: @@ -151,7 +152,8 @@ def generate_tokens_probs_and_return_on_first_stage( top_p_decay=top_p_decay, early_exit_thres=early_exit_thres, use_early_exit=use_early_exit, - print_max_prob=print_max_prob) + print_max_prob=print_max_prob, + exit_layers=exit_layers) # forward step. forward_step = ForwardStep(model, inference_params=inference_params) @@ -452,7 +454,8 @@ def generate_with_pipelined_early_exit_and_return_on_first_stage( echo_prompts=False, early_exit_thres=1.0, use_early_exit=False, - print_max_prob=False + print_max_prob=False, + exit_layers=[] ): """Main token generation function. Arguments: @@ -501,7 +504,8 @@ def generate_with_pipelined_early_exit_and_return_on_first_stage( top_p_decay=top_p_decay, early_exit_thres=early_exit_thres, use_early_exit=use_early_exit, - print_max_prob=print_max_prob) + print_max_prob=print_max_prob, + exit_layers=exit_layers) # forward step. forward_step = ForwardStep(model, inference_params=inference_params) diff --git a/megatron/text_generation/inference_params.py b/megatron/text_generation/inference_params.py index 2e57d387..27c44082 100644 --- a/megatron/text_generation/inference_params.py +++ b/megatron/text_generation/inference_params.py @@ -15,7 +15,8 @@ def __init__(self, max_batch_size, max_sequence_length, top_k=0, top_p=0, temperature=1.0, top_p_decay=0, top_p_bound=0, early_exit_thres=None, use_early_exit=False, - print_max_prob=False): + print_max_prob=False, + exit_layers=[]): self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size self.sequence_len_offset = 0 @@ -31,6 +32,8 @@ def __init__(self, max_batch_size, max_sequence_length, self.top_p_decay = top_p_decay self.top_p_bound = top_p_bound self.print_max_probs = print_max_prob + self.exit_layers = set(exit_layers) + self.use_all_exit = len(exit_layers) == 0 self.has_early_exited = False self.is_first_step = True @@ -47,6 +50,8 @@ def clear_early_exit_states(self): def do_early_exit(self, logits, layer_num): if self.has_early_exited or self.prev_has_early_exited: return False + if not (self.use_all_exit or (layer_num in self.exit_layers)): + return False last_token_logits = logits[:, -1, :] log_probs = F.log_softmax(last_token_logits, dim=1) max_log_prob, token_id = torch.max(log_probs[:, :], dim=1) diff --git a/tools/request_client.py b/tools/request_client.py index ce70f7b6..83023c8d 100644 --- a/tools/request_client.py +++ b/tools/request_client.py @@ -14,6 +14,7 @@ def request( use_early_exit=True, early_exit_thres=0.8, print_max_prob=False, + exit_layers=[] ): length = len(prompts) for i in range(length): @@ -25,6 +26,7 @@ def request( "random_seed": int(time.time_ns()) % 16384, "echo_prompts": False, "early_exit_thres": early_exit_thres, + "exit_layers": exit_layers } if use_early_exit: data["use_early_exit"] = True @@ -46,14 +48,14 @@ def request( def main( - file_name, tokens_to_generate, use_early_exit, early_exit_thres, print_max_prob + file_name, tokens_to_generate, use_early_exit, early_exit_thres, print_max_prob, exit_layers ): prompts = [] with open(file_name, "r") as f: for line in f.readlines(): prompts.append(json.loads(line)["text"]) request( - prompts, tokens_to_generate, use_early_exit, early_exit_thres, print_max_prob + prompts, tokens_to_generate, use_early_exit, early_exit_thres, print_max_prob, exit_layers ) @@ -61,7 +63,8 @@ def main( main( "tools/prompt_example.jsonl", tokens_to_generate=100, - use_early_exit=True, - early_exit_thres=0.8, - print_max_prob=False, + use_early_exit=False, + early_exit_thres=1.0, + print_max_prob=True, + exit_layers=[] ) From 89ea3083ac05a998a5f8951425c0cfae8656cd31 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 21 Dec 2023 08:53:06 +0000 Subject: [PATCH 07/12] fix bugs in EE only tune --- megatron/arguments.py | 2 +- megatron/checkpointing.py | 1 - megatron/model/language_model.py | 22 ---------------------- pretrain_early_exit_gpt.py | 2 +- tools/checkpoint/checkpoint_converter.py | 4 ++-- tools/checkpoint/saver_megatron.py | 3 +++ tools/request_client.py | 6 +++--- 7 files changed, 10 insertions(+), 30 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 8d1ab6a3..5309d6f4 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1267,7 +1267,7 @@ def _add_early_exit_args(parser): group.add_argument('--use-dynamic-exit-layer-weight', action='store_true') group.add_argument('--tune-exit', action='store_true', help='Only finetune early exit parameters.') - group.add_argument('--tune-exit-pipeline-parallel-size', type=int, default=None) + group.add_argument('--tune-exit-pipeline-parallel-size', type=int, default=1) return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index efe7a2c7..4cb1f8ea 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -55,7 +55,6 @@ def _compare(arg_name, old_arg_name=None, default=None): _compare('num_layers') _compare('hidden_size') _compare('num_attention_heads') - _compare('add_position_embedding', default=True) if args.vocab_file: _compare('max_position_embeddings') _compare('make_vocab_size_divisible_by') diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 854e1117..066e82fc 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -722,28 +722,6 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, return encoder_output, early_exit_output - - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): - """For easy load.""" - - state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(prefix=prefix, - keep_vars=keep_vars) - if self.untie_exit_output_weights: - state_dict_[self._exit_output_key] = self.exit_output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) - - if self.add_position_embedding: - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict(prefix=prefix, - keep_vars=keep_vars) - if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] \ - = self.tokentype_embeddings.state_dict(prefix=prefix, - keep_vars=keep_vars) - - return state_dict_ - def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): """For easy load.""" diff --git a/pretrain_early_exit_gpt.py b/pretrain_early_exit_gpt.py index 0fb3d521..5cd4a40d 100644 --- a/pretrain_early_exit_gpt.py +++ b/pretrain_early_exit_gpt.py @@ -52,7 +52,7 @@ def get_batch(data_iterator): labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # for Llama2Tokenizer - tokens.masked_fill_(tokens >= 32000, 2) + tokens.masked_fill_(tokens >= 32000, -1) # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, diff --git a/tools/checkpoint/checkpoint_converter.py b/tools/checkpoint/checkpoint_converter.py index 0582cd3f..5f5a0f42 100644 --- a/tools/checkpoint/checkpoint_converter.py +++ b/tools/checkpoint/checkpoint_converter.py @@ -108,7 +108,7 @@ def add_exit(args, checkpoint_load_dir, checkpoint_save_dir): print("Can't add exit layers and change exit position at the same time") return use_pre_exit = checkpoint_args.pre_exit - target_exit_layer_nums = list(set(checkpoint_args.exit_layer_nums + args.add_exit_layer_nums)) + target_exit_layer_nums = sorted(list(set(checkpoint_args.exit_layer_nums + args.add_exit_layer_nums))) tensor_parallel_size = checkpoint_args.tensor_model_parallel_size pipeline_parallel_size = checkpoint_args.pipeline_model_parallel_size use_pipeline_parallel = pipeline_parallel_size > 1 @@ -222,7 +222,7 @@ def add_exit(args, checkpoint_load_dir, checkpoint_save_dir): if args.use_exit_norm: checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm.weight'] = final_norm_weight if final_norm_bias is not None: - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm'] = final_norm_bias + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm.bias'] = final_norm_bias if not use_pipeline_parallel: checkpoint_save_path = os.path.join(checkpoint_save_dir, f'mp_rank_{tensor_rank:02d}', 'model_optim_rng.pt') else: diff --git a/tools/checkpoint/saver_megatron.py b/tools/checkpoint/saver_megatron.py index 5b576af3..483c2179 100644 --- a/tools/checkpoint/saver_megatron.py +++ b/tools/checkpoint/saver_megatron.py @@ -224,6 +224,9 @@ def get_models(count, dtype, pre_process, post_process): layer_per_stage = md.num_layers / args.target_pipeline_parallel_size mpu.set_early_exit_layer_nums(list(filter(lambda x: 0 < x <= layer_per_stage, md.exit_layer_nums))) mpu.set_early_exit_stages(list(set(map(lambda layer_num: int((layer_num - 1) // layer_per_stage), md.exit_layer_nums)))) + else: + mpu.set_early_exit_layer_nums([]) + mpu.set_early_exit_stages([]) fused_kernels.load(margs) # Embeddings diff --git a/tools/request_client.py b/tools/request_client.py index 83023c8d..ab79fd60 100644 --- a/tools/request_client.py +++ b/tools/request_client.py @@ -62,9 +62,9 @@ def main( if __name__ == "__main__": main( "tools/prompt_example.jsonl", - tokens_to_generate=100, - use_early_exit=False, - early_exit_thres=1.0, + tokens_to_generate=50, + use_early_exit=True, + early_exit_thres=0.9, print_max_prob=True, exit_layers=[] ) From ab8f201ecc778f239b8a2d00d7d0f075b1b898b1 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 22 Dec 2023 09:35:34 +0000 Subject: [PATCH 08/12] random init EE --- tools/checkpoint/checkpoint_converter.py | 115 ++++++++++++++++++----- 1 file changed, 91 insertions(+), 24 deletions(-) diff --git a/tools/checkpoint/checkpoint_converter.py b/tools/checkpoint/checkpoint_converter.py index 5f5a0f42..465b844c 100644 --- a/tools/checkpoint/checkpoint_converter.py +++ b/tools/checkpoint/checkpoint_converter.py @@ -3,6 +3,7 @@ import sys import torch import argparse +import math from collections import OrderedDict def get_args(): @@ -16,6 +17,8 @@ def get_args(): parser.add_argument('--use-exit-mlp', action='store_true') parser.add_argument('--use-exit-block', action='store_true') parser.add_argument('--use-exit-norm', action='store_true') + parser.add_argument('--random-init', action='store_true') + parser.add_argument('--init-method-std', type=float, default=0.02) parser.add_argument('--megatron-path', type=str, default=None, help='Base directory of deepspeed repository') return parser.parse_args() @@ -32,6 +35,23 @@ def load_checkpoint_args(checkpoint_root_path): model = torch.load(checkpoint_path) return model['args'] +# Init method from megatron-lm +def init_method_normal(sigma): + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + +def scaled_init_method_normal(sigma, num_layers): + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + def change_exit_position(args, checkpoint_load_dir, checkpoint_save_dir): checkpoint_args = load_checkpoint_args(checkpoint_load_dir) cur_exit_position = 'pre' if checkpoint_args.pre_exit else 'post' @@ -113,6 +133,10 @@ def add_exit(args, checkpoint_load_dir, checkpoint_save_dir): pipeline_parallel_size = checkpoint_args.pipeline_model_parallel_size use_pipeline_parallel = pipeline_parallel_size > 1 layer_per_stage = checkpoint_args.num_layers / pipeline_parallel_size + # if args.random_init: + # init_method = init_method_normal(args.init_method_std) + # output_layer_init_method = scaled_init_method_normal(args.init_method_std) + for tensor_rank in range(tensor_parallel_size): checkpoint_dicts = {} output_weight = None @@ -146,15 +170,24 @@ def add_exit(args, checkpoint_load_dir, checkpoint_save_dir): if args.use_exit_mlp and (not hasattr(state_dict['args'], 'use_exit_mlp') or not state_dict['args'].use_exit_mlp): state_dict['args'].use_exit_mlp = args.use_exit_mlp for layer_num in exit_layer_nums: + if args.random_init: + init_method = init_method_normal(args.init_method_std) + output_layer_init_method = scaled_init_method_normal(args.init_method_std, layer_num) layer_id = int(layer_num - layer_num_offset) state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.trunk.dense_h_to_4h.weight'] = \ state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.weight'] state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.trunk.dense_4h_to_h.weight'] = \ state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.weight'] - state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_h_to_4h.weight'] = \ - state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.weight'] - state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_4h_to_h.weight'] = \ - state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.weight'] + if args.random_init: + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_h_to_4h.weight'] = \ + init_method(torch.empty(state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.weight'].shape)) + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_4h_to_h.weight'] = \ + output_layer_init_method(torch.empty(state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.weight'].shape)) + else: + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_h_to_4h.weight'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.weight'] + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_4h_to_h.weight'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.weight'] state_dict['model']['language_model']['encoder'].pop(f'layers.{layer_id}.mlp.dense_h_to_4h.weight') state_dict['model']['language_model']['encoder'].pop(f'layers.{layer_id}.mlp.dense_4h_to_h.weight') if checkpoint_args.add_bias_linear: @@ -162,10 +195,16 @@ def add_exit(args, checkpoint_load_dir, checkpoint_save_dir): state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.bias'] state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.trunk.dense_4h_to_h.bias'] = \ state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.bias'] - state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_h_to_4h.bias'] = \ - state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.bias'] - state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_4h_to_h.bias'] = \ - state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.bias'] + if args.random_init: + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_h_to_4h.bias'] = \ + torch.zeros(state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.bias'].shape) + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_4h_to_h.bias'] = \ + torch.zeros(state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.bias'].shape) + else: + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_h_to_4h.bias'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.bias'] + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_4h_to_h.bias'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.bias'] state_dict['model']['language_model']['encoder'].pop(f'layers.{layer_id}.mlp.dense_h_to_4h.bias') state_dict['model']['language_model']['encoder'].pop(f'layers.{layer_id}.mlp.dense_4h_to_h.bias') # convert to exit block @@ -204,25 +243,51 @@ def add_exit(args, checkpoint_load_dir, checkpoint_save_dir): # add exit output weight and exit norm for i, layer_num in enumerate(exit_layer_nums): layer_id = int(layer_num - layer_num_offset) - checkpoint_dicts[pipeline_rank]['model']['language_model']['exit_output_layer'][f'{i}.weight'] = output_weight + if args.random_init: + init_method = init_method_normal(args.init_method_std) + output_layer_init_method = scaled_init_method_normal(args.init_method_std, layer_num) + checkpoint_dicts[pipeline_rank]['model']['language_model']['exit_output_layer'][f'{i}.weight'] = init_method(torch.empty(output_weight.shape)) + else: + checkpoint_dicts[pipeline_rank]['model']['language_model']['exit_output_layer'][f'{i}.weight'] = output_weight if args.use_exit_block: - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.input_norm.weight'] = last_layer_input_norm - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.query_key_value.weight'] = last_layer_atten_qkv - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.dense.weight'] = last_layer_atten_dense - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.post_attention_norm.weight'] = last_layer_post_norm - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_h_to_4h.weight'] = last_layer_mlp_h_to_4h - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_4h_to_h.weight'] = last_layer_mlp_4h_to_h - if checkpoint_args.add_bias_linear: - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.dense.bias'] = last_layer_atten_dense_bias - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_h_to_4h.bias'] = last_layer_h_to_4h_bias - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_4h_to_h.bias'] = last_layer_4h_to_h_bias - if checkpoint_args.normalization == 'LayerNorm': - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.input_norm.bias'] = last_layer_input_norm_bias - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.post_attention_norm.bias'] = last_layer_post_norm_bias + if args.random_init: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.input_norm.weight'] = torch.ones(last_layer_input_norm.shape) + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.query_key_value.weight'] = init_method(torch.empty(last_layer_atten_qkv.shape)) + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.dense.weight'] =output_layer_init_method(torch.empty(last_layer_atten_dense.shape)) + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.post_attention_norm.weight'] = torch.ones(last_layer_post_norm.shape) + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_h_to_4h.weight'] = init_method(torch.empty(last_layer_mlp_h_to_4h.shape)) + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_4h_to_h.weight'] = output_layer_init_method(torch.empty(last_layer_mlp_4h_to_h.shape)) + if checkpoint_args.add_bias_linear: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.dense.bias'] = torch.zeros(last_layer_atten_dense_bias.shape) + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_h_to_4h.bias'] = torch.zeros(last_layer_h_to_4h_bias.shape) + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_4h_to_h.bias'] = torch.zeros(last_layer_4h_to_h_bias.shape) + if checkpoint_args.normalization == 'LayerNorm': + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.input_norm.bias'] = torch.zeros(last_layer_input_norm_bias.shape) + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.post_attention_norm.bias'] = torch.zeros(last_layer_post_norm_bias.shape) + else: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.input_norm.weight'] = last_layer_input_norm + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.query_key_value.weight'] = last_layer_atten_qkv + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.dense.weight'] = last_layer_atten_dense + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.post_attention_norm.weight'] = last_layer_post_norm + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_h_to_4h.weight'] = last_layer_mlp_h_to_4h + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_4h_to_h.weight'] = last_layer_mlp_4h_to_h + if checkpoint_args.add_bias_linear: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.dense.bias'] = last_layer_atten_dense_bias + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_h_to_4h.bias'] = last_layer_h_to_4h_bias + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_4h_to_h.bias'] = last_layer_4h_to_h_bias + if checkpoint_args.normalization == 'LayerNorm': + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.input_norm.bias'] = last_layer_input_norm_bias + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.post_attention_norm.bias'] = last_layer_post_norm_bias if args.use_exit_norm: - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm.weight'] = final_norm_weight + if args.random_init: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm.weight'] = torch.ones(final_norm_weight.shape) + else: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm.weight'] = final_norm_weight if final_norm_bias is not None: - checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm.bias'] = final_norm_bias + if args.random_init: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm.bias'] = torch.zeros(final_norm_bias.shape) + else: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm.bias'] = final_norm_bias if not use_pipeline_parallel: checkpoint_save_path = os.path.join(checkpoint_save_dir, f'mp_rank_{tensor_rank:02d}', 'model_optim_rng.pt') else: @@ -243,6 +308,8 @@ def convert(args): change_exit_position(args, checkpoint_load_dir, checkpoint_save_dir) elif args.conversion_type == 'add-exit': add_exit(args, checkpoint_load_dir, checkpoint_save_dir) + with open(os.path.join(args.save_dir, 'latest_checkpointed_iteration.txt'), 'w', encoding='utf-8') as f: + f.write(str(args.load_iteration)) if __name__ == '__main__': From 4e24859733e85aa7bd71055977fce9ff8038808c Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 25 Dec 2023 03:53:05 +0000 Subject: [PATCH 09/12] change wandb save dir --- megatron/global_vars.py | 9 ++++++++- tools/request_client.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/megatron/global_vars.py b/megatron/global_vars.py index caa257a6..72e7b793 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -178,6 +178,12 @@ def _set_wandb_writer(args): is_master = args.rank == (args.world_size - 1) name = f'{args.wandb_exp_name}-master' if is_master \ else f'{args.wandb_exp_name}-worker-{pipeline_stage_id}' + if args.wandb_save_dir: + save_dir = args.wandb_save_dir + elif args.save: + save_dir = os.path.join(args.save, 'wandb') + else: + save_dir = os.path.join(os.getcwd(), 'wandb') wandb.init( project=args.wandb_project, group=args.wandb_group, @@ -186,7 +192,8 @@ def _set_wandb_writer(args): config=args, force=False, notes=description, - tags=['master'if is_master else 'worker'] + tags=['master'if is_master else 'worker'], + dir=save_dir ) _GLOBAL_WANDB_WRITER = wandb diff --git a/tools/request_client.py b/tools/request_client.py index ab79fd60..55751626 100644 --- a/tools/request_client.py +++ b/tools/request_client.py @@ -64,7 +64,7 @@ def main( "tools/prompt_example.jsonl", tokens_to_generate=50, use_early_exit=True, - early_exit_thres=0.9, + early_exit_thres=1.0, print_max_prob=True, exit_layers=[] ) From 0ec8c9c9f95118762a1d5a3a786fdc4ca9ceacf6 Mon Sep 17 00:00:00 2001 From: "panxuchen.pxc" Date: Fri, 26 Jan 2024 18:13:02 +0800 Subject: [PATCH 10/12] update readme and example scripts for ee_tuning --- README.md | 50 +++++- .../ee_inference_server.sh | 0 examples/{early_exit => ee_traning}/1-3B.sh | 4 + examples/{early_exit => ee_traning}/13B.sh | 6 +- examples/{early_exit => ee_traning}/30B.sh | 6 +- examples/{early_exit => ee_traning}/7B.sh | 6 +- examples/ee_tuning/convert/add_exit_layers.sh | 99 ++++++++++++ .../ee_tuning/convert/convert_llama_hf.sh | 22 +++ .../tune/llama2_13B_1_exit_mlp_pt.sh | 149 +++++++++++++++++ .../tune/llama2_13B_8_exit_mlp_pt.sh | 149 +++++++++++++++++ .../tune/llama2_70B_1_exit_mlp_pt.sh | 153 ++++++++++++++++++ .../tune/llama2_70B_8_exit_mlp_pt.sh | 153 ++++++++++++++++++ tools/checkpoint/checkpoint_converter.py | 5 +- 13 files changed, 790 insertions(+), 12 deletions(-) rename examples/{early_exit => ee_inference}/ee_inference_server.sh (100%) rename examples/{early_exit => ee_traning}/1-3B.sh (97%) rename examples/{early_exit => ee_traning}/13B.sh (96%) rename examples/{early_exit => ee_traning}/30B.sh (96%) rename examples/{early_exit => ee_traning}/7B.sh (96%) create mode 100755 examples/ee_tuning/convert/add_exit_layers.sh create mode 100755 examples/ee_tuning/convert/convert_llama_hf.sh create mode 100755 examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh create mode 100755 examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh create mode 100755 examples/ee_tuning/tune/llama2_70B_1_exit_mlp_pt.sh create mode 100755 examples/ee_tuning/tune/llama2_70B_8_exit_mlp_pt.sh diff --git a/README.md b/README.md index 0c9f5a61..aaca9f85 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # README -[EE-LLM](https://arxiv.org/abs/2312.04916) is a framework for large-scale training and inference of early-exit (EE) large language models (LLMs), which is built upon [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) and currently under active development. +[EE-LLM](https://arxiv.org/abs/2312.04916) is a framework for large-scale training, tunning and inference of early-exit (EE) large language models (LLMs), which is built upon [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). ![](images/ee_architecture.png) @@ -19,16 +19,16 @@ Below are several example training scripts used in our paper. ``` # train 1.3B model -./examples/early_exit/1-3B.sh +./examples/ee_training/1-3B.sh # train 7B model -./examples/early_exit/7B.sh +./examples/ee_training/7B.sh # train 13B model -./example/early_exit/13B.sh +./example/ee_training/13B.sh # train 30B model -./example/early_exit/30B.sh +./example/ee_training/30B.sh ``` @@ -76,6 +76,44 @@ Below are the new configurations of EE-LLM compared to Megatron-LM. You can cust - `--backward-forward-ratio`: An estimate of the ratio of time consumption between backward and forward computation during training, used to automatically calculate the optimal number of inserted microbatches. Default to 2.0. [Experimental] +## Tuning + +EE-LLM has supported to tune an existing standard LLM into a early-exit LLM, which is called "EE-Tuning". + +> Before using EE-Tunning, please make sure your existing LLM checkpoint is in Megatron-LM format. +> As an example, `examples/ee_tuning/convert/convert_llama_hf.sh` provides the functionality to convert the llama2 huggingface checkpoint into Megatron-LM format. + +First, use `tools/checkpoint/checkpoint_converter.py` to add early-exit modules to the checkpoint, and the arguments of the script are listed below: + +- `--load-dir`: The Megatron-LM format standard LLM checkpoint file path. +- `--load-iteration`: The iteration number of checkpoints to be loaded. +- `--save-dir`: The output early-exit LLM checkpoint file path. +- `--add-exit-layer-nums`: The layer numbers of the layers where the early-exit module needs to be added. +- `--use-exit-norm`: Add a norm to the early-exit module. +- `--use-exit-mlp`: Add an MLP to the early-exit module. +- `--use-exit-block`: Add an transformer layer to the early-exit module. +- `--random-init`: Initialize the early-exit module randomly. +- `--megatron-path`: Path to EE-LLM root directory. + +> `examples/ee_tuning/convert/add_exit_layers.sh` provides some conversion examples. + +Then, tune the converted checkpoint using similar method as shown [training process](#training). Below are some example scripts. + +```shell +# tune llama2 13B chat with 8 exits +./examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh + +# tune llama2 13B chat with 1 exit (only load the first 1/4 of the model) +./examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh +``` + +Here are the new parameters added by EE-Tuning. + +- `--tune-exit`: Activate the functionality of EE-tunning. +- `--tune-exit-pipeline-parallel-size`: Used to support partial loading functionality, only load pipeline stages whose stage number not larger than this value. + +Other parameters are the same as the training part. + ## Inference We provided an text generation server for inference of early-exit LLMs. @@ -83,7 +121,7 @@ To start a server, you can use the following script. Before running, please set `CHECKPOINT_PATH` to the root folder path of the checkpoint, and set `TP` and `PP` appropriately according to the parallelism of the checkpoint. ``` -./example/early_exit/ee_inference_server.sh +./example/ee_inference/ee_inference_server.sh ``` After the server is started, you can use `tools/request_client.py` to send requests to the server. diff --git a/examples/early_exit/ee_inference_server.sh b/examples/ee_inference/ee_inference_server.sh similarity index 100% rename from examples/early_exit/ee_inference_server.sh rename to examples/ee_inference/ee_inference_server.sh diff --git a/examples/early_exit/1-3B.sh b/examples/ee_traning/1-3B.sh similarity index 97% rename from examples/early_exit/1-3B.sh rename to examples/ee_traning/1-3B.sh index a5481fa3..74bb9d7c 100755 --- a/examples/early_exit/1-3B.sh +++ b/examples/ee_traning/1-3B.sh @@ -145,6 +145,10 @@ OUTPUT_ARGS=" --wandb-exp-name $RUN_NAME \ " +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +cd $MEGATRON_ROOT_PATH + torchrun $DIST_ARGS \ pretrain_early_exit_gpt.py \ $GPT_ARGS \ diff --git a/examples/early_exit/13B.sh b/examples/ee_traning/13B.sh similarity index 96% rename from examples/early_exit/13B.sh rename to examples/ee_traning/13B.sh index 6ebd22a5..23893587 100755 --- a/examples/early_exit/13B.sh +++ b/examples/ee_traning/13B.sh @@ -1,7 +1,7 @@ #!/bin/bash PROJECT_NAME=EE-LLM -GROUP_NAME=7B-EXIT-8-16-untie-300B +GROUP_NAME=7B-EXIT-8-16-untie-800B RUN_NAME=`date "+%m%d-%H%M"` @@ -147,6 +147,10 @@ OUTPUT_ARGS=" --wandb-exp-name $RUN_NAME \ " +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +cd $MEGATRON_ROOT_PATH + torchrun $DIST_ARGS \ pretrain_early_exit_gpt.py \ $GPT_ARGS \ diff --git a/examples/early_exit/30B.sh b/examples/ee_traning/30B.sh similarity index 96% rename from examples/early_exit/30B.sh rename to examples/ee_traning/30B.sh index 2903db72..e00c15d0 100755 --- a/examples/early_exit/30B.sh +++ b/examples/ee_traning/30B.sh @@ -1,7 +1,7 @@ #!/bin/bash PROJECT_NAME=EE-LLM -GROUP_NAME=7B-EXIT-8-16-untie-300B +GROUP_NAME=7B-EXIT-8-16-untie-800B RUN_NAME=`date "+%m%d-%H%M"` @@ -147,6 +147,10 @@ OUTPUT_ARGS=" --wandb-exp-name $RUN_NAME \ " +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +cd $MEGATRON_ROOT_PATH + torchrun $DIST_ARGS \ pretrain_early_exit_gpt.py \ $GPT_ARGS \ diff --git a/examples/early_exit/7B.sh b/examples/ee_traning/7B.sh similarity index 96% rename from examples/early_exit/7B.sh rename to examples/ee_traning/7B.sh index 31439e53..3cd0c622 100755 --- a/examples/early_exit/7B.sh +++ b/examples/ee_traning/7B.sh @@ -1,7 +1,7 @@ #!/bin/bash PROJECT_NAME=EE-LLM -GROUP_NAME=7B-EXIT-8-16-untie-300B +GROUP_NAME=7B-EXIT-8-16-untie-800B RUN_NAME=`date "+%m%d-%H%M"` @@ -147,6 +147,10 @@ OUTPUT_ARGS=" --wandb-exp-name $RUN_NAME \ " +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +cd $MEGATRON_ROOT_PATH + torchrun $DIST_ARGS \ pretrain_early_exit_gpt.py \ $GPT_ARGS \ diff --git a/examples/ee_tuning/convert/add_exit_layers.sh b/examples/ee_tuning/convert/add_exit_layers.sh new file mode 100755 index 00000000..1d3afa90 --- /dev/null +++ b/examples/ee_tuning/convert/add_exit_layers.sh @@ -0,0 +1,99 @@ +#!/bin/bash + +LOAD_DIR= # path to the converted llama checkpoint in megatron format +SAVE_DIR= # path to save the converted EE LLM checkpoint + +LOAD_ITER=1 +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../../.." && pwd) + +# For llama2 13B model (40 layers) + +## add an embedding only exit every 1/8 depth +# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \ +# --load-dir $LOAD_DIR \ +# --save-dir $SAVE_DIR \ +# --load-iteration $LOAD_ITER \ +# --conversion-type add-exit \ +# --add-exit-layer-nums 5 10 15 20 25 30 35 40 \ +# --megatron-path $MEGATRON_ROOT_PATH + +## add an embedding-norm exit every 1/8 depth +# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \ +# --load-dir $LOAD_DIR \ +# --save-dir $SAVE_DIR \ +# --load-iteration $LOAD_ITER \ +# --conversion-type add-exit \ +# --add-exit-layer-nums 5 10 15 20 25 30 35 40 \ +# --megatron-path $MEGATRON_ROOT_PATH + +## add an embedding-norm-mlp exit every 1/8 depth +python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \ + --load-dir $LOAD_DIR \ + --save-dir $SAVE_DIR \ + --load-iteration $LOAD_ITER \ + --use-exit-norm \ + --use-exit-mlp \ + --conversion-type add-exit \ + --add-exit-layer-nums 5 10 15 20 25 30 35 40 \ + --megatron-path $MEGATRON_ROOT_PATH + +## add an embedding-norm-layer exit every 1/8 depth +# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \ +# --load-dir $LOAD_DIR \ +# --save-dir $SAVE_DIR \ +# --load-iteration $LOAD_ITER \ +# --use-exit-norm \ +# --use-exit-block \ +# --conversion-type add-exit \ +# --add-exit-layer-nums 5 10 15 20 25 30 35 40 \ +# --megatron-path $MEGATRON_ROOT_PATH + +## add an embedding-norm-mlp exit at 1/4 depth +# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \ +# --load-dir $LOAD_DIR \ +# --save-dir $SAVE_DIR \ +# --load-iteration $LOAD_ITER \ +# --use-exit-norm \ +# --use-exit-mlp \ +# --conversion-type add-exit \ +# --add-exit-layer-nums 10 \ +# --megatron-path $MEGATRON_ROOT_PATH + +## add an random init embedding-norm-mlp exit at 1/4 depth +# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \ +# --load-dir $LOAD_DIR \ +# --save-dir $SAVE_DIR \ +# --load-iteration $LOAD_ITER \ +# --use-exit-norm \ +# --use-exit-mlp \ +# --random-init \ +# --conversion-type add-exit \ +# --add-exit-layer-nums 10 \ +# --megatron-path $MEGATRON_ROOT_PATH + +# For llama2 70B model (80 layers) + +## add an embedding-norm-mlp exit every 1/8 depth +# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \ +# --load-dir $LOAD_DIR \ +# --save-dir $SAVE_DIR \ +# --load-iteration $LOAD_ITER \ +# --use-exit-norm \ +# --use-exit-mlp \ +# --conversion-type add-exit \ +# --add-exit-layer-nums 10 20 30 40 50 60 70 80 \ +# --megatron-path $MEGATRON_ROOT_PATH + +# For llama2 7B model (32 layers) + +## add an embedding-norm-mlp exit every 1/8 depth +# python ${MEGATRON_ROOT_PATH}/tools/checkpoint/checkpoint_converter.py \ +# --load-dir $LOAD_DIR \ +# --save-dir $SAVE_DIR \ +# --load-iteration $LOAD_ITER \ +# --use-exit-norm \ +# --use-exit-mlp \ +# --conversion-type add-exit \ +# --add-exit-layer-nums 4 8 12 16 20 24 28 32 \ +# --megatron-path $MEGATRON_ROOT_PATH \ No newline at end of file diff --git a/examples/ee_tuning/convert/convert_llama_hf.sh b/examples/ee_tuning/convert/convert_llama_hf.sh new file mode 100755 index 00000000..fb874cff --- /dev/null +++ b/examples/ee_tuning/convert/convert_llama_hf.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +LOAD_DIR= # path to the llama2 huggingface checkpoint dir +SAVE_DIR= # path to save the converted megatron checkpoint +TP=1 # target tensor parallel size +PP=4 # target pipeline parallel size + +TOKENIZER_PATH= ${LOAD_DIR}/tokenizer.model + +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../../.." && pwd) + +python $MEGATRON_ROOT_PATH/tools/checkpoint/util.py \ + --model-type EarlyExitGPT \ + --load-dir $LOAD_DIR \ + --save-dir $SAVE_DIR \ + --loader llama2_hf \ + --saver megatron \ + --target-tensor-parallel-size $TP \ + --target-pipeline-parallel-size $PP \ + --megatron-path $MEGATRON_ROOT_PATH \ + --tokenizer-model $TOKENIZER_PATH \ No newline at end of file diff --git a/examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh b/examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh new file mode 100755 index 00000000..190e0888 --- /dev/null +++ b/examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh @@ -0,0 +1,149 @@ +#!/bin/bash + +PROJECT_NAME=EE-TUNE +GROUP_NAME=llama-2-13B-chat-1-EXIT-pt + +CURRENT_TIME=`date "+%m%d-%H%M"` + +MASTER_NAME=${CURRENT_TIME} + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export OMP_NUM_THREADS=4 + +# Checkpoint configuration +MODEL_HOME= +LOAD_PATH=${MODEL_HOME}/checkpoints/MET-EXP/llama2-13b-chat-1-exit # your checkpoint path +CHECKPOINT_PATH=${MODEL_HOME}/checkpoints/$PROJECT_NAME/$GROUP_NAME +TOKENIZER_PATH=${MODEL_HOME}/tokenizer/tokenizer.model + +# Data configuration +DATA_HOME= +DATASET_ARXIV=${DATA_HOME}/redpajama-arxiv/all +DATASET_BOOKS=${DATA_HOME}/redpajama-book/all +DATASET_C4=${DATA_HOME}/redpajama-c4/all +DATASET_CC=${DATA_HOME}/redpajama-cc/all +DATASET_STACKEXCHANGE=${DATA_HOME}/redpajama-pile-stackexchange/all +DATASET_CODE=${DATA_HOME}/redpajama-stack-code/all +DATASET_WIKIPEDIA=${DATA_HOME}/redpajama-wiki/all +DATASET_PILE_EUROPARL=${DATA_HOME}/the-pile-europarl/all +DATASET_PILE_FREELAW=${DATA_HOME}/the-pile-freelaw/all +DATASET_PILE_HACKERNEWS=${DATA_HOME}/the-pile-hackernews/all +DATASET_PILE_NIH=${DATA_HOME}/the-pile-nih/all +DATASET_PILE_PHILPAPER=${DATA_HOME}/the-pile-philpaper/all +DATASET_PILE_PMA=${DATA_HOME}/the-pile-pubmed-abstract/all +DATASET_PILE_PMC=${DATA_HOME}/the-pile-pubmed-central/all +DATASET_PILE_USPTO=${DATA_HOME}/the-pile-uspto/all + +DATA_PATH="\ + 0.0362 ${DATASET_ARXIV} \ + 0.0657 ${DATASET_BOOKS} \ + 0.2264 ${DATASET_C4} \ + 0.4491 ${DATASET_CC} \ + 0.0246 ${DATASET_STACKEXCHANGE} \ + 0.0810 ${DATASET_CODE} \ + 0.0548 ${DATASET_WIKIPEDIA} \ + 0.0010 ${DATASET_PILE_EUROPARL} \ + 0.0162 ${DATASET_PILE_FREELAW} \ + 0.0006 ${DATASET_PILE_HACKERNEWS} \ + 0.0005 ${DATASET_PILE_NIH} \ + 0.0006 ${DATASET_PILE_PHILPAPER} \ + 0.0065 ${DATASET_PILE_PMA} \ + 0.0318 ${DATASET_PILE_PMC} \ + 0.0050 ${DATASET_PILE_USPTO} \ +" + +NLAYERS=40 +HIDDEN=5120 +HEADS=40 +SEQ=2048 +FFN_SIZE=13824 + +TP=1 +PP=4 + +MICRO_BATCH=4 +GLOBAL_BATCH=16 + + +MASTER_ADDR=127.0.0.1 +MASTER_PORT=5901 +WORLD_SIZE=1 +RANK=0 +NPROC_PER_NODE=4 + +TRAIN_ITER=40000 +EVAL_INTERVAL=50000 +SAVE_INTERVAL=20000 + +DIST_ARGS=" + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --nproc_per_node $NPROC_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + " + +GPT_ARGS=" + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --query-key-layer-scaling \ + --num-layers $NLAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads $HEADS \ + --seq-length $SEQ \ + --max-position-embeddings $SEQ \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --lr 0.0001 \ + --train-iters $TRAIN_ITER \ + --min-lr 1.0e-5 \ + --lr-warmup-fraction .01 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --adam-eps 1e-5 \ + --clip-grad 1.0 \ + --bf16 \ + --disable-bias-linear \ + --use-flash-attn \ + --normalization RMSNorm \ + --position-embedding-type rope \ + --swiglu \ + --exit-layer-nums 10 \ + --untie-embeddings-and-output-weights \ + --untie-exit-output-weights \ + --padded-vocab-size 32000 \ + --ffn-hidden-size $FFN_SIZE \ + --finetune \ + --tune-exit-pipeline-parallel-size 1 \ + --tune-exit \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --tokenizer-type Llama2Tokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 990,9,1 \ +" + +OUTPUT_ARGS=" + --log-interval 10 \ + --log-timers-to-tracker \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $EVAL_INTERVAL \ + --eval-iters 1 \ + --wandb-project $PROJECT_NAME \ + --wandb-group $GROUP_NAME \ + --wandb-exp-name $MASTER_NAME \ +" + +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +cd $MEGATRON_ROOT_PATH + +torchrun $DIST_ARGS \ + pretrain_early_exit_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --load $LOAD_PATH \ + --save $CHECKPOINT_PATH diff --git a/examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh b/examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh new file mode 100755 index 00000000..2500bb4e --- /dev/null +++ b/examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh @@ -0,0 +1,149 @@ +#!/bin/bash + +PROJECT_NAME=EE-TUNE +GROUP_NAME=llama-2-13B-chat-8-EXIT-pt + +CURRENT_TIME=`date "+%m%d-%H%M"` + +MASTER_NAME=${CURRENT_TIME} + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export OMP_NUM_THREADS=4 + +# Checkpoint configuration +MODEL_HOME= +LOAD_PATH=${MODEL_HOME}/checkpoints/llama2-13b-chat-8-exit # your checkpoint path +CHECKPOINT_PATH=${MODEL_HOME}/checkpoints/$PROJECT_NAME/$GROUP_NAME +TOKENIZER_PATH=${MODEL_HOME}/tokenizer/tokenizer.model + +# Data configuration +DATA_HOME= +DATASET_ARXIV=${DATA_HOME}/redpajama-arxiv/all +DATASET_BOOKS=${DATA_HOME}/redpajama-book/all +DATASET_C4=${DATA_HOME}/redpajama-c4/all +DATASET_CC=${DATA_HOME}/redpajama-cc/all +DATASET_STACKEXCHANGE=${DATA_HOME}/redpajama-pile-stackexchange/all +DATASET_CODE=${DATA_HOME}/redpajama-stack-code/all +DATASET_WIKIPEDIA=${DATA_HOME}/redpajama-wiki/all +DATASET_PILE_EUROPARL=${DATA_HOME}/the-pile-europarl/all +DATASET_PILE_FREELAW=${DATA_HOME}/the-pile-freelaw/all +DATASET_PILE_HACKERNEWS=${DATA_HOME}/the-pile-hackernews/all +DATASET_PILE_NIH=${DATA_HOME}/the-pile-nih/all +DATASET_PILE_PHILPAPER=${DATA_HOME}/the-pile-philpaper/all +DATASET_PILE_PMA=${DATA_HOME}/the-pile-pubmed-abstract/all +DATASET_PILE_PMC=${DATA_HOME}/the-pile-pubmed-central/all +DATASET_PILE_USPTO=${DATA_HOME}/the-pile-uspto/all + +DATA_PATH="\ + 0.0362 ${DATASET_ARXIV} \ + 0.0657 ${DATASET_BOOKS} \ + 0.2264 ${DATASET_C4} \ + 0.4491 ${DATASET_CC} \ + 0.0246 ${DATASET_STACKEXCHANGE} \ + 0.0810 ${DATASET_CODE} \ + 0.0548 ${DATASET_WIKIPEDIA} \ + 0.0010 ${DATASET_PILE_EUROPARL} \ + 0.0162 ${DATASET_PILE_FREELAW} \ + 0.0006 ${DATASET_PILE_HACKERNEWS} \ + 0.0005 ${DATASET_PILE_NIH} \ + 0.0006 ${DATASET_PILE_PHILPAPER} \ + 0.0065 ${DATASET_PILE_PMA} \ + 0.0318 ${DATASET_PILE_PMC} \ + 0.0050 ${DATASET_PILE_USPTO} \ +" + +NLAYERS=40 +HIDDEN=5120 +HEADS=40 +SEQ=2048 +FFN_SIZE=13824 + +TP=1 +PP=4 + +MICRO_BATCH=4 +GLOBAL_BATCH=16 + + +MASTER_ADDR=127.0.0.1 +MASTER_PORT=5901 +WORLD_SIZE=1 +RANK=0 +NPROC_PER_NODE=8 + +TRAIN_ITER=40000 +EVAL_INTERVAL=50000 +SAVE_INTERVAL=20000 + +DIST_ARGS=" + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --nproc_per_node $NPROC_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + " + +GPT_ARGS=" + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --query-key-layer-scaling \ + --num-layers $NLAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads $HEADS \ + --seq-length $SEQ \ + --max-position-embeddings $SEQ \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --lr 0.0001 \ + --train-iters $TRAIN_ITER \ + --min-lr 1.0e-5 \ + --lr-warmup-fraction .01 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --adam-eps 1e-5 \ + --clip-grad 1.0 \ + --bf16 \ + --disable-bias-linear \ + --use-flash-attn \ + --normalization RMSNorm \ + --position-embedding-type rope \ + --swiglu \ + --exit-layer-nums 5 10 15 20 25 30 35 40 \ + --untie-embeddings-and-output-weights \ + --untie-exit-output-weights \ + --padded-vocab-size 32000 \ + --ffn-hidden-size $FFN_SIZE \ + --finetune \ + --tune-exit-pipeline-parallel-size 4 \ + --tune-exit \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --tokenizer-type Llama2Tokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 990,9,1 \ +" + +OUTPUT_ARGS=" + --log-interval 10 \ + --log-timers-to-tracker \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $EVAL_INTERVAL \ + --eval-iters 1 \ + --wandb-project $PROJECT_NAME \ + --wandb-group $GROUP_NAME \ + --wandb-exp-name $MASTER_NAME \ +" + +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +cd $MEGATRON_ROOT_PATH + +torchrun $DIST_ARGS \ + pretrain_early_exit_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --load $LOAD_PATH \ + --save $CHECKPOINT_PATH diff --git a/examples/ee_tuning/tune/llama2_70B_1_exit_mlp_pt.sh b/examples/ee_tuning/tune/llama2_70B_1_exit_mlp_pt.sh new file mode 100755 index 00000000..3ce586ad --- /dev/null +++ b/examples/ee_tuning/tune/llama2_70B_1_exit_mlp_pt.sh @@ -0,0 +1,153 @@ +#!/bin/bash + +PROJECT_NAME=EE-TUNE +GROUP_NAME=llama-2-70B-chat-1-EXIT-pt + +CURRENT_TIME=`date "+%m%d-%H%M"` + +RUN_NAME=${CURRENT_TIME} + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export OMP_NUM_THREADS=4 + +# Checkpoint configuration +MODEL_HOME= +LOAD_PATH=${MODEL_HOME}/checkpoints/MET-EXP/llama2-70b-chat-8-exit # your checkpoint path +CHECKPOINT_PATH=${MODEL_HOME}/checkpoints/$PROJECT_NAME/$GROUP_NAME +TOKENIZER_PATH=${MODEL_HOME}/tokenizer/tokenizer.model + +# Data configuration +DATA_HOME= +DATASET_ARXIV=${DATA_HOME}/redpajama-arxiv/all +DATASET_BOOKS=${DATA_HOME}/redpajama-book/all +DATASET_C4=${DATA_HOME}/redpajama-c4/all +DATASET_CC=${DATA_HOME}/redpajama-cc/all +DATASET_STACKEXCHANGE=${DATA_HOME}/redpajama-pile-stackexchange/all +DATASET_CODE=${DATA_HOME}/redpajama-stack-code/all +DATASET_WIKIPEDIA=${DATA_HOME}/redpajama-wiki/all +DATASET_PILE_EUROPARL=${DATA_HOME}/the-pile-europarl/all +DATASET_PILE_FREELAW=${DATA_HOME}/the-pile-freelaw/all +DATASET_PILE_HACKERNEWS=${DATA_HOME}/the-pile-hackernews/all +DATASET_PILE_NIH=${DATA_HOME}/the-pile-nih/all +DATASET_PILE_PHILPAPER=${DATA_HOME}/the-pile-philpaper/all +DATASET_PILE_PMA=${DATA_HOME}/the-pile-pubmed-abstract/all +DATASET_PILE_PMC=${DATA_HOME}/the-pile-pubmed-central/all +DATASET_PILE_USPTO=${DATA_HOME}/the-pile-uspto/all + +DATA_PATH="\ + 0.0362 ${DATASET_ARXIV} \ + 0.0657 ${DATASET_BOOKS} \ + 0.2264 ${DATASET_C4} \ + 0.4491 ${DATASET_CC} \ + 0.0246 ${DATASET_STACKEXCHANGE} \ + 0.0810 ${DATASET_CODE} \ + 0.0548 ${DATASET_WIKIPEDIA} \ + 0.0010 ${DATASET_PILE_EUROPARL} \ + 0.0162 ${DATASET_PILE_FREELAW} \ + 0.0006 ${DATASET_PILE_HACKERNEWS} \ + 0.0005 ${DATASET_PILE_NIH} \ + 0.0006 ${DATASET_PILE_PHILPAPER} \ + 0.0065 ${DATASET_PILE_PMA} \ + 0.0318 ${DATASET_PILE_PMC} \ + 0.0050 ${DATASET_PILE_USPTO} \ +" + +NLAYERS=80 +HIDDEN=8192 +HEADS=64 +SEQ=2048 +FFN_SIZE=28672 + +TP=1 +PP=4 + +MICRO_BATCH=4 +GLOBAL_BATCH=16 + + +MASTER_ADDR=127.0.0.1 +MASTER_PORT=5900 +WORLD_SIZE=1 +RANK=0 +NPROC_PER_NODE=4 + +TRAIN_ITER=40000 +EVAL_INTERVAL=40000 +SAVE_INTERVAL=20000 + +DIST_ARGS=" + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --nproc_per_node $NPROC_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + " + +GPT_ARGS=" + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --num-layers $NLAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads $HEADS \ + --seq-length $SEQ \ + --max-position-embeddings $SEQ \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --lr 0.0001 \ + --train-iters $TRAIN_ITER \ + --sequence-parallel \ + --min-lr 1.0e-5 \ + --lr-warmup-fraction .01 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --adam-eps 1e-5 \ + --clip-grad 1.0 \ + --bf16 \ + --disable-bias-linear \ + --use-flash-attn \ + --normalization RMSNorm \ + --position-embedding-type rope \ + --swiglu \ + --group-query-attention \ + --num-query-groups 8 \ + --exit-layer-nums 20 \ + --use-exit-norm \ + --use-exit-mlp \ + --untie-embeddings-and-output-weights \ + --untie-exit-output-weights \ + --padded-vocab-size 32000 \ + --ffn-hidden-size $FFN_SIZE \ + --finetune \ + --tune-exit-pipeline-parallel-size 1 \ + --tune-exit \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --tokenizer-type Llama2Tokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 990,9,1 \ +" + +OUTPUT_ARGS=" + --log-interval 10 \ + --log-timers-to-tracker \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $EVAL_INTERVAL \ + --eval-iters 1 \ + --wandb-project $PROJECT_NAME \ + --wandb-group $GROUP_NAME \ + --wandb-exp-name $RUN_NAME \ +" + +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +cd $MEGATRON_ROOT_PATH + +torchrun $DIST_ARGS \ + pretrain_early_exit_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --load $LOAD_PATH \ + --save $CHECKPOINT_PATH diff --git a/examples/ee_tuning/tune/llama2_70B_8_exit_mlp_pt.sh b/examples/ee_tuning/tune/llama2_70B_8_exit_mlp_pt.sh new file mode 100755 index 00000000..892d28c9 --- /dev/null +++ b/examples/ee_tuning/tune/llama2_70B_8_exit_mlp_pt.sh @@ -0,0 +1,153 @@ +#!/bin/bash + +PROJECT_NAME=EE-TUNE +GROUP_NAME=llama-2-70B-chat-8-EXIT-pt + +CURRENT_TIME=`date "+%m%d-%H%M"` + +RUN_NAME=${CURRENT_TIME} + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export OMP_NUM_THREADS=4 + +# Checkpoint configuration +MODEL_HOME= +LOAD_PATH=${MODEL_HOME}/checkpoints/MET-EXP/llama2-70b-chat-8-exit # your checkpoint path +CHECKPOINT_PATH=${MODEL_HOME}/checkpoints/$PROJECT_NAME/$GROUP_NAME +TOKENIZER_PATH=${MODEL_HOME}/tokenizer/tokenizer.model + +# Data configuration +DATA_HOME= +DATASET_ARXIV=${DATA_HOME}/redpajama-arxiv/all +DATASET_BOOKS=${DATA_HOME}/redpajama-book/all +DATASET_C4=${DATA_HOME}/redpajama-c4/all +DATASET_CC=${DATA_HOME}/redpajama-cc/all +DATASET_STACKEXCHANGE=${DATA_HOME}/redpajama-pile-stackexchange/all +DATASET_CODE=${DATA_HOME}/redpajama-stack-code/all +DATASET_WIKIPEDIA=${DATA_HOME}/redpajama-wiki/all +DATASET_PILE_EUROPARL=${DATA_HOME}/the-pile-europarl/all +DATASET_PILE_FREELAW=${DATA_HOME}/the-pile-freelaw/all +DATASET_PILE_HACKERNEWS=${DATA_HOME}/the-pile-hackernews/all +DATASET_PILE_NIH=${DATA_HOME}/the-pile-nih/all +DATASET_PILE_PHILPAPER=${DATA_HOME}/the-pile-philpaper/all +DATASET_PILE_PMA=${DATA_HOME}/the-pile-pubmed-abstract/all +DATASET_PILE_PMC=${DATA_HOME}/the-pile-pubmed-central/all +DATASET_PILE_USPTO=${DATA_HOME}/the-pile-uspto/all + +DATA_PATH="\ + 0.0362 ${DATASET_ARXIV} \ + 0.0657 ${DATASET_BOOKS} \ + 0.2264 ${DATASET_C4} \ + 0.4491 ${DATASET_CC} \ + 0.0246 ${DATASET_STACKEXCHANGE} \ + 0.0810 ${DATASET_CODE} \ + 0.0548 ${DATASET_WIKIPEDIA} \ + 0.0010 ${DATASET_PILE_EUROPARL} \ + 0.0162 ${DATASET_PILE_FREELAW} \ + 0.0006 ${DATASET_PILE_HACKERNEWS} \ + 0.0005 ${DATASET_PILE_NIH} \ + 0.0006 ${DATASET_PILE_PHILPAPER} \ + 0.0065 ${DATASET_PILE_PMA} \ + 0.0318 ${DATASET_PILE_PMC} \ + 0.0050 ${DATASET_PILE_USPTO} \ +" + +NLAYERS=80 +HIDDEN=8192 +HEADS=64 +SEQ=2048 +FFN_SIZE=28672 + +TP=1 +PP=4 + +MICRO_BATCH=4 +GLOBAL_BATCH=16 + + +MASTER_ADDR=127.0.0.1 +MASTER_PORT=5900 +WORLD_SIZE=1 +RANK=0 +NPROC_PER_NODE=8 + +TRAIN_ITER=40000 +EVAL_INTERVAL=40000 +SAVE_INTERVAL=20000 + +DIST_ARGS=" + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --nproc_per_node $NPROC_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + " + +GPT_ARGS=" + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --num-layers $NLAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads $HEADS \ + --seq-length $SEQ \ + --max-position-embeddings $SEQ \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --lr 0.0001 \ + --train-iters $TRAIN_ITER \ + --sequence-parallel \ + --min-lr 1.0e-5 \ + --lr-warmup-fraction .01 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --adam-eps 1e-5 \ + --clip-grad 1.0 \ + --bf16 \ + --disable-bias-linear \ + --use-flash-attn \ + --normalization RMSNorm \ + --position-embedding-type rope \ + --swiglu \ + --group-query-attention \ + --num-query-groups 8 \ + --exit-layer-nums 10 20 30 40 50 60 70 80 \ + --use-exit-norm \ + --use-exit-mlp \ + --untie-embeddings-and-output-weights \ + --untie-exit-output-weights \ + --padded-vocab-size 32000 \ + --ffn-hidden-size $FFN_SIZE \ + --finetune \ + --tune-exit-pipeline-parallel-size 4 \ + --tune-exit \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --tokenizer-type Llama2Tokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 990,9,1 \ +" + +OUTPUT_ARGS=" + --log-interval 10 \ + --log-timers-to-tracker \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $EVAL_INTERVAL \ + --eval-iters 1 \ + --wandb-project $PROJECT_NAME \ + --wandb-group $GROUP_NAME \ + --wandb-exp-name $RUN_NAME \ +" + +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +cd $MEGATRON_ROOT_PATH + +torchrun $DIST_ARGS \ + pretrain_early_exit_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --load $LOAD_PATH \ + --save $CHECKPOINT_PATH diff --git a/tools/checkpoint/checkpoint_converter.py b/tools/checkpoint/checkpoint_converter.py index 465b844c..3812c721 100644 --- a/tools/checkpoint/checkpoint_converter.py +++ b/tools/checkpoint/checkpoint_converter.py @@ -11,7 +11,7 @@ def get_args(): parser.add_argument('--load-dir', type=str) parser.add_argument('--load-iteration', type=int) parser.add_argument('--save-dir', type=str) - parser.add_argument('--conversion-type', choices=['exit-position', 'add-exit']) + parser.add_argument('--conversion-type', choices=['exit-position', 'add-exit'], default='add-exit') parser.add_argument('--target-exit-position', choices=['pre', 'post'], default='post') parser.add_argument('--add-exit-layer-nums', type=int, nargs='+', default=[]) parser.add_argument('--use-exit-mlp', action='store_true') @@ -19,8 +19,7 @@ def get_args(): parser.add_argument('--use-exit-norm', action='store_true') parser.add_argument('--random-init', action='store_true') parser.add_argument('--init-method-std', type=float, default=0.02) - parser.add_argument('--megatron-path', type=str, default=None, - help='Base directory of deepspeed repository') + parser.add_argument('--megatron-path', type=str, default=None) return parser.parse_args() def load_checkpoint_args(checkpoint_root_path): From 417d9c6aed1f952de36452791b71138b5b77ddf9 Mon Sep 17 00:00:00 2001 From: yanxi-chen Date: Mon, 29 Jan 2024 15:07:44 +0800 Subject: [PATCH 11/12] Update readme (EE-Tuning and overall) --- README.md | 92 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index aaca9f85..1a1a638a 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,23 @@ -# README +# EE-LLM: Early-Exit Large Language Models -[EE-LLM](https://arxiv.org/abs/2312.04916) is a framework for large-scale training, tunning and inference of early-exit (EE) large language models (LLMs), which is built upon [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). + +[EE-LLM](https://arxiv.org/abs/2312.04916) is a framework for large-scale training and inference of early-exit (EE) large language models (LLMs), which is built upon [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) and compatible with 3D parallelism (namely data, tensor, sequence and pipeline parallelism). ![](images/ee_architecture.png) + +As shown in the above figure, an early-exit LLM can convert intermediate hidden states into outputs. +During inference, the model can select adaptively one early/final exit to generate the output for each input, without running the full-model forward pass. + +Our system supports two methods of training early-exit LLMs: + +- Full-parameter training, which updates model parameters by optimizing a weighted sum of losses from multiple exits; +- EE-Tuning, a parameter-efficient approach that augments an existing pre-trained LLM with early-exit layers and tunes them while modules of the original LLM are frozen. + +Further details about the usage and functionalities of EE-LLM are introduced in the following. + + + ## Installation The installation of EE-LLM is the same as Megatron-LM. @@ -12,12 +26,12 @@ We recommand using the 22.12 version of [NGC's PyTorch container](https://catalo For more details about the installation of Megatron-LM, please refer to Megatron-LM's [README](README_Megatron_LM.md). -## Training +## Full-parameter training Below are several example training scripts used in our paper. -``` +```shell # train 1.3B model ./examples/ee_training/1-3B.sh @@ -40,7 +54,7 @@ for more details about Megatron-LM's data preprocessing, please refer to [Data P > Running the training scripts requires 16 Nvidia A100-80G GPUs or higher hardware specifications. To run them with fewer GPUs, please set the parallelism degrees therein to smaller values. -Below are the new configurations of EE-LLM compared to Megatron-LM. You can customize your own early-exit LLM by modifying these configurations. +Below are some new configurations of EE-LLM compared to Megatron-LM. You can customize your own early-exit LLM by modifying these configurations. ### Configurations for model architectures @@ -76,51 +90,71 @@ Below are the new configurations of EE-LLM compared to Megatron-LM. You can cust - `--backward-forward-ratio`: An estimate of the ratio of time consumption between backward and forward computation during training, used to automatically calculate the optimal number of inserted microbatches. Default to 2.0. [Experimental] -## Tuning -EE-LLM has supported to tune an existing standard LLM into a early-exit LLM, which is called "EE-Tuning". +## EE-Tuning + + +> Before using EE-Tuning, please make sure that the existing LLM checkpoint is in Megatron-LM format. +> As an example, `examples/ee_tuning/convert/convert_llama_hf.sh` provides the functionality of converting the Llama 2 HuggingFace checkpoint into Megatron-LM format. + + +### Stage 1: initialize early-exit layers + +The first step of EE-Tuning is to use `tools/checkpoint/checkpoint_converter.py` to add early-exit layers to the standard LLM checkpoint. +Example scripts can be found in the following file: -> Before using EE-Tunning, please make sure your existing LLM checkpoint is in Megatron-LM format. -> As an example, `examples/ee_tuning/convert/convert_llama_hf.sh` provides the functionality to convert the llama2 huggingface checkpoint into Megatron-LM format. +```shell +examples/ee_tuning/convert/add_exit_layers.sh +``` + +The relevant arguments are listed below: + +- `--load-dir`: Path to the standard LLM checkpoint in Megatron-LM format. + +- `--load-iteration`: The iteration number of the checkpoint to be loaded. + +- `--save-dir`: Path to the output early-exit LLM checkpoint. + +- `--add-exit-layer-nums`: Indices of the backbone Transformer layers that early exits are added to. + +- `--use-exit-norm`: Add layer normalization (LayerNorm/RMSNorm) to the early-exit layer. -First, use `tools/checkpoint/checkpoint_converter.py` to add early-exit modules to the checkpoint, and the arguments of the script are listed below: +- `--use-exit-mlp`: Add a MLP to the early-exit layer. + +- `--use-exit-block`: Add a Transformer layer to the early-exit layer. + +- `--random-init`: Initialize model parameters of early-exit layers randomly. Otherwise, they are initialized as duplication of certain modules of the original LLM. -- `--load-dir`: The Megatron-LM format standard LLM checkpoint file path. -- `--load-iteration`: The iteration number of checkpoints to be loaded. -- `--save-dir`: The output early-exit LLM checkpoint file path. -- `--add-exit-layer-nums`: The layer numbers of the layers where the early-exit module needs to be added. -- `--use-exit-norm`: Add a norm to the early-exit module. -- `--use-exit-mlp`: Add an MLP to the early-exit module. -- `--use-exit-block`: Add an transformer layer to the early-exit module. -- `--random-init`: Initialize the early-exit module randomly. - `--megatron-path`: Path to EE-LLM root directory. -> `examples/ee_tuning/convert/add_exit_layers.sh` provides some conversion examples. -Then, tune the converted checkpoint using similar method as shown [training process](#training). Below are some example scripts. +### Stage 2: tune early-exit layers + +The second step of EE-Tuning is to tune the early-exit layers of the converted checkpoint, using scripts similar to those for [full-parameter training](#training). Below are some example scripts. ```shell -# tune llama2 13B chat with 8 exits +# tune Llama 2-Chat 13B with 8 exits ./examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh -# tune llama2 13B chat with 1 exit (only load the first 1/4 of the model) +# tune Llama 2-Chat 13B with 1 exit (only load the first 1/4 of the model) ./examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh ``` -Here are the new parameters added by EE-Tuning. +Below are the new parameters relevant to EE-Tuning. Other parameters are the same as those for full-parameter training. + +- `--tune-exit`: Activate the functionality of EE-Tuning. + +- `--tune-exit-pipeline-parallel-size`: Used to support partial checkpoint loading, only load pipeline stages whose stage numbers are not larger than this value. -- `--tune-exit`: Activate the functionality of EE-tunning. -- `--tune-exit-pipeline-parallel-size`: Used to support partial loading functionality, only load pipeline stages whose stage number not larger than this value. -Other parameters are the same as the training part. ## Inference -We provided an text generation server for inference of early-exit LLMs. +We provided a text generation server for inference of early-exit LLMs. To start a server, you can use the following script. -Before running, please set `CHECKPOINT_PATH` to the root folder path of the checkpoint, and set `TP` and `PP` appropriately according to the parallelism of the checkpoint. +Before running, please set `CHECKPOINT_PATH` to the root folder path of the checkpoint, and set `TP` and `PP` appropriately according to the parallelism degrees of the checkpoint. -``` +```shell ./example/ee_inference/ee_inference_server.sh ``` From b696a9f29c19e090c6aa6e771c24ad771f517876 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 29 Jan 2024 11:47:15 +0000 Subject: [PATCH 12/12] update docs and scripts for ee_tuning --- README.md | 9 +++++++-- examples/ee_inference/ee_inference_server.sh | 4 ++++ examples/{ee_traning => ee_training}/1-3B.sh | 2 +- examples/{ee_traning => ee_training}/13B.sh | 4 ++-- examples/{ee_traning => ee_training}/30B.sh | 4 ++-- examples/{ee_traning => ee_training}/7B.sh | 2 +- .../ee_tuning/convert/convert_llama_hf.sh | 2 +- .../tune/llama2_13B_1_exit_mlp_pt.sh | 17 +++++++++-------- .../tune/llama2_13B_8_exit_mlp_pt.sh | 17 +++++++++-------- .../tune/llama2_70B_1_exit_mlp_pt.sh | 19 +++++++++---------- .../tune/llama2_70B_8_exit_mlp_pt.sh | 19 +++++++++---------- 11 files changed, 54 insertions(+), 45 deletions(-) rename examples/{ee_traning => ee_training}/1-3B.sh (99%) rename examples/{ee_traning => ee_training}/13B.sh (98%) rename examples/{ee_traning => ee_training}/30B.sh (98%) rename examples/{ee_traning => ee_training}/7B.sh (99%) diff --git a/README.md b/README.md index 1a1a638a..d57dacb3 100644 --- a/README.md +++ b/README.md @@ -39,10 +39,10 @@ Below are several example training scripts used in our paper. ./examples/ee_training/7B.sh # train 13B model -./example/ee_training/13B.sh +./examples/ee_training/13B.sh # train 30B model -./example/ee_training/30B.sh +./examples/ee_training/30B.sh ``` @@ -165,8 +165,13 @@ Below are some parameters for early-exit LLM inference, which can be found in `t - `early_exit_thres`: The confidence threshold used to determine whether to execute early exiting, ranging from 0.0 to 1.0. +- `exit_layers`: Only the early-exit layers listed here will be activated. If empty, all available early-exit layers will be activated. + - `print_max_prob`: If set, the inference server will print the token with the highest confidence and the confidence values at all exits. +## Checkpoints + +The model checkpoints mentioned in our paper will be released soon. ## BibTeX diff --git a/examples/ee_inference/ee_inference_server.sh b/examples/ee_inference/ee_inference_server.sh index dfe9fd3b..7a3f9f97 100755 --- a/examples/ee_inference/ee_inference_server.sh +++ b/examples/ee_inference/ee_inference_server.sh @@ -37,6 +37,10 @@ SERVER_ARGS=" --port $PORT " +CUR_DIR=$(cd $(dirname "$0") && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +cd $MEGATRON_ROOT_PATH + torchrun $DIST_ARGS \ tools/run_early_exit_text_generation_server.py \ $SERVER_ARGS diff --git a/examples/ee_traning/1-3B.sh b/examples/ee_training/1-3B.sh similarity index 99% rename from examples/ee_traning/1-3B.sh rename to examples/ee_training/1-3B.sh index 74bb9d7c..2c3896a9 100755 --- a/examples/ee_traning/1-3B.sh +++ b/examples/ee_training/1-3B.sh @@ -139,7 +139,7 @@ OUTPUT_ARGS=" --log-timers-to-tracker \ --save-interval $SAVE_INTERVAL \ --eval-interval $EVAL_INTERVAL \ - --eval-iters 0 \ + --eval-iters 10 \ --wandb-project $PROJECT_NAME \ --wandb-group $GROUP_NAME \ --wandb-exp-name $RUN_NAME \ diff --git a/examples/ee_traning/13B.sh b/examples/ee_training/13B.sh similarity index 98% rename from examples/ee_traning/13B.sh rename to examples/ee_training/13B.sh index 23893587..ba7e3db6 100755 --- a/examples/ee_traning/13B.sh +++ b/examples/ee_training/13B.sh @@ -1,7 +1,7 @@ #!/bin/bash PROJECT_NAME=EE-LLM -GROUP_NAME=7B-EXIT-8-16-untie-800B +GROUP_NAME=13B-EXIT-10-20-untie-800B RUN_NAME=`date "+%m%d-%H%M"` @@ -141,7 +141,7 @@ OUTPUT_ARGS=" --log-timers-to-tracker \ --save-interval $SAVE_INTERVAL \ --eval-interval $EVAL_INTERVAL \ - --eval-iters 0 \ + --eval-iters 10 \ --wandb-project $PROJECT_NAME \ --wandb-group $GROUP_NAME \ --wandb-exp-name $RUN_NAME \ diff --git a/examples/ee_traning/30B.sh b/examples/ee_training/30B.sh similarity index 98% rename from examples/ee_traning/30B.sh rename to examples/ee_training/30B.sh index e00c15d0..d284b430 100755 --- a/examples/ee_traning/30B.sh +++ b/examples/ee_training/30B.sh @@ -1,7 +1,7 @@ #!/bin/bash PROJECT_NAME=EE-LLM -GROUP_NAME=7B-EXIT-8-16-untie-800B +GROUP_NAME=30B-EXIT-15-30-untie-800B RUN_NAME=`date "+%m%d-%H%M"` @@ -141,7 +141,7 @@ OUTPUT_ARGS=" --log-timers-to-tracker \ --save-interval $SAVE_INTERVAL \ --eval-interval $EVAL_INTERVAL \ - --eval-iters 0 \ + --eval-iters 10 \ --wandb-project $PROJECT_NAME \ --wandb-group $GROUP_NAME \ --wandb-exp-name $RUN_NAME \ diff --git a/examples/ee_traning/7B.sh b/examples/ee_training/7B.sh similarity index 99% rename from examples/ee_traning/7B.sh rename to examples/ee_training/7B.sh index 3cd0c622..74c3abef 100755 --- a/examples/ee_traning/7B.sh +++ b/examples/ee_training/7B.sh @@ -141,7 +141,7 @@ OUTPUT_ARGS=" --log-timers-to-tracker \ --save-interval $SAVE_INTERVAL \ --eval-interval $EVAL_INTERVAL \ - --eval-iters 0 \ + --eval-iters 10 \ --wandb-project $PROJECT_NAME \ --wandb-group $GROUP_NAME \ --wandb-exp-name $RUN_NAME \ diff --git a/examples/ee_tuning/convert/convert_llama_hf.sh b/examples/ee_tuning/convert/convert_llama_hf.sh index fb874cff..02b226c3 100755 --- a/examples/ee_tuning/convert/convert_llama_hf.sh +++ b/examples/ee_tuning/convert/convert_llama_hf.sh @@ -5,7 +5,7 @@ SAVE_DIR= # path to save the converted megatron checkpoint TP=1 # target tensor parallel size PP=4 # target pipeline parallel size -TOKENIZER_PATH= ${LOAD_DIR}/tokenizer.model +TOKENIZER_PATH=${LOAD_DIR}/tokenizer.model CUR_DIR=$(cd $(dirname "$0") && pwd) MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../../.." && pwd) diff --git a/examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh b/examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh index 190e0888..1383c2f9 100755 --- a/examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh +++ b/examples/ee_tuning/tune/llama2_13B_1_exit_mlp_pt.sh @@ -11,10 +11,9 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export OMP_NUM_THREADS=4 # Checkpoint configuration -MODEL_HOME= -LOAD_PATH=${MODEL_HOME}/checkpoints/MET-EXP/llama2-13b-chat-1-exit # your checkpoint path -CHECKPOINT_PATH=${MODEL_HOME}/checkpoints/$PROJECT_NAME/$GROUP_NAME -TOKENIZER_PATH=${MODEL_HOME}/tokenizer/tokenizer.model +LOAD_PATH= # your checkpoint path +TOKENIZER_PATH= # your tokenizer path +CHECKPOINT_PATH= # checkpoint save path # Data configuration DATA_HOME= @@ -108,14 +107,16 @@ GPT_ARGS=" --normalization RMSNorm \ --position-embedding-type rope \ --swiglu \ - --exit-layer-nums 10 \ --untie-embeddings-and-output-weights \ - --untie-exit-output-weights \ --padded-vocab-size 32000 \ --ffn-hidden-size $FFN_SIZE \ --finetune \ - --tune-exit-pipeline-parallel-size 1 \ --tune-exit \ + --untie-exit-output-weights \ + --use-exit-norm \ + --use-exit-mlp \ + --tune-exit-pipeline-parallel-size 4 \ + --exit-layer-nums 10 \ " DATA_ARGS=" @@ -137,7 +138,7 @@ OUTPUT_ARGS=" " CUR_DIR=$(cd $(dirname "$0") && pwd) -MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../../.." && pwd) cd $MEGATRON_ROOT_PATH torchrun $DIST_ARGS \ diff --git a/examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh b/examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh index 2500bb4e..d1ee4d5d 100755 --- a/examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh +++ b/examples/ee_tuning/tune/llama2_13B_8_exit_mlp_pt.sh @@ -11,10 +11,9 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export OMP_NUM_THREADS=4 # Checkpoint configuration -MODEL_HOME= -LOAD_PATH=${MODEL_HOME}/checkpoints/llama2-13b-chat-8-exit # your checkpoint path -CHECKPOINT_PATH=${MODEL_HOME}/checkpoints/$PROJECT_NAME/$GROUP_NAME -TOKENIZER_PATH=${MODEL_HOME}/tokenizer/tokenizer.model +LOAD_PATH= # your checkpoint path +TOKENIZER_PATH= # your tokenizer path +CHECKPOINT_PATH= # checkpoint save path # Data configuration DATA_HOME= @@ -108,14 +107,16 @@ GPT_ARGS=" --normalization RMSNorm \ --position-embedding-type rope \ --swiglu \ - --exit-layer-nums 5 10 15 20 25 30 35 40 \ --untie-embeddings-and-output-weights \ - --untie-exit-output-weights \ --padded-vocab-size 32000 \ --ffn-hidden-size $FFN_SIZE \ --finetune \ - --tune-exit-pipeline-parallel-size 4 \ --tune-exit \ + --untie-exit-output-weights \ + --use-exit-norm \ + --use-exit-mlp \ + --tune-exit-pipeline-parallel-size 4 \ + --exit-layer-nums 5 10 15 20 25 30 35 40 \ " DATA_ARGS=" @@ -137,7 +138,7 @@ OUTPUT_ARGS=" " CUR_DIR=$(cd $(dirname "$0") && pwd) -MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../../.." && pwd) cd $MEGATRON_ROOT_PATH torchrun $DIST_ARGS \ diff --git a/examples/ee_tuning/tune/llama2_70B_1_exit_mlp_pt.sh b/examples/ee_tuning/tune/llama2_70B_1_exit_mlp_pt.sh index 3ce586ad..57f06fd4 100755 --- a/examples/ee_tuning/tune/llama2_70B_1_exit_mlp_pt.sh +++ b/examples/ee_tuning/tune/llama2_70B_1_exit_mlp_pt.sh @@ -11,10 +11,9 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export OMP_NUM_THREADS=4 # Checkpoint configuration -MODEL_HOME= -LOAD_PATH=${MODEL_HOME}/checkpoints/MET-EXP/llama2-70b-chat-8-exit # your checkpoint path -CHECKPOINT_PATH=${MODEL_HOME}/checkpoints/$PROJECT_NAME/$GROUP_NAME -TOKENIZER_PATH=${MODEL_HOME}/tokenizer/tokenizer.model +LOAD_PATH= # your checkpoint path +TOKENIZER_PATH= # your tokenizer path +CHECKPOINT_PATH= # checkpoint save path # Data configuration DATA_HOME= @@ -110,16 +109,16 @@ GPT_ARGS=" --swiglu \ --group-query-attention \ --num-query-groups 8 \ - --exit-layer-nums 20 \ - --use-exit-norm \ - --use-exit-mlp \ --untie-embeddings-and-output-weights \ - --untie-exit-output-weights \ --padded-vocab-size 32000 \ --ffn-hidden-size $FFN_SIZE \ --finetune \ - --tune-exit-pipeline-parallel-size 1 \ --tune-exit \ + --untie-exit-output-weights \ + --use-exit-norm \ + --use-exit-mlp \ + --tune-exit-pipeline-parallel-size 1 \ + --exit-layer-nums 20 \ " DATA_ARGS=" @@ -141,7 +140,7 @@ OUTPUT_ARGS=" " CUR_DIR=$(cd $(dirname "$0") && pwd) -MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../../.." && pwd) cd $MEGATRON_ROOT_PATH torchrun $DIST_ARGS \ diff --git a/examples/ee_tuning/tune/llama2_70B_8_exit_mlp_pt.sh b/examples/ee_tuning/tune/llama2_70B_8_exit_mlp_pt.sh index 892d28c9..28d86aaa 100755 --- a/examples/ee_tuning/tune/llama2_70B_8_exit_mlp_pt.sh +++ b/examples/ee_tuning/tune/llama2_70B_8_exit_mlp_pt.sh @@ -11,10 +11,9 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export OMP_NUM_THREADS=4 # Checkpoint configuration -MODEL_HOME= -LOAD_PATH=${MODEL_HOME}/checkpoints/MET-EXP/llama2-70b-chat-8-exit # your checkpoint path -CHECKPOINT_PATH=${MODEL_HOME}/checkpoints/$PROJECT_NAME/$GROUP_NAME -TOKENIZER_PATH=${MODEL_HOME}/tokenizer/tokenizer.model +LOAD_PATH= # your checkpoint path +TOKENIZER_PATH= # your tokenizer path +CHECKPOINT_PATH= # checkpoint save path # Data configuration DATA_HOME= @@ -110,16 +109,16 @@ GPT_ARGS=" --swiglu \ --group-query-attention \ --num-query-groups 8 \ - --exit-layer-nums 10 20 30 40 50 60 70 80 \ - --use-exit-norm \ - --use-exit-mlp \ --untie-embeddings-and-output-weights \ - --untie-exit-output-weights \ --padded-vocab-size 32000 \ --ffn-hidden-size $FFN_SIZE \ --finetune \ - --tune-exit-pipeline-parallel-size 4 \ --tune-exit \ + --untie-exit-output-weights \ + --use-exit-norm \ + --use-exit-mlp \ + --tune-exit-pipeline-parallel-size 4 \ + --exit-layer-nums 10 20 30 40 50 60 70 80 \ " DATA_ARGS=" @@ -141,7 +140,7 @@ OUTPUT_ARGS=" " CUR_DIR=$(cd $(dirname "$0") && pwd) -MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../.." && pwd) +MEGATRON_ROOT_PATH=$(cd "$CUR_DIR/../../.." && pwd) cd $MEGATRON_ROOT_PATH torchrun $DIST_ARGS \