diff --git a/configs/mistral-small-24B-eagle3.json b/configs/mistral-small-24B-eagle3.json new file mode 100644 index 00000000..7db7f501 --- /dev/null +++ b/configs/mistral-small-24B-eagle3.json @@ -0,0 +1,27 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 32768, + "max_position_embeddings": 32768, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 100000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.47.0", + "use_cache": true, + "vocab_size": 131072, + "draft_vocab_size": 32000 +} diff --git a/examples/run_mistral_small_24B_eagle3_online.sh b/examples/run_mistral_small_24B_eagle3_online.sh new file mode 100644 index 00000000..89bb1314 --- /dev/null +++ b/examples/run_mistral_small_24B_eagle3_online.sh @@ -0,0 +1,23 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for mistral-Small-24B +NUM_GPUS=${1:-2} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3_online.py \ + --target-model-path mistralai/Mistral-Small-24B-Instruct-2501 \ + --draft-model-config $ROOT_DIR/configs/mistral-small-24B-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \ + --output-dir $ROOT_DIR/outputs/mistral-Small-24B-eagle3 \ + --num-epochs 2 \ + --batch-size 1 \ + --tp 2 \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template mistral-small-24B \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend flex_attention diff --git a/specforge/data/parse.py b/specforge/data/parse.py index 9b17c90a..fdc63970 100644 --- a/specforge/data/parse.py +++ b/specforge/data/parse.py @@ -41,12 +41,12 @@ class GeneralParser(Parser): def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): super().__init__(tokenizer, chat_template) self.system_prompt = chat_template.system_prompt - self.user_message_separator = ( - f"{chat_template.end_of_turn_token}{chat_template.user_header}" - ) - self.assistant_message_separator = ( - f"{chat_template.end_of_turn_token}{chat_template.assistant_header}" - ) + if chat_template.end_of_turn_token: + self.user_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.user_header or ''}" + self.assistant_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.assistant_header or ''}" + else: + self.user_message_separator = f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header or ''}" + self.assistant_message_separator = f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header or ''}" def parse( self, diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index 4f35d950..0a0cb005 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -71,12 +71,14 @@ def _apply_loss_mask_from_chat_template( """ loss_mask = torch.zeros(len(offsets), dtype=torch.long) - user_message_separator = ( - f"{chat_template.end_of_turn_token}{chat_template.user_header}" - ) - assistant_message_separator = ( - f"{chat_template.end_of_turn_token}{chat_template.assistant_header}" - ) + if chat_template.end_of_turn_token: + user_message_separator = ( + f"{chat_template.end_of_turn_token or ''}{chat_template.user_header or ''}" + ) + assistant_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.assistant_header or ''}" + else: + user_message_separator = f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header or ''}" + assistant_message_separator = f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header or ''}" # Find spans of assistant responses using regex assistant_pattern = ( diff --git a/specforge/data/template.py b/specforge/data/template.py index 2368048b..12a26b72 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -13,12 +13,17 @@ class ChatTemplate(BaseModel): user_header(str): The header for the user. system_prompt(str): The system prompt. end_of_turn_token(str): The end token of a turn of conversation. + If present, end_of_assistant_token and end_of_user_token are ignored. + end_of_assistant_token(str): The end token of an assistant turn of conversation. + end_of_user_token(str): The end token of a user turn of conversation. """ assistant_header: str | None user_header: str | None system_prompt: str | None - end_of_turn_token: str | None + end_of_turn_token: str | None = None + end_of_assistant_token: str | None = None + end_of_user_token: str | None = None parser_type: str = "general" @@ -105,6 +110,23 @@ def get_all_template_names(self) -> List[str]: ), ) +TEMPLATE_REGISTRY.register( + name="mistral-small-24B", + template=ChatTemplate( + assistant_header="[/INST]", + user_header="[INST]", + system_prompt="You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup " + "headquartered in Paris. Your knowledge base was last updated on 2023-10-01. The current date" + "is 2025-08-31. When you're not sure about some information, you say that you don't have the " + "information and don't make up anything. If the user's question is not clear, ambiguous, or " + "does not provide enough context for you to accurately answer the question, you do not try to " + 'answer it right away and you rather ask the user to clarify their request (e.g. "What are ' + 'some good restaurants around me?" => "Where are you?" or "When is the next flight to ' + 'Tokyo" => "Where do you travel from?")', + end_of_assistant_token="", + ), +) + TEMPLATE_REGISTRY.register( name="qwen", template=ChatTemplate( diff --git a/specforge/modeling/auto.py b/specforge/modeling/auto.py index 345e3366..f7717bb8 100644 --- a/specforge/modeling/auto.py +++ b/specforge/modeling/auto.py @@ -11,6 +11,7 @@ Llama4Config, Llama4TextConfig, LlamaConfig, + MistralConfig, Phi3Config, PretrainedConfig, Qwen2_5_VLConfig, @@ -26,6 +27,7 @@ from .target.gpt_oss import GptOssForCausalLM from .target.llama import LlamaForCausalLM from .target.llama4 import Llama4ForCausalLM +from .target.mistral import MistralForCausalLM from .target.phi3 import Phi3ForCausalLM from .target.qwen2 import Qwen2ForCausalLM from .target.qwen3 import Qwen3ForCausalLM @@ -94,6 +96,7 @@ class AutoDistributedTargetModel(AutoModelForCausalLMBase): LlamaConfig: [LlamaForCausalLM], Qwen3Config: [Qwen3ForCausalLM], Phi3Config: [Phi3ForCausalLM], + MistralConfig: [MistralForCausalLM], GptOssConfig: [GptOssForCausalLM], } diff --git a/specforge/modeling/target/mistral.py b/specforge/modeling/target/mistral.py new file mode 100644 index 00000000..d2f92b22 --- /dev/null +++ b/specforge/modeling/target/mistral.py @@ -0,0 +1,574 @@ +# coding=utf-8 +# Copyright 2025 Mistral AI and HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn +from transformers import MistralConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.models.mistral.modeling_mistral import ( + MistralRMSNorm, + MistralRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + TransformersKwargs, + auto_docstring, + can_return_tuple, + logging, +) + +from specforge.distributed import get_tp_group +from specforge.layers.linear import ColumnParallelLinear, RowParallelLinear +from specforge.modeling.target.base import DistributedTargetModel + +logger = logging.get_logger(__name__) + + +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # distributed linear layers + self.tp_group = get_tp_group() + self.gate_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, self.hidden_size, bias=False + ) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + dist.all_reduce(down_proj, op=dist.ReduceOp.SUM, group=self.tp_group) + return down_proj + + +class MistralAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + # distributed linear layers + self.tp_group = get_tp_group() + self.q_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=False, + ) + self.k_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=False, + ) + self.v_proj = ColumnParallelLinear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=False, + ) + self.o_proj = RowParallelLinear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr( + self.config, "sliding_window", None + ), # main diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + dist.all_reduce(attn_output, op=dist.ReduceOp.SUM, group=self.tp_group) + return attn_output, attn_weights + + +class MistralDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +@auto_docstring +class MistralPreTrainedModel(PreTrainedModel): + config_class: MistralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MistralDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) + + +@auto_docstring +class MistralModel(MistralPreTrainedModel): + def __init__(self, config: MistralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + MistralDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MistralRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError( + "The `past_key_values` should be either a `Cache` object or `None`." + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = ( + create_causal_mask + if self.config.sliding_window is None + else create_sliding_window_causal_mask + ) + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@auto_docstring +class MistralForCausalLM( + MistralPreTrainedModel, GenerationMixin, DistributedTargetModel +): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + + # distributed the lm head + self.lm_head = ColumnParallelLinear( + config.hidden_size, config.vocab_size, bias=False + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = self._gather_tensor(logits, get_tp_group()) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def load_weights(self, state_dict: Dict[str, torch.Tensor]): + tp_group = get_tp_group() + + updated_state_dict = {} + for key, value in state_dict.items(): + # Ensure that the state dict is a flat dict of keys and tensors. Breaking this assumption + # will break recipe code + if not isinstance(value, torch.Tensor): + raise ValueError( + f"Expected all values in the state dict to be torch.Tensor. " + f"Found {type(value)} instead." + ) + + module_key = ".".join(key.split(".")[:-1]) + module = self.get_submodule(module_key) + + # get the module type based on key and shard accordingly + if isinstance(module, RowParallelLinear) and key.endswith(".weight"): + value = self._shard_tensor(value, tp_group, -1) + elif isinstance(module, ColumnParallelLinear) and key.endswith(".weight"): + value = self._shard_tensor(value, tp_group, 0) + elif isinstance(module, ColumnParallelLinear) and key.endswith(".bias"): + value = self._shard_tensor(value, tp_group, 0) + + updated_state_dict[key] = value + + # load state dict + self.load_state_dict(updated_state_dict, strict=False) + + +__all__ = [ + "MistralForCausalLM", + "MistralPreTrainedModel", + "MistralModel", +] diff --git a/tests/test_target_modeling/test_mistral_tp.py b/tests/test_target_modeling/test_mistral_tp.py new file mode 100644 index 00000000..e3526184 --- /dev/null +++ b/tests/test_target_modeling/test_mistral_tp.py @@ -0,0 +1,88 @@ +import os +import tempfile +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from accelerate.utils import set_seed +from transformers import MistralConfig, MistralForCausalLM + +from specforge.distributed import init_distributed + + +def test_mistral_tp(rank, world_size, temp_dir): + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29501" + + init_distributed(tp_size=2) + set_seed(42) + config = MistralConfig( + vocab_size=1000, + hidden_size=384, + intermediate_size=512, + num_hidden_layers=2, + max_position_embeddings=1024, + num_attention_heads=10, + num_key_value_heads=2, + head_dim=64, + tie_word_embeddings=False, + initializer_range=0.02, + hidden_act="silu", + rms_norm_eps=1e-05, + ) + + # create the single-gpu + model = MistralForCausalLM(config).cuda() + + from specforge.modeling.target.mistral import ( + MistralForCausalLM as DistMistralForCausalLM, + ) + + dist_model = DistMistralForCausalLM(config).cuda() + + # save the model weights to a temp directory + if dist.get_rank() == 0: + model.save_pretrained(temp_dir) + print(f"Saved model to {temp_dir}") + dist.barrier() + + # load the model weights to the distributed model + print(f"Loading model from {temp_dir}") + dist_model.load_checkpoint(temp_dir) + dist.barrier() + + # create data + input_ids = torch.randint(0, 1000, (1, 256)).cuda() + attention_mask = torch.ones_like(input_ids).cuda() + + expected_logits = model(input_ids=input_ids, attention_mask=attention_mask).logits + dist_logits = dist_model(input_ids=input_ids, attention_mask=attention_mask).logits + + assert torch.allclose( + expected_logits, + dist_logits, + rtol=1e-5, + atol=1e-5, + ), f"Logits are not close, {expected_logits} vs {dist_logits}" + + +class TestMistralTP(unittest.TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + self.temp_dir.cleanup() + + def test_mistral_tp(self): + mp.spawn(test_mistral_tp, nprocs=2, args=(2, self.temp_dir.name)) + + +if __name__ == "__main__": + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(TestMistralTP)) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite)