Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions scripts/train_eagle3_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,20 +541,11 @@ def main():
for k, v in model_state_dict.items()
if "draft_model." in k and "embed" not in k.lower()
}
draft_model.save_pretrained(
os.path.join(args.output_dir, f"epoch_{epoch}"),
state_dict=draft_model_state_dict,
)

if dist.get_rank() == 0:
torch.save(
state_to_save,
os.path.join(epoch_output_dir, "training_state.pt"),
)
print_on_rank0(
f"Saved full training state to {epoch_output_dir}/training_state.pt"
)
draft_model.save_pretrained(
epoch_output_dir,
state_dict=draft_model_state_dict,
)
print_on_rank0(f"Saved model configuration to {epoch_output_dir}")
dist.barrier()

# Close the tracker at the end of training
Expand Down
35 changes: 16 additions & 19 deletions scripts/train_eagle3_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,22 +291,17 @@ def main():

# load model with resume
if draft_model_last_checkpoint:
draft_model = (
AutoEagle3DraftModel.from_pretrained(
draft_model_last_checkpoint, attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16
)
.cuda()

)
draft_model = AutoEagle3DraftModel.from_pretrained(
draft_model_last_checkpoint,
attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16,
).cuda()
else:
draft_model = (
AutoEagle3DraftModel.from_config(
draft_model_config, attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16
)
.cuda()
)
draft_model = AutoEagle3DraftModel.from_config(
draft_model_config,
attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16,
).cuda()
draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key)
draft_model.freeze_embedding()
print_with_rank("Initialized draft model")
Expand Down Expand Up @@ -652,6 +647,12 @@ def main():
if "draft_model." in k and "embed" not in k.lower()
}

# The new save_pretrained method handles all TP logic internally.
# It ensures only global rank 0 writes to disk.
draft_model.save_pretrained(
epoch_output_dir,
state_dict=draft_model_state_dict,
)
if dist.get_rank() == 0:
torch.save(
state_to_save,
Expand All @@ -660,10 +661,6 @@ def main():
print_on_rank0(
f"Saved full training state to {epoch_output_dir}/training_state.pt"
)
draft_model.save_pretrained(
epoch_output_dir,
state_dict=draft_model_state_dict,
)
print_on_rank0(f"Saved model configuration to {epoch_output_dir}")
dist.barrier()

Expand Down
15 changes: 15 additions & 0 deletions specforge/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

from specforge.distributed import get_tp_group

Expand Down Expand Up @@ -125,3 +126,17 @@ def load_state_dict(self, state_dict, strict=True):

def __repr__(self):
return f"ColumnParallelLinear(in_features={self.in_features}, out_features={self.out_features_per_shard}, tp_size={self.tp_size}, tp_rank={self.tp_rank})"


class _AllReduce(Function):
@staticmethod
def forward(ctx, input, op, group):
# ctx is a context object that can be used to stash information for backward computation
output = input.clone()
dist.all_reduce(output, op=op, group=group)
return output

@staticmethod
def backward(ctx, grad_output):
# # The gradient of all_reduce is an identity function, so we can directly return the gradient
return grad_output, None, None
71 changes: 71 additions & 0 deletions specforge/modeling/draft/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
from typing import Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers.cache_utils import Cache
from transformers.modeling_utils import PreTrainedModel

from specforge.distributed import get_tp_group
from specforge.layers.linear import ColumnParallelLinear, RowParallelLinear
from specforge.modeling._mask_utils import _expand_mask, _make_causal_mask


Expand Down Expand Up @@ -191,3 +194,71 @@ def load_vocab_mapping(self, file_path: str) -> None:
vocab_mapping = torch.load(file_path)
self.t2d.copy_(vocab_mapping["t2d"])
self.d2t.copy_(vocab_mapping["d2t"])

def save_pretrained(self, save_directory, state_dict=None, **kwargs):
"""
Overrides save_pretrained to handle TP weight aggregation robustly.
This method gathers sharded weights from all TP ranks and saves a single,
complete checkpoint from the main process.
"""
if not dist.is_initialized():
# Standard non-distributed save
super().save_pretrained(save_directory, state_dict=state_dict, **kwargs)
return

# Use the provided state_dict or get it from the model
if state_dict is None:
state_dict = self.state_dict()

# Get distributed process groups and ranks
global_rank = dist.get_rank()
tp_group = get_tp_group()
tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0

# If not using TP, only rank 0 saves and others do nothing.
if tp_size <= 1:
if global_rank == 0:
super().save_pretrained(save_directory, state_dict=state_dict, **kwargs)
dist.barrier()
return

# --- Aggregation Logic for TP > 1 ---
# Step 1: Each TP rank's leader (tp_rank == 0) will reconstruct the full state dict.
reconstructed_state_dict = None
if tp_rank == 0:
reconstructed_state_dict = {}

# All ranks in a TP group participate in gathering shards for each parameter.
modules = dict(self.named_modules())
for name, param in state_dict.items():
# Gather shards from all TP ranks into a list
tensor_list = [torch.empty_like(param) for _ in range(tp_size)]
dist.all_gather(tensor_list, param.contiguous(), group=tp_group)

# Let the tp_rank 0 process handle the concatenation
if tp_rank == 0:
module_name = ".".join(name.split(".")[:-1])
module = modules.get(module_name)

if isinstance(module, ColumnParallelLinear) and name.endswith(
".weight"
):
# Concat along dimension 0 for ColumnParallel
reconstructed_state_dict[name] = torch.cat(tensor_list, dim=0)
elif isinstance(module, RowParallelLinear) and name.endswith(".weight"):
# Concat along dimension 1 for RowParallel
reconstructed_state_dict[name] = torch.cat(tensor_list, dim=1)
else:
# Non-parallel layers (biases, norms, etc.) are identical across ranks
reconstructed_state_dict[name] = tensor_list[0]

# Step 2: Only the global rank 0 process saves the final model.
if global_rank == 0:
print(f"Rank {global_rank} saving aggregated model checkpoint...")
super().save_pretrained(
save_directory, state_dict=reconstructed_state_dict, **kwargs
)

# Step 3: Barrier to ensure all processes wait until saving is complete.
dist.barrier()
95 changes: 53 additions & 42 deletions specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
Expand All @@ -11,6 +12,8 @@
from transformers.cache_utils import Cache
from transformers.models.llama.configuration_llama import LlamaConfig

from specforge.distributed import get_tp_group
from specforge.layers.linear import ColumnParallelLinear, RowParallelLinear, _AllReduce
from specforge.modeling.draft.flex_attention import (
compile_friendly_create_block_mask,
compile_friendly_flex_attention,
Expand Down Expand Up @@ -343,27 +346,42 @@ class LlamaAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tp_group = get_tp_group()
self._tp_size = (
dist.get_world_size(self.tp_group) if self.tp_group is not None else 1
)
self._tp_rank = dist.get_rank(self.tp_group) if self.tp_group is not None else 0
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
if hasattr(config, "head_dim"):
self.head_dim = config.head_dim
else:
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads

# adjust head number based on tp size
self.num_heads = config.num_attention_heads // self._tp_size
self.num_key_value_heads = config.num_key_value_heads // self._tp_size
assert (
config.num_attention_heads % self._tp_size == 0
), "num_attention_heads must be divisible by tp_size"
assert (
config.num_key_value_heads % self._tp_size == 0
), "num_key_value_heads must be divisible by tp_size"

self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings

self.q_proj = nn.Linear(
self.hidden_size * 2, self.num_heads * self.head_dim, bias=False
self.q_proj = ColumnParallelLinear(
self.hidden_size * 2, config.num_attention_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False
self.k_proj = ColumnParallelLinear(
self.hidden_size * 2, config.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False
self.v_proj = ColumnParallelLinear(
self.hidden_size * 2, config.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
self.o_proj = RowParallelLinear(
config.num_attention_heads * self.head_dim, self.hidden_size, bias=False
)
self._init_rope()

Expand Down Expand Up @@ -512,7 +530,6 @@ def forward(
(attn_weights, attn_weightsi[..., None]), dim=-1
)

# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
Expand All @@ -530,7 +547,10 @@ def forward(
attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads)

attn_output = self.o_proj(attn_output)

if self._tp_size > 1:
attn_output = _AllReduce.apply(
attn_output, dist.ReduceOp.SUM, self.tp_group
)
return attn_output


Expand Down Expand Up @@ -648,44 +668,35 @@ class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config

self.tp_group = get_tp_group()
self._tp_size = (
dist.get_world_size(self.tp_group) if self.tp_group is not None else 1
)

self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.gate_proj = ColumnParallelLinear(
self.hidden_size, self.intermediate_size, bias=False
)
self.up_proj = ColumnParallelLinear(
self.hidden_size, self.intermediate_size, bias=False
)
self.down_proj = RowParallelLinear(
self.intermediate_size, self.hidden_size, bias=False
)
self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)

gate_proj = torch.cat(
[
F.linear(x, gate_proj_slices[i])
for i in range(self.config.pretraining_tp)
],
dim=-1,
)
up_proj = torch.cat(
[
F.linear(x, up_proj_slices[i])
for i in range(self.config.pretraining_tp)
],
dim=-1,
)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i])
for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Remove the pretraining_tp > 1 branch in favor of a unified parallel layer implementation.
gate_output = self.gate_proj(x)
up_output = self.up_proj(x)

down_proj = self.down_proj(self.act_fn(gate_output) * up_output)

if self._tp_size > 1:
down_proj = _AllReduce.apply(down_proj, dist.ReduceOp.SUM, self.tp_group)
return down_proj


Expand Down
2 changes: 1 addition & 1 deletion specforge/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def log(self, log_dict: Dict[str, Any], step: Optional[int] = None):
swanlab.log(log_dict, step=step)

def close(self):
if self.rank == 0 and self.is_initialized and swanlab.is_running():
if self.rank == 0 and self.is_initialized:
swanlab.finish()
self.is_initialized = False

Expand Down
Loading
Loading