Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 45 additions & 45 deletions scripts/train_eagle3_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import torch.distributed as dist
from accelerate.utils import set_seed
from datasets import load_dataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from tqdm import tqdm
from transformers import AutoTokenizer

Expand All @@ -21,7 +20,7 @@
generate_vocab_mapping_file,
prepare_dp_dataloaders,
)
from specforge.distributed import destroy_distributed, get_dp_group, init_distributed
from specforge.distributed import destroy_distributed, get_dp_group, init_distributed, get_dp_device_mesh
from specforge.modeling.target.target_head import TargetHead
from specforge.optimizer import BF16Optimizer
from specforge.tracker import create_tracker, get_tracker_class
Expand All @@ -31,6 +30,7 @@
print_on_rank0,
print_with_rank,
rank_0_priority,
get_full_optimizer_state,
)


Expand Down Expand Up @@ -340,19 +340,10 @@ def main():
length=args.ttt_length,
attention_backend=args.draft_attention_backend,
)
eagle3_model = FSDP(
eagle3_model,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
keep_low_precision_grads=False,
),
sharding_strategy=ShardingStrategy.NO_SHARD,
ignored_modules=[],
process_group=get_dp_group(),
)
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
fsdp_config = {"mesh": get_dp_device_mesh(), "mp_policy": mp_policy}
fully_shard(eagle3_model, **fsdp_config)

print_with_rank(f"Initialized Eagle3 FSDP model")
global_step, batch_index = 0, 0
log_dict = defaultdict(float)
Expand All @@ -373,8 +364,8 @@ def main():
# Run training
train_dataloader.sampler.set_epoch(epoch + 1)
draft_model.train()
epoch_acces = [[] for _ in range(eagle3_model.module.length)]
epoch_plosses = [[] for _ in range(eagle3_model.module.length)]
epoch_acces = [[] for _ in range(eagle3_model.length)]
epoch_plosses = [[] for _ in range(eagle3_model.length)]

if dist.get_rank() == 0:
progress_bar = tqdm(
Expand Down Expand Up @@ -530,33 +521,42 @@ def main():
os.makedirs(epoch_output_dir, exist_ok=True)
dist.barrier()

with FSDP.state_dict_type(eagle3_model, StateDictType.FULL_STATE_DICT):
model_state_dict = eagle3_model.state_dict()
state_to_save = {
"epoch": epoch,
"args": args,
}
state_to_save.update(optimizer.state_dict())
draft_model_state_dict = {
k.replace("draft_model.", ""): v
for k, v in model_state_dict.items()
if "draft_model." in k and "embed" not in k.lower()
}

if dist.get_rank() == 0:
torch.save(
state_to_save,
os.path.join(epoch_output_dir, "training_state.pt"),
)
print_on_rank0(
f"Saved full training state to {epoch_output_dir}/training_state.pt"
)
draft_model.save_pretrained(
epoch_output_dir,
state_dict=draft_model_state_dict,
)
print_on_rank0(f"Saved model configuration to {epoch_output_dir}")
dist.barrier()
model_state_dict = eagle3_model.state_dict()

state_to_save = {
"epoch": epoch,
"args": args,
}

optimizer_state_dict = optimizer.state_dict()
optimizer_state_dict["optimizer_state_dict"] = get_full_optimizer_state(optimizer_state_dict["optimizer_state_dict"])

state_to_save.update(optimizer_state_dict)

draft_model_state_dict = {
k.replace("draft_model.", ""): (
v.full_tensor()
if isinstance(v, torch.distributed.tensor.DTensor)
else v
)
for k, v in model_state_dict.items()
if "draft_model." in k and "embed" not in k.lower()
}

if dist.get_rank() == 0:
torch.save(
state_to_save,
os.path.join(epoch_output_dir, "training_state.pt"),
)
print_on_rank0(
f"Saved full training state to {epoch_output_dir}/training_state.pt"
)
draft_model.save_pretrained(
epoch_output_dir,
state_dict=draft_model_state_dict,
)
print_on_rank0(f"Saved model configuration to {epoch_output_dir}")
dist.barrier()

# Close the tracker at the end of training
tracker.close()
Expand Down
9 changes: 7 additions & 2 deletions specforge/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
_DEVICE_MESH = None
_TP_DEVICE_MESH = None
_TP_GROUP = None
_DP_DEVICE_MESH = None
_DP_GROUP = None


def get_tp_group():
global _TP_GROUP
return _TP_GROUP
Expand All @@ -30,6 +30,10 @@ def get_tp_device_mesh():
global _TP_DEVICE_MESH
return _TP_DEVICE_MESH

def get_dp_device_mesh():
global _DP_DEVICE_MESH
return _DP_DEVICE_MESH


def init_distributed(timeout: int = 10, tp_size: int = 1):
"""Initialize distributed training.
Expand All @@ -55,11 +59,12 @@ def init_distributed(timeout: int = 10, tp_size: int = 1):

# we need to create a 1D submesh
tp_device_mesh = dist.DeviceMesh.from_group(tp_group, device_type="cuda")
global _TP_GROUP, _DP_GROUP, _DEVICE_MESH, _TP_DEVICE_MESH
global _TP_GROUP, _DP_GROUP, _DEVICE_MESH, _TP_DEVICE_MESH, _DP_DEVICE_MESH
_DEVICE_MESH = device_mesh
_TP_GROUP = tp_group
_TP_DEVICE_MESH = tp_device_mesh
_DP_GROUP = dp_group
_DP_DEVICE_MESH = dist.DeviceMesh.from_group(dp_group, device_type="cuda")


def destroy_distributed():
Expand Down
26 changes: 26 additions & 0 deletions specforge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,29 @@ def create_draft_config_from_target(
dist.barrier()

return output_path

def get_full_optimizer_state(optimizer_state_dict: dict):
"""
Convert optimizer state dict with DTensor to full tensors for saving

Args:
optimizer_state_dict (dict): Optimizer state dict possibly containing DTensors
Returns:
dict: Optimizer state dict with full tensors
"""
full_optimizer_state_dict = {
k: v for k, v in optimizer_state_dict.items() if k != "state"
}
if "state" in optimizer_state_dict:
full_optimizer_state_dict["state"] = {
param_id: {
state_key: (
state_tensor.full_tensor()
if isinstance(state_tensor, torch.distributed.tensor.DTensor)
else state_tensor
)
for state_key, state_tensor in param_state.items()
}
for param_id, param_state in optimizer_state_dict["state"].items()
}
return full_optimizer_state_dict