diff --git a/scripts/train_eagle3_offline.py b/scripts/train_eagle3_offline.py index 83f3961c..5977b866 100644 --- a/scripts/train_eagle3_offline.py +++ b/scripts/train_eagle3_offline.py @@ -9,8 +9,7 @@ import torch.distributed as dist from accelerate.utils import set_seed from datasets import load_dataset -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType +from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy from tqdm import tqdm from transformers import AutoTokenizer @@ -21,7 +20,7 @@ generate_vocab_mapping_file, prepare_dp_dataloaders, ) -from specforge.distributed import destroy_distributed, get_dp_group, init_distributed +from specforge.distributed import destroy_distributed, get_dp_group, init_distributed, get_dp_device_mesh from specforge.modeling.target.target_head import TargetHead from specforge.optimizer import BF16Optimizer from specforge.tracker import create_tracker, get_tracker_class @@ -31,6 +30,7 @@ print_on_rank0, print_with_rank, rank_0_priority, + get_full_optimizer_state, ) @@ -340,19 +340,10 @@ def main(): length=args.ttt_length, attention_backend=args.draft_attention_backend, ) - eagle3_model = FSDP( - eagle3_model, - use_orig_params=True, - mixed_precision=MixedPrecision( - param_dtype=torch.bfloat16, - buffer_dtype=torch.bfloat16, - reduce_dtype=torch.float32, - keep_low_precision_grads=False, - ), - sharding_strategy=ShardingStrategy.NO_SHARD, - ignored_modules=[], - process_group=get_dp_group(), - ) + mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) + fsdp_config = {"mesh": get_dp_device_mesh(), "mp_policy": mp_policy} + fully_shard(eagle3_model, **fsdp_config) + print_with_rank(f"Initialized Eagle3 FSDP model") global_step, batch_index = 0, 0 log_dict = defaultdict(float) @@ -373,8 +364,8 @@ def main(): # Run training train_dataloader.sampler.set_epoch(epoch + 1) draft_model.train() - epoch_acces = [[] for _ in range(eagle3_model.module.length)] - epoch_plosses = [[] for _ in range(eagle3_model.module.length)] + epoch_acces = [[] for _ in range(eagle3_model.length)] + epoch_plosses = [[] for _ in range(eagle3_model.length)] if dist.get_rank() == 0: progress_bar = tqdm( @@ -530,33 +521,42 @@ def main(): os.makedirs(epoch_output_dir, exist_ok=True) dist.barrier() - with FSDP.state_dict_type(eagle3_model, StateDictType.FULL_STATE_DICT): - model_state_dict = eagle3_model.state_dict() - state_to_save = { - "epoch": epoch, - "args": args, - } - state_to_save.update(optimizer.state_dict()) - draft_model_state_dict = { - k.replace("draft_model.", ""): v - for k, v in model_state_dict.items() - if "draft_model." in k and "embed" not in k.lower() - } - - if dist.get_rank() == 0: - torch.save( - state_to_save, - os.path.join(epoch_output_dir, "training_state.pt"), - ) - print_on_rank0( - f"Saved full training state to {epoch_output_dir}/training_state.pt" - ) - draft_model.save_pretrained( - epoch_output_dir, - state_dict=draft_model_state_dict, - ) - print_on_rank0(f"Saved model configuration to {epoch_output_dir}") - dist.barrier() + model_state_dict = eagle3_model.state_dict() + + state_to_save = { + "epoch": epoch, + "args": args, + } + + optimizer_state_dict = optimizer.state_dict() + optimizer_state_dict["optimizer_state_dict"] = get_full_optimizer_state(optimizer_state_dict["optimizer_state_dict"]) + + state_to_save.update(optimizer_state_dict) + + draft_model_state_dict = { + k.replace("draft_model.", ""): ( + v.full_tensor() + if isinstance(v, torch.distributed.tensor.DTensor) + else v + ) + for k, v in model_state_dict.items() + if "draft_model." in k and "embed" not in k.lower() + } + + if dist.get_rank() == 0: + torch.save( + state_to_save, + os.path.join(epoch_output_dir, "training_state.pt"), + ) + print_on_rank0( + f"Saved full training state to {epoch_output_dir}/training_state.pt" + ) + draft_model.save_pretrained( + epoch_output_dir, + state_dict=draft_model_state_dict, + ) + print_on_rank0(f"Saved model configuration to {epoch_output_dir}") + dist.barrier() # Close the tracker at the end of training tracker.close() diff --git a/specforge/distributed.py b/specforge/distributed.py index e26d05a4..d61a7ae2 100644 --- a/specforge/distributed.py +++ b/specforge/distributed.py @@ -8,9 +8,9 @@ _DEVICE_MESH = None _TP_DEVICE_MESH = None _TP_GROUP = None +_DP_DEVICE_MESH = None _DP_GROUP = None - def get_tp_group(): global _TP_GROUP return _TP_GROUP @@ -30,6 +30,10 @@ def get_tp_device_mesh(): global _TP_DEVICE_MESH return _TP_DEVICE_MESH +def get_dp_device_mesh(): + global _DP_DEVICE_MESH + return _DP_DEVICE_MESH + def init_distributed(timeout: int = 10, tp_size: int = 1): """Initialize distributed training. @@ -55,11 +59,12 @@ def init_distributed(timeout: int = 10, tp_size: int = 1): # we need to create a 1D submesh tp_device_mesh = dist.DeviceMesh.from_group(tp_group, device_type="cuda") - global _TP_GROUP, _DP_GROUP, _DEVICE_MESH, _TP_DEVICE_MESH + global _TP_GROUP, _DP_GROUP, _DEVICE_MESH, _TP_DEVICE_MESH, _DP_DEVICE_MESH _DEVICE_MESH = device_mesh _TP_GROUP = tp_group _TP_DEVICE_MESH = tp_device_mesh _DP_GROUP = dp_group + _DP_DEVICE_MESH = dist.DeviceMesh.from_group(dp_group, device_type="cuda") def destroy_distributed(): diff --git a/specforge/utils.py b/specforge/utils.py index 8064afc2..710fe8c4 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -219,3 +219,29 @@ def create_draft_config_from_target( dist.barrier() return output_path + +def get_full_optimizer_state(optimizer_state_dict: dict): + """ + Convert optimizer state dict with DTensor to full tensors for saving + + Args: + optimizer_state_dict (dict): Optimizer state dict possibly containing DTensors + Returns: + dict: Optimizer state dict with full tensors + """ + full_optimizer_state_dict = { + k: v for k, v in optimizer_state_dict.items() if k != "state" + } + if "state" in optimizer_state_dict: + full_optimizer_state_dict["state"] = { + param_id: { + state_key: ( + state_tensor.full_tensor() + if isinstance(state_tensor, torch.distributed.tensor.DTensor) + else state_tensor + ) + for state_key, state_tensor in param_state.items() + } + for param_id, param_state in optimizer_state_dict["state"].items() + } + return full_optimizer_state_dict \ No newline at end of file