diff --git a/scripts/train_eagle3_offline.py b/scripts/train_eagle3_offline.py index 03e5c9f6..69b95b45 100644 --- a/scripts/train_eagle3_offline.py +++ b/scripts/train_eagle3_offline.py @@ -226,8 +226,14 @@ def main(): .cuda() .to(torch.bfloat16) ) - draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) - draft_model.freeze_embedding() + if ( + not hasattr(draft_model_config, "target_hidden_size") + or draft_model_config.target_hidden_size == draft_model_config.hidden_size + ): + draft_model.load_embedding( + args.target_model_path, embedding_key=args.embedding_key + ) + draft_model.freeze_embedding() print_with_rank("Initialized draft model") # build dataloaders diff --git a/scripts/train_eagle3_online.py b/scripts/train_eagle3_online.py index 46c3816c..416b7aaa 100644 --- a/scripts/train_eagle3_online.py +++ b/scripts/train_eagle3_online.py @@ -259,8 +259,14 @@ def main(): .cuda() .to(torch.bfloat16) ) - draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) - draft_model.freeze_embedding() + if ( + not hasattr(draft_model_config, "target_hidden_size") + or draft_model_config.target_hidden_size == draft_model_config.hidden_size + ): + draft_model.load_embedding( + args.target_model_path, embedding_key=args.embedding_key + ) + draft_model.freeze_embedding() print_with_rank("Initialized draft model") # build dataloaders diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 024254ff..3d416854 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -793,14 +793,15 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: ) self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) + self.hidden_size = config.hidden_size if hasattr(config, "target_hidden_size"): - self.fc = torch.nn.Linear( - config.target_hidden_size * 3, config.hidden_size, bias=False - ) + self.target_hidden_size = config.target_hidden_size else: - self.fc = torch.nn.Linear( - config.hidden_size * 3, config.hidden_size, bias=False - ) + self.target_hidden_size = config.hidden_size + + self.fc = torch.nn.Linear( + self.target_hidden_size * 3, self.hidden_size, bias=False + ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear( @@ -874,7 +875,7 @@ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: # eagle 3 requires hidden states from 3 layers - assert hidden_states.size(-1) == self.config.hidden_size * 3 + assert hidden_states.size(-1) == self.target_hidden_size * 3 return self.fc(hidden_states) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: