From f8de21f0abccecc63e25e2bbf06eb975a4005fc8 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Mon, 29 Sep 2025 16:44:55 +0000 Subject: [PATCH 1/4] change m2_indexes to multiple --- .../models/ssm/external/make_hybrid_checkpoint_with_mil.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py index 6ce283525..920da3b85 100644 --- a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py +++ b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py @@ -70,10 +70,11 @@ def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqv @click.command() -@click.option("--m2_indexes", type=int, nargs="-1", required=True) +@click.option("--m2_indexes", type=int, multiple=True, required=True) @click.option("--hybrid_checkpoint", type=str, required=True) @click.option("--save_dir", type=str, required=True) def main(m2_indexes: list, hybrid_checkpoint: str, save_dir: str): + print(f"m2_indexes: {m2_indexes}") m2_indexes = list(m2_indexes) # convert tuple -> list path_base = "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" transformer = AutoModelForCausalLM.from_pretrained(path_base, trust_remote_code=True) @@ -82,7 +83,7 @@ def main(m2_indexes: list, hybrid_checkpoint: str, save_dir: str): hybrid_block_layout = hybrid_config.hybrid_block_layout for m2_index in m2_indexes: hybrid_block_layout[m2_index] = "m2" - print(hybrid_block_layout) + print(f"hybrid_block_layout: {hybrid_block_layout}") convert_layers(transformer, hybrid_config, hybrid_block_layout, True, torch.bfloat16) hybrid_config.ssm_cfg["activation"] = "silu" From ed800ae86325e2a6de9214a55a150f11f700ef3a Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Thu, 16 Oct 2025 18:24:50 +0000 Subject: [PATCH 2/4] formatting --- .../configuration_ssm_hybrid_apriel15b.py | 2 +- .../modeling_ssm_hybrid_apriel15b.py | 632 +----------------- 2 files changed, 34 insertions(+), 600 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py index 98d2fc28d..5e91816f4 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py @@ -29,7 +29,7 @@ class AprielSSMHybridConfig(MistralConfig): model_type = "apriel_ssm_thinker_hybrid" - def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + def __init__(self, hybrid_block_layout=["m2"], ssm_cfg=None, **kwargs): super().__init__(**kwargs) self.hybrid_block_layout = hybrid_block_layout self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 9f4588a29..6f4d2d328 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -6,10 +6,10 @@ import torch import torch.nn.functional as F from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig from einops import rearrange, repeat from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from torch import nn from transformers import GenerationMixin from transformers.cache_utils import Cache, DynamicCache @@ -21,12 +21,6 @@ from transformers.utils import LossKwargs, can_return_tuple, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig - -# from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_scan_fn as varlen_selective_scan_fn -# from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_fn as varlen_causal_conv1d_fn - - logger = logging.get_logger(__name__) @@ -45,178 +39,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionStaticCache(Cache): - def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): - super().__init__() # config, batch_size, max_length, device, dtype) - self.dtype = dtype - self.hybrid_override_pattern = config.hybrid_block_layout - self.has_previous_state = False # only used by mamba - intermediate_size = config.ssm_cfg["d_inner"] - ssm_state_size = config.ssm_cfg["d_state"] - conv_kernel_size = config.ssm_cfg["d_conv"] - self.n_qk_heads = config.ssm_cfg["n_qk_heads"] - assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" - self.head_d = intermediate_size // self.n_qk_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - self.batch_size = batch_size - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - self.max_cache_len = config.max_position_embeddings if max_length is None else max_length - - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) - - for i in range(config.num_hidden_layers): - if self.hybrid_override_pattern[i] == "m2d": - # Mamba layer - new_layer_conv_state = torch.zeros( - batch_size, - conv_kernel_size, - intermediate_size + 2 * self.n_qk_heads * ssm_state_size, - device=device, - dtype=dtype, - ).transpose(1, 2) - - new_layer_ssm_state = torch.zeros( - batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype - ) - new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) - new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) - else: - # Attention or MLP layer - new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) - new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) - new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) - self.transformer_layers.append(i) - - # if not is_torchdynamo_compiling(): - # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) - # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) - # new_layer_key_cache = getattr(self, f"key_cache_{i}") - # new_layer_value_cache = getattr(self, f"value_cache_{i}") - # torch._dynamo.mark_static_address(new_layer_key_cache) - # torch._dynamo.mark_static_address(new_layer_value_cache) - # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) - # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) - # torch._dynamo.mark_static_address(new_layer_conv_state) - # torch._dynamo.mark_static_address(new_layer_ssm_state) - # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") - # new_layer_conv_state = getattr(self, f"conv_states_{i}") - - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - self.conv_states.append(new_layer_conv_state) - self.ssm_states.append(new_layer_ssm_state) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - - cache_position = cache_kwargs.get("cache_position") - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place - # operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - return k_out, v_out - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: Optional[int] = None) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx is None: - if len(self.transformer_layers) > 0: - layer_idx = self.transformer_layers[0] - else: - return 0 - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def get_max_cache_shape(self) -> Optional[int]: - return self.max_cache_len - - # Copied from modeling_mamba2.py - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False - ) -> torch.Tensor: - if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) - else: - self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) - return self.ssm_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - # Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py class HybridMambaAttentionDynamicCache(DynamicCache): """ @@ -445,344 +267,6 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): return hidden_states -# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py - - -class DiscreteMamba2(nn.Module): - def __init__( - self, - d_model, - d_state=64, - n_qk_heads=32, - n_v_heads=32, - d_conv=4, - expand=1, - activation="identity", - bias=False, - conv_bias=True, - chunk_size=128, - layer_idx=None, - device=None, - dtype=None, - d_inner=None, - **kwargs, # Absorb kwarg for general module - ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" - - Other options are all experimental and should not need to be configured - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = self.expand * self.d_model if d_inner is None else d_inner - self.n_qk_heads = n_qk_heads - self.n_v_heads = n_v_heads - self.headdim = self.d_inner // self.n_v_heads - assert self.n_v_heads == self.d_inner // self.headdim - assert self.d_inner % self.headdim == 0 - assert self.n_v_heads % self.n_qk_heads == 0 - self.activation = activation - self.chunk_size = chunk_size - self.layer_idx = layer_idx - self.bias = bias - self.kwargs = kwargs - - # Projections - self.in_proj = nn.Linear( - self.d_model, - 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, - bias=bias, - **factory_kwargs, - ) - self.z_bias = ( - nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 - ) # make sure z_bias always exists - - # Convolutional layer - conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state - self.conv_bias = conv_bias - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - - # Activation after conv - if self.activation == "identity": - self.act = nn.Identity() - elif self.activation in ["silu", "swish"]: - self.act = nn.SiLU() - else: - raise ValueError(f"Unknown activation {self.activation}") - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) - self.D._optim = {"weight_decay": 0.0} - - # out_proj - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - # In __init__, pre-allocate these tensors - # self.zeros_buffer = torch.zeros((self.n_v_heads, self.headdim), device=device, dtype=dtype) - # self.ones_buffer = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=device, dtype=dtype) - - # @property - # def d_output(self): - # return self.d_model - - # @property - # def state_to_tensor(self): - # return self.layer.state_to_tensor - - def forward( - self, - u, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - attention_mask: Optional[torch.Tensor] = None, - return_mixer_matrix=False, - **kwargs, - ): - """ - u: (B, L, D) - Returns: same shape as u - For later refference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bamba/modeling_bamba.py - """ - assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" - cache_position = kwargs.get("cache_position", None) - batch, seqlen, dim = u.shape - u = apply_mask_to_padding_states(u, attention_mask) - ssm_state, conv_state = None, None - use_precomputed_states = False - ######################################################### - # Quick and dirty to work with CG - if "inference_params" in kwargs: - seqlen_offset = kwargs["inference_params"].seqlen_offset - if seqlen_offset > 0: - use_precomputed_states = True - else: - seqlen_offset = kwargs.get("seqlen_offset", cache_position[0]) if cache_position is not None else 0 - use_precomputed_states = ( - past_key_value is not None - and past_key_value.has_previous_state - and seqlen == 1 - and past_key_value.conv_states[self.layer_idx].shape[0] - == past_key_value.ssm_states[self.layer_idx].shape[0] - == batch - and cache_position is not None - and seqlen_offset > 0 - ) - ######################################################### - ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) - if use_precomputed_states: - # ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) - u = u.squeeze(1) if len(u.shape) == 3 else u - out, _, _ = self.step(u, ssm_state, conv_state) - out = out.unsqueeze(1) if len(u.shape) == 2 else out - return {"hidden_states": out} - else: - outputs = {} - # Hacky way to initialize state during inference - chunk_size = self.chunk_size # if ssm_state is None else seqlen - - # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = F.pad(u, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - if conv_state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) - - # Convolutional layer - xBC = self.convolutional_forward(xBC, padded_len) - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) - - # SSM forward - result = mamba_chunk_scan_combined( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=A_log, - dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), - B=B, - C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(ssm_state is not None), - ) - - if ssm_state is not None: - y, ssm_state_update = result - ssm_state.copy_(ssm_state_update) - else: - y = result - - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = rearrange(y + Du, "b l h p -> b l (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :] - - if return_mixer_matrix: - outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] - return outputs - - def step(self, u, ssm_state, conv_state, **kwargs): - """ - u: (B D) - state: dict of states - Returns: same shape as u - """ - - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - xBC, conv_state_new = self.convolutional_step(xBC, conv_state) - if conv_state_new is not None: - raise NotImplementedError("Should not end up here snce only support fast path.") - # conv_state.copy_(conv_state_new) # update state in place, only for slow pass - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) - B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) - C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - - ssm_state = ssm_state.to(x.dtype) - # does nto work with CG, probably becuase zeros and ones are on CPU - # zeros = self.zeros_buffer.to(A_log.device).to(x.dtype) # Just cast, don't allocate - # ones = self.ones_buffer.to(A_log.device).to(x.dtype) - zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) - ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) - y = selective_state_update( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=repeat(A_log, "b h -> b h p", p=self.headdim), - dt_softplus=True, - A=-ones, - B=B, - C=C, - state=ssm_state, # will be updated in place - dt_bias=zeros, - D=zeros, - ) - - y = y + self.D[:, None] * x - y = rearrange(y, "b h p -> b (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - - return out, ssm_state, conv_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - """ - conv_state: (batch, d_conv, conv1d.weight.shape[0]) - ssm_state: (batch, n_qk_heads, headdim, d_state) - """ - assert self.layer_idx is not None - # Allocate memory if not exists - # if self.layer_idx not in inference_params.ssm_states: - # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - # batch_size, inference_params.max_seqlen, dtype=torch.float32 - # ) - # Get states - ssm_states = inference_params.ssm_states[self.layer_idx] - conv_states = inference_params.conv_states[self.layer_idx] - if initialize_states: - ssm_states.zero_() - conv_states.zero_() - return ssm_states, conv_states - - def convolutional_forward(self, xBC, padded_len): - if causal_conv1d_fn is None or self.activation not in [ - "silu", - "swish", - "identity", - ]: - xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) - else: - xBC = causal_conv1d_fn( - xBC.transpose(1, 2), - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - activation=None if self.activation == "identity" else self.activation, - ).transpose(1, 2) - return xBC - - def convolutional_step(self, xBC, conv_state): - # Convolutional layer - conv_state = conv_state.to(xBC.dtype) - if causal_conv1d_update: - xBC = causal_conv1d_update( - xBC, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation if self.activation != "identity" else None, - ) - return xBC, None - else: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv_bias: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype - - return xBC, conv_state - - class Mamba2(nn.Module): def __init__( self, @@ -901,14 +385,9 @@ def forward( hidden_states: (B, L, D) Returns: same shape as hidden_states """ - cu_seqlens = None assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" cache_position = kwargs.get("cache_position", None) batch, seqlen, dim = hidden_states.shape - # mamba_mask = ( - # None if seqlen == 1 else mamba_mask - # ) # prevent that hidden_states are expanded to mask's seq. dimention., i.e. we do not need apply_mask_to_padding_states when generating single token at a time - # hidden_states = apply_mask_to_padding_states(hidden_states, mamba_mask) ssm_state, conv_state = None, None use_precomputed_states = False @@ -935,7 +414,6 @@ def forward( ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) if use_precomputed_states: - # ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) out, _, _ = self.step(hidden_states, conv_state, ssm_state) return {"hidden_states": out} @@ -970,76 +448,44 @@ def forward( x = repeat_kv(x, self.repeat_group) x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") - if cu_seqlens is not None: - # variable length path - x = varlen_causal_conv1d_fn( - x.squeeze(0) if cu_seqlens is not None else x, # Add batch dimension + # Compute short convolution + if conv_state is not None: + # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + # Update state (B D W) + conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) + if causal_conv1d_fn is None: + x = self.act(self.conv1d(x)[..., :seqlen]).transpose(1, 2) + else: + assert self.activation in ["silu", "swish"] + x = causal_conv1d_fn( + x=x, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, - conv_states=conv_state, - query_start_loc=cu_seqlens, ) - x = x.unsqueeze(0) if cu_seqlens is not None else x - else: - # Compute short convolution - if conv_state is not None: - # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - # Update state (B D W) - conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]).transpose(1, 2) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x=x, - weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), - bias=self.conv1d.bias, - activation=self.activation, - ) # .transpose(1, 2) - # x = apply_mask_to_padding_states(x, mamba_mask).transpose( - # 1, 2 - # ) # zero out everything that comes from padding tokens if not self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) x = repeat_kv(x, self.repeat_group) x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") - if cu_seqlens is not None: - # use variable length decoding - y = varlen_selective_scan_fn( - x.squeeze(0), - ssm_state, - dt.squeeze(0), - A, - B.squeeze(0), - C.squeeze(0), - self.D.float(), - z=z.squeeze(0), - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - query_start_loc=cu_seqlens, - ) - y = y.unsqueeze(0) - else: - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=(ssm_state is not None), - ) + y = selective_scan_fn( + x, + dt, + A, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=(ssm_state is not None), + ) - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) y = rearrange(y, "b d l -> b l d") out = self.out_proj(y) @@ -1126,11 +572,7 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states ssm_state: (batch, n_qk_heads, headdim, d_state) """ assert self.layer_idx is not None - # Allocate memory if not exists - # if self.layer_idx not in inference_params.ssm_states: - # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - # batch_size, inference_params.max_seqlen, dtype=torch.float32 - # ) + # Get states ssm_states = inference_params.ssm_states[self.layer_idx] conv_states = inference_params.conv_states[self.layer_idx] @@ -1140,8 +582,8 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states return ssm_states, conv_states -class AprielSSMDecoderLayer(nn.Module): - _mixer_class = DiscreteMamba2 +class AprielSSMM2DecoderLayer(nn.Module): + _mixer_class = Mamba2 def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): super().__init__(**kwargs) @@ -1181,16 +623,11 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - # outputs["hidden_states"] = hidden_states outputs = (hidden_states,) return outputs -class AprielSSMM2DecoderLayer(AprielSSMDecoderLayer): - _mixer_class = Mamba2 - - class AprielHybridIdentity(nn.Module): def __init__(self, config: AprielSSMHybridConfig): super().__init__() @@ -1215,9 +652,7 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): blocks = [] logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") for layer_idx, type in enumerate(config.hybrid_block_layout): - if type == "m2d": - blocks.append(AprielSSMDecoderLayer(config, layer_idx)) - elif type == "m2": + if type == "m2": blocks.append(AprielSSMM2DecoderLayer(config, layer_idx)) elif type == "t": blocks.append(MistralDecoderLayer(config, layer_idx)) @@ -1246,7 +681,7 @@ def forward( ) -> BaseModelOutputWithPast: use_cache = use_cache if use_cache is not None else self.config.use_cache if use_cache and past_key_values is None: - # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) + # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) output = super().forward( @@ -1379,7 +814,6 @@ def prepare_inputs_for_generation( "use_cache": use_cache, "attention_mask": attention_mask, "output_router_logits": output_router_logits, - # "logits_to_keep": self.config.num_logits_to_keep, "cache_position": cache_position, } ) From 10fdd925ada96cc6a35c0b9c017b7bdd5dd70c0a Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Thu, 16 Oct 2025 18:27:55 +0000 Subject: [PATCH 3/4] drop wrong file --- .../make_hybrid_checkpoint_with_mil.py | 108 ------------------ 1 file changed, 108 deletions(-) delete mode 100644 fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py deleted file mode 100644 index 920da3b85..000000000 --- a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py +++ /dev/null @@ -1,108 +0,0 @@ -import gc - -import click -import torch -from transformers import AutoModelForCausalLM - -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig -from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( - AprielSSMM2DecoderLayer, - AprielThinkerSSMHybridForCausalLM, -) - -device = "cuda" if torch.cuda.is_available() else "cpu" - - -def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqvo, torch_dtype=torch.bfloat16): - config = transformer.config - embed_dim = config.hidden_size - num_heads = config.num_attention_heads - num_heads_kv = config.num_key_value_heads - head_dim = embed_dim // num_heads - head_dim * num_heads - head_dim * num_heads_kv - - for layer_idx, type in enumerate(hybrid_block_layout): - print("Converting layer %d...", layer_idx) - # Fetch the layer module for easier access - layer_module = transformer.model.layers._modules[f"{layer_idx}"] - if type == "t": - print("Skipping transformer layer %d..." % layer_idx) - elif type == "m2": - print("Converting layer %d..." % layer_idx) - # Use MambaDecoderLayer for the remaining layers - mamba_encoder = AprielSSMM2DecoderLayer( - mamba_config, - layer_idx, - device="cpu", - dtype=torch_dtype, - ) - - mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict()) - mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict()) - mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict()) - mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict()) - - if init_with_kqvo: - # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] : mamba_config.ssm_cfg["d_inner"] + mamba_config.ssm_cfg["d_xb"], : - ].copy_(layer_module.self_attn.v_proj.weight.data) - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] - + mamba_config.ssm_cfg["d_xb"] : mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"], - :, - ].copy_(layer_module.self_attn.k_proj.weight.data) - mamba_encoder.mixer.in_proj.weight.data[ - mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"] : 2 * mamba_config.ssm_cfg["d_inner"] - + 2 * mamba_config.ssm_cfg["d_xb"], - :, - ].copy_(layer_module.self_attn.q_proj.weight.data) - - print("Init Mamba using Attention") - - transformer.model.layers[layer_idx] = mamba_encoder - - else: - raise ValueError(f"Invalid layer type: {type}") - - -@click.command() -@click.option("--m2_indexes", type=int, multiple=True, required=True) -@click.option("--hybrid_checkpoint", type=str, required=True) -@click.option("--save_dir", type=str, required=True) -def main(m2_indexes: list, hybrid_checkpoint: str, save_dir: str): - print(f"m2_indexes: {m2_indexes}") - m2_indexes = list(m2_indexes) # convert tuple -> list - path_base = "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" - transformer = AutoModelForCausalLM.from_pretrained(path_base, trust_remote_code=True) - hybrid_config = AprielSSMHybridConfig.from_pretrained(hybrid_checkpoint) - - hybrid_block_layout = hybrid_config.hybrid_block_layout - for m2_index in m2_indexes: - hybrid_block_layout[m2_index] = "m2" - print(f"hybrid_block_layout: {hybrid_block_layout}") - - convert_layers(transformer, hybrid_config, hybrid_block_layout, True, torch.bfloat16) - hybrid_config.ssm_cfg["activation"] = "silu" - - # load all existing ssm layers - hybrid_model = AprielThinkerSSMHybridForCausalLM.from_pretrained(hybrid_checkpoint) - state_dict = hybrid_model.state_dict() - missing, unexpected = transformer.load_state_dict(state_dict, strict=False) - for m2_index in m2_indexes: - assert f"model.layers.{m2_index}.mixer.A_log" in missing - assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected - print(missing) - print(unexpected) - transformer.save_pretrained(save_dir) - - hybrid_config.save_pretrained(save_dir) - - gc.collect() - - -if __name__ == "__main__": - main() From bbec888a857f674b13a255dae000300e26b1ed89 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Thu, 16 Oct 2025 18:30:49 +0000 Subject: [PATCH 4/4] reset file --- .../make_hybrid_checkpoint_with_mil.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py new file mode 100644 index 000000000..6ce283525 --- /dev/null +++ b/fast_llm/models/ssm/external/make_hybrid_checkpoint_with_mil.py @@ -0,0 +1,107 @@ +import gc + +import click +import torch +from transformers import AutoModelForCausalLM + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielSSMM2DecoderLayer, + AprielThinkerSSMHybridForCausalLM, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqvo, torch_dtype=torch.bfloat16): + config = transformer.config + embed_dim = config.hidden_size + num_heads = config.num_attention_heads + num_heads_kv = config.num_key_value_heads + head_dim = embed_dim // num_heads + head_dim * num_heads + head_dim * num_heads_kv + + for layer_idx, type in enumerate(hybrid_block_layout): + print("Converting layer %d...", layer_idx) + # Fetch the layer module for easier access + layer_module = transformer.model.layers._modules[f"{layer_idx}"] + if type == "t": + print("Skipping transformer layer %d..." % layer_idx) + elif type == "m2": + print("Converting layer %d..." % layer_idx) + # Use MambaDecoderLayer for the remaining layers + mamba_encoder = AprielSSMM2DecoderLayer( + mamba_config, + layer_idx, + device="cpu", + dtype=torch_dtype, + ) + + mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict()) + mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict()) + mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict()) + mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict()) + + if init_with_kqvo: + # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] : mamba_config.ssm_cfg["d_inner"] + mamba_config.ssm_cfg["d_xb"], : + ].copy_(layer_module.self_attn.v_proj.weight.data) + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] + + mamba_config.ssm_cfg["d_xb"] : mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"], + :, + ].copy_(layer_module.self_attn.k_proj.weight.data) + mamba_encoder.mixer.in_proj.weight.data[ + mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"] : 2 * mamba_config.ssm_cfg["d_inner"] + + 2 * mamba_config.ssm_cfg["d_xb"], + :, + ].copy_(layer_module.self_attn.q_proj.weight.data) + + print("Init Mamba using Attention") + + transformer.model.layers[layer_idx] = mamba_encoder + + else: + raise ValueError(f"Invalid layer type: {type}") + + +@click.command() +@click.option("--m2_indexes", type=int, nargs="-1", required=True) +@click.option("--hybrid_checkpoint", type=str, required=True) +@click.option("--save_dir", type=str, required=True) +def main(m2_indexes: list, hybrid_checkpoint: str, save_dir: str): + m2_indexes = list(m2_indexes) # convert tuple -> list + path_base = "/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker" + transformer = AutoModelForCausalLM.from_pretrained(path_base, trust_remote_code=True) + hybrid_config = AprielSSMHybridConfig.from_pretrained(hybrid_checkpoint) + + hybrid_block_layout = hybrid_config.hybrid_block_layout + for m2_index in m2_indexes: + hybrid_block_layout[m2_index] = "m2" + print(hybrid_block_layout) + + convert_layers(transformer, hybrid_config, hybrid_block_layout, True, torch.bfloat16) + hybrid_config.ssm_cfg["activation"] = "silu" + + # load all existing ssm layers + hybrid_model = AprielThinkerSSMHybridForCausalLM.from_pretrained(hybrid_checkpoint) + state_dict = hybrid_model.state_dict() + missing, unexpected = transformer.load_state_dict(state_dict, strict=False) + for m2_index in m2_indexes: + assert f"model.layers.{m2_index}.mixer.A_log" in missing + assert f"model.layers.{m2_index}.self_attn.q_proj.weight" in unexpected + print(missing) + print(unexpected) + transformer.save_pretrained(save_dir) + + hybrid_config.save_pretrained(save_dir) + + gc.collect() + + +if __name__ == "__main__": + main()