From a0b6e45b3e6af541e48703b0b358b19149c3a471 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 14:05:06 +0000 Subject: [PATCH 01/23] add assert --- fast_llm/data/dataset/gpt/sampled.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 42062a58..969cafa7 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -144,6 +144,10 @@ def _sample(self) -> None: document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) if image_sizes.any(): + assert self._parameters.max_image_size is not None, ( + f"Dataset {self._indexed_dataset.name} contains images, but no max_image_size is set." + f"image_sizes: {image_sizes}, max_image_size: {self._parameters.max_image_size}" + ) image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( From d35c6853ef0156ba551cf20fa3f07db0da86bfa9 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 15:00:54 +0000 Subject: [PATCH 02/23] move check to config validation --- fast_llm/data/dataset/gpt/sampled.py | 4 ---- fast_llm/models/gpt/config.py | 3 +++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 969cafa7..42062a58 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -144,10 +144,6 @@ def _sample(self) -> None: document_sizes, image_sizes = self._indexed_dataset.get_document_sizes() document_sizes = torch.from_numpy(document_sizes).to(self._device) if image_sizes.any(): - assert self._parameters.max_image_size is not None, ( - f"Dataset {self._indexed_dataset.name} contains images, but no max_image_size is set." - f"image_sizes: {image_sizes}, max_image_size: {self._parameters.max_image_size}" - ) image_token_sizes = [] for i, sizes in enumerate(image_sizes): image_token_sizes.append( diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 039b97f8..bc64821f 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -257,6 +257,9 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + if self.model.base_model.vision_encoder.enabled: + assert self.batch.max_image_size is not None, "max_image_size must be set when using vision encoder" + Assert.gt(self.batch.max_image_size, 0) @classmethod def _from_dict( From 3345ab122a9d5aa50f17c11cfa95e4c7a1279cef Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 18:25:15 +0000 Subject: [PATCH 03/23] debug log --- fast_llm/engine/distributed/distributed.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 200074ee..e9247871 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -46,6 +46,8 @@ def __init__( Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) torch.cuda.init() self._device = torch.device(self._rank) + logger.info(f"Using device {self._device} for rank {self._rank}.") + logger.info(f"Number of local devices: {torch.cuda.device_count()}.") torch.cuda.set_device(self._device) if self._world_size > 1: From b0b52fa98b366ffa52c3844b177b88d3b558c36e Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 18:35:27 +0000 Subject: [PATCH 04/23] fix device --- fast_llm/engine/distributed/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index e9247871..1f879830 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -45,7 +45,7 @@ def __init__( else: Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) torch.cuda.init() - self._device = torch.device(self._rank) + self._device = torch.device(self._rank % self._local_world_size) logger.info(f"Using device {self._device} for rank {self._rank}.") logger.info(f"Number of local devices: {torch.cuda.device_count()}.") torch.cuda.set_device(self._device) From 2be288ca52a459a24d086d520e1e9ad338e442c7 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 19:41:57 +0000 Subject: [PATCH 05/23] remove log --- fast_llm/engine/distributed/distributed.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 1f879830..f17a8f45 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -46,8 +46,6 @@ def __init__( Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) torch.cuda.init() self._device = torch.device(self._rank % self._local_world_size) - logger.info(f"Using device {self._device} for rank {self._rank}.") - logger.info(f"Number of local devices: {torch.cuda.device_count()}.") torch.cuda.set_device(self._device) if self._world_size > 1: From d1caa987e648719ca8418c8bd93a2f31a46e7d2c Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 17 Jul 2025 20:47:42 +0000 Subject: [PATCH 06/23] fix name --- fast_llm/layers/ssm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c69ada38..46d629aa 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -22,7 +22,7 @@ class SSMDimNames: v_heads = "v_heads" # Number of V heads # Mamba 2 - x_proj_dim_2 = "x_proj_dim" # d_xb + x_proj_dim_2 = "x_proj_dim_2" # d_xb class SSMBlockType(enum.StrEnum): From 0fbe88156f2fd8c9d218f554b63bb699eba49cc0 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 13:48:46 +0000 Subject: [PATCH 07/23] fix hybrid get_layers --- fast_llm/models/gpt/model.py | 12 +++++++----- fast_llm/models/ssm/model.py | 3 +-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6356cf23..8d70a894 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -110,13 +110,15 @@ def get_vision_layers(self) -> list[Layer]: MultiModalEmbedding(self._config, self._tensor_space), ] + def get_embedding_layers(self) -> list[Layer]: + if self._config.vision_encoder.enabled: + return self.get_vision_layers() + else: + return [LanguageModelEmbedding(self._config, self._tensor_space)] + def get_layers(self) -> list[Layer]: return [ - *( - [LanguageModelEmbedding(self._config, self._tensor_space)] - if not self._config.vision_encoder.enabled - else self.get_vision_layers() - ), + *(self.get_embedding_layers()), *[ TransformerLayer( self._config.transformer, diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac23..80f9ca8b 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -3,7 +3,6 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.llamba_block import LlambaBlock @@ -94,7 +93,7 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers = self.get_embedding_layers() # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): From 854f305507f5d8852a8f8e5cc56cb3ca1b6248d5 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 14:21:02 +0000 Subject: [PATCH 08/23] debug --- fast_llm/layers/ssm/mamba2.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509ab..96116abe 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -10,6 +11,8 @@ from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ from fast_llm.utils import get_lr_scale +logger = logging.getLogger(__name__) + try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -144,6 +147,9 @@ def init_from_tensor_( value: torch.Tensor, ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + logger.info( + f"Initializing {meta.tensor_name} with shape {meta.shape} from tensor with shape {value.shape}" + ) return tensor.copy_(value) return init_ @@ -156,6 +162,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) lr_scale=mamba_layer_lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn + logger.info(f"td_inner: {td_inner}, inv_dt: {inv_dt.shape}") self.dt_proj_bias = ParameterMeta.from_dims( (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale ) @@ -166,6 +173,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) d=self.d_inner, ).contiguous() A_log = torch.log(A).flatten() # Keep A_log in fp32 + logger.info(f"A_log: {A_log.shape}, td_inner: {td_inner}, td_state: {td_state}") self.A_log = ParameterMeta.from_dims( (td_inner, td_state), init_method=init_from_tensor_(A_log), From b7b81931ad634646476663222f04585957633bee Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 20:30:41 +0000 Subject: [PATCH 09/23] add llava hybrid format --- fast_llm/models/ssm/config.py | 20 ++++++++++++++++++++ fast_llm/models/ssm/conversion.py | 11 ++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11b..95c8ca84 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -166,6 +166,26 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler +# class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): +# name: typing.ClassVar[str] = "llava" +# # Using default values for vision and text models. Can be overridden in the config +# vision_name: typing.ClassVar[str] = "pixtral" +# text_name: typing.ClassVar[str] = "mistral" + + +class LlavaHybridHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "llava_hybrid" + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import LlavaHybridHuggingfaceCheckpointHandler + + return LlavaHybridHuggingfaceCheckpointHandler + + @config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d5730025..ddf45c6c 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -20,13 +20,18 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig from fast_llm.layers.ssm.config import SSMBlockType -from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter +from fast_llm.models.gpt.conversion import ( + CommonLlamaHuggingfaceCheckpointHandler, + LlavaHuggingfaceCheckpointHandler, + MLPLayer2Converter, +) from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, ) from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert @@ -762,3 +767,7 @@ def _load_config(cls, directory: pathlib.Path | str) -> dict: def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: with open(directory / "config.json", "w") as f: json.dump(config, f) + + +class LlavaHybridHuggingfaceCheckpointHandler(LlavaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat From a1589da4f320c05e2515ec8559109e0fb6f6b3cc Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 20:47:04 +0000 Subject: [PATCH 10/23] workaround init --- fast_llm/layers/ssm/mamba2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 96116abe..8a61a896 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -148,9 +148,14 @@ def init_from_tensor_( ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa logger.info( - f"Initializing {meta.tensor_name} with shape {meta.shape} from tensor with shape {value.shape}" + f"Initializing {meta.tensor_name} with shape {meta.shape}, tensor shape {tensor.shape} from value shape {value.shape}" ) - return tensor.copy_(value) + # TODO: fix and remove try-except + try: + return tensor.copy_(value) + except RuntimeError as e: + logger.error(f"Failed to copy value to tensor: {e}") + return tensor.fill_(0.0) return init_ From a202b4ce4d8ba34edf867895dbe4e40b06039918 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 21:00:26 +0000 Subject: [PATCH 11/23] update --- fast_llm/models/ssm/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 95c8ca84..55a7ef54 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -196,6 +196,7 @@ class HybridSSMModelConfig(FastLLMModelConfig): AprielSSMHuggingfaceCheckpointFormat, AprielSSMHHybridHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, ) @classmethod From b675ec2d1e073a0430962832dd37fd86e08fe768 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 21:16:16 +0000 Subject: [PATCH 12/23] update --- fast_llm/models/ssm/conversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index ddf45c6c..0b2a36c6 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -771,3 +771,4 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class LlavaHybridHuggingfaceCheckpointHandler(LlavaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig From a055d2a8aff442173291cb890a31a3664fd15262 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 22:15:47 +0000 Subject: [PATCH 13/23] update --- fast_llm/models/gpt/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index e01aaf70..3a53754d 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -664,7 +664,7 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + _model_class: typing.ClassVar[FastLLMModelConfig] = FastLLMModelConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: From 8e4ef5d52f94960820b8dbe062c12256e3249373 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 18 Jul 2025 22:43:02 +0000 Subject: [PATCH 14/23] refactoring attempt --- fast_llm/models/gpt/conversion.py | 17 +++++++++++++---- fast_llm/models/ssm/conversion.py | 7 +++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 3a53754d..79d7099c 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -14,6 +14,7 @@ AutoStateDictCheckpointHandler, ConstantExportParamConverter, ConstantImportParamConverter, + ExternalStateDictCheckpointHandler, IgnoreExportWeightConverter, IgnoreImportParamConverter, IgnoreImportWeightConverter, @@ -873,6 +874,14 @@ class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + @classmethod + def get_vision_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + + @classmethod + def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: cfg_dict = cls._load_config(config.path) @@ -944,8 +953,8 @@ def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: @classmethod def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: exported_config = {} - vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) - text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + vision_handler_cls = cls.get_vision_handler_class() + text_handler_cls = cls.get_text_handler_class() for converter in vision_handler_cls._create_config_converters(): try: values = converter.export_params( @@ -991,10 +1000,10 @@ def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: return exported_config def _create_weight_converters(self): - vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.vision_name) + vision_handler_cls = self.get_vision_handler_class() vision_handler = vision_handler_cls(self._model) converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) - text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler_cls = self.get_text_handler_class() text_handler = text_handler_cls(self._model) converters.extend( text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 0b2a36c6..d62a0549 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -7,6 +7,7 @@ from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, ConstantImportParamConverter, + ExternalStateDictCheckpointHandler, IgnoreImportParamConverter, IgnoreImportWeightConverter, MappedConfigParamConverter, @@ -772,3 +773,9 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class LlavaHybridHuggingfaceCheckpointHandler(LlavaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + + @classmethod + def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler From 4bfce67042390d0eef54a5697457469b2443a812 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 21 Jul 2025 14:11:11 +0000 Subject: [PATCH 15/23] update ssm conversion: use hf_prefix/offset --- fast_llm/models/ssm/conversion.py | 59 +++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d62a0549..55ed3058 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -221,7 +221,11 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] - def _create_weight_converters(self) -> list[WeightConverter]: + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: converters = super()._create_weight_converters() or [] num_layers = self._model.config.base_model.transformer.num_layers @@ -230,55 +234,65 @@ def _create_weight_converters(self) -> list[WeightConverter]: for i in range(num_layers): # SSM converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias + f"layers.{offset+i+1}.mixer.in_proj", f"{hf_base_prefix}model.layers.{i}.mixer.in_proj", ssm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias + f"layers.{offset+i+1}.mixer.out_proj", f"{hf_base_prefix}model.layers.{i}.mixer.out_proj", ssm_bias ) converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) + WeightConverter( + f"layers.{offset+i+1}.mixer.D", + f"{hf_base_prefix}model.layers.{i}.mixer.D", + self._model.config.base_model, + ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{offset+i+1}.mixer.z_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{offset+i+1}.mixer.z_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"model.layers.{i}.mixer.conv1d.weight", + f"layers.{offset+i+1}.mixer.conv1d_weight", + f"{hf_base_prefix}model.layers.{i}.mixer.conv1d.weight", self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"model.layers.{i}.mixer.conv1d.bias", + f"layers.{offset+i+1}.mixer.conv1d_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.conv1d.bias", self._model.config.base_model, ) ) # ================================================ # Mamba2 specific parameters converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False + f"layers.{offset+i+1}.mixer.dt_proj", f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj", False ) # bias is treated separately in Mamba2 and must always exist (https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py) converters.append( WeightConverter( - f"layers.{i+1}.mixer.dt_proj_bias", - f"model.layers.{i}.mixer.dt_proj.bias", + f"layers.{offset+i+1}.mixer.dt_proj_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj.bias", self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.A_log", f"model.layers.{i}.mixer.A_log", self._model.config.base_model + f"layers.{offset+i+1}.mixer.A_log", + f"{hf_base_prefix}model.layers.{i}.mixer.A_log", + self._model.config.base_model, ) ) @@ -572,11 +586,16 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = super()._create_weight_converters(hf_base_prefix, offset) num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False + # TODO: use hf_base_prefix and offset # Embedding and output if self._model.config.base_model.tie_word_embeddings: converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) @@ -710,8 +729,12 @@ class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( _hf_prefix: str = "model" architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = super()._create_weight_converters(hf_base_prefix, offset) # num_layers = self._model.config.base_model.transformer.num_layers # # Embedding and output # if self._model.config.base_model.tie_word_embeddings: From f93c51f465998370e4a9ac2b1f2eb78903b57ba3 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 21 Jul 2025 21:50:55 +0000 Subject: [PATCH 16/23] draft llava hybrid --- fast_llm/engine/checkpoint/huggingface.py | 3 + fast_llm/models/ssm/conversion.py | 18 ++- .../configuration_llava_hybrid.py | 110 ++++++++++++++++++ .../llava_hybrid/modeling_llava_hybrid.py | 12 ++ 4 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py create mode 100644 fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 16b3e005..4cfff4af 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -134,6 +134,7 @@ class CustomModelingExportMixin: configuration_file: typing.ClassVar[str] configuration_cls: typing.ClassVar[type[PretrainedConfig]] generation_utils_file: str | None = None + additional_files: typing.ClassVar[list[str]] = [] # Use custom config instead of relying on the transformers library @classmethod @@ -159,3 +160,5 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None: gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json" if gen_config.exists(): shutil.copy(gen_config, config.path) + for file in self.additional_files: + shutil.copy(file, config.path) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 55ed3058..7d908a13 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,8 @@ import pathlib import typing +from transformers.configuration_utils import PretrainedConfig + from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -16,7 +18,7 @@ SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig @@ -34,6 +36,11 @@ LLambaHuggingfaceCheckpointFormat, LlavaHybridHuggingfaceCheckpointFormat, ) +from fast_llm.models.ssm.external.apriel_15b_hybrid import ( + configuration_ssm_hybrid_apriel15b, + modeling_ssm_hybrid_apriel15b, +) +from fast_llm.models.ssm.external.llava_hybrid import configuration_llava_hybrid, modeling_llava_hybrid from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert @@ -793,9 +800,16 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An json.dump(config, f) -class LlavaHybridHuggingfaceCheckpointHandler(LlavaHuggingfaceCheckpointHandler): +class LlavaHybridHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlavaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat + modeling_file = modeling_llava_hybrid.__file__ + configuration_file = configuration_llava_hybrid.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = configuration_llava_hybrid.LlavaHybridConfig _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + additional_files = [ + modeling_ssm_hybrid_apriel15b.__file__, + configuration_ssm_hybrid_apriel15b.__file__, + ] @classmethod def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: diff --git a/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py new file mode 100644 index 00000000..09e17a92 --- /dev/null +++ b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py @@ -0,0 +1,110 @@ +from transformers import MistralConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +# Copied from configuration_ssm_hybrid_apriel15b.py +# TODO: split into mamba 2 and discrete mamba 2 configs with a base dict +ssm_config_default = { + # discrete mamba2 + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 32 * 128, + # mamba2 + "d_xb": None, # will be set to model dim + "dt_rank": "auto", + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init": "random", + "dt_scale": 1.0, + "dt_init_floor": 1e-4, + "conv_bias": True, +} + + +class AprielSSMHybridConfig(MistralConfig): + model_type = "apriel_ssm_thinker_hybrid" + + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + super().__init__(**kwargs) + self.hybrid_block_layout = hybrid_block_layout + self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 + self.ssm_cfg = ssm_cfg or ssm_config_default + + for k, v in ssm_config_default.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v # to make sure all elements are present in the config + + +class LlavaHybridConfig(PretrainedConfig): + """ + Configuration class for Llava SSM-Hybrid-decoder model. + """ + + model_type = "llava_hybrid" + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=576, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + if text_config.get("model_type") == "apriel_ssm_thinker_hybrid": + text_config = AprielSSMHybridConfig(**text_config) + else: + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + + super().__init__(**kwargs) diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py new file mode 100644 index 00000000..d58b3535 --- /dev/null +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -0,0 +1,12 @@ +from transformers import LlavaModel + +from .configuration_llava_hybrid import LlavaHybridConfig + + +class LlavaHybridModel(LlavaModel): + """ + Llava SSM-Hybrid-decoder model. + """ + + def __init__(self, config: LlavaHybridConfig): + super().__init__(config) From d10eaad67d80efd00f70c574d5743c94a42f90eb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 23 Jul 2025 15:10:01 +0000 Subject: [PATCH 17/23] fix and add test config --- fast_llm/layers/vision_encoder/adapter.py | 8 +++--- fast_llm/layers/vision_encoder/config.py | 12 ++++++++ fast_llm/models/ssm/conversion.py | 8 +++++- tests/utils/model_configs.py | 34 ++++++++++++++++++++++- 4 files changed, 56 insertions(+), 6 deletions(-) diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index 41ea065d..d324d522 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -24,15 +24,15 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): input_dim, tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), bias=True, - weight_init_method=init_normal_(), - bias_init_method=init_normal_(), + weight_init_method=init_normal_(std=config.adapter_init_method_std), + bias_init_method=init_normal_(std=config.adapter_init_method_std), ) self.layer_2 = Linear( tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), bias=True, - weight_init_method=init_normal_(), - bias_init_method=init_normal_(), + weight_init_method=init_normal_(std=config.adapter_init_method_std), + bias_init_method=init_normal_(std=config.adapter_init_method_std), ) def forward( diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 2ea7f611..a705d948 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -150,6 +150,18 @@ class VisionEncoderConfig(BaseModelConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + adapter_init_method_std: float = Field( + default=None, + desc="Standard deviation for the normal initialization of the adapter weights. Default: adapter_size ** -0.5.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.adapter_init_method_std is None: + self.adapter_init_method_std = self.adapter_size**-0.5 + super()._validate() def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 7d908a13..97794b4f 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -233,7 +233,13 @@ def _create_weight_converters( hf_base_prefix: str = "", offset: int = 0, ) -> list[WeightConverter]: - converters = super()._create_weight_converters() or [] + converters = ( + super()._create_weight_converters( + hf_base_prefix=hf_base_prefix, + offset=offset, + ) + or [] + ) num_layers = self._model.config.base_model.transformer.num_layers ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b96a8963..6af35eeb 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -19,7 +19,7 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat, LlavaHybridHuggingfaceCheckpointFormat from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig @@ -521,6 +521,38 @@ def _update_and_add_testing_config( }, ) +_update_and_add_testing_config( + # Tests hybrid ssm, llamba converter. + "hybrid_mamba2", + "vision_hybrid_mamba2", + model_type="hybrid_ssm", + extra_args=[ + "batch.max_image_size=128", + "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.transformer.type=image_encoder", + "model.base_model.vision_encoder.transformer.gated=True", + "model.base_model.vision_encoder.transformer.num_layers=2", + "model.base_model.vision_encoder.transformer.hidden_size=256", + "model.base_model.vision_encoder.transformer.num_attention_heads=8", + "model.base_model.vision_encoder.transformer.head_groups=4", + "model.base_model.vision_encoder.transformer.init_method_std=0.022", + "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", + "model.base_model.vision_encoder.adapter_size=512", + "model.distributed.training_dtype=torch.bfloat16", + ], + megatron_args=None, + checkpoint_format=LlavaHybridHuggingfaceCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=16.0, +) + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: From 7c8de47cbc424b39a5abeb0c09881bb159ee2893 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 23 Jul 2025 19:57:24 +0000 Subject: [PATCH 18/23] conversion fixes and tests --- fast_llm/engine/checkpoint/external.py | 6 +++- .../layers/vision_encoder/preprocessing.py | 33 ++++++++++++++----- fast_llm/models/gpt/conversion.py | 10 +++--- fast_llm/models/gpt/model.py | 6 +++- fast_llm/models/ssm/config.py | 3 +- fast_llm/models/ssm/conversion.py | 14 ++++++++ .../llava_hybrid/modeling_llava_hybrid.py | 11 ++++++- fast_llm/models/ssm/huggingface.py | 6 ++-- fast_llm/models/ssm/model.py | 7 ++++ tests/utils/model_configs.py | 8 +++-- 10 files changed, 84 insertions(+), 20 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 72db80f6..f8a42b31 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -283,7 +283,7 @@ def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: return exported_config # Noqa @classmethod - def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa + def _import_config_dict(cls, config: dict[str, typing.Any]) -> dict[str | tuple[str, ...], typing.Any]: kwargs = {} for converter in cls._get_config_converters(): try: @@ -306,7 +306,11 @@ def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # no kwargs[fast_llm_name] = value except Exception as e: raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + return kwargs + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa + kwargs = cls._import_config_dict(config) return cls._model_class.get_base_model_config_class().from_dict({}, kwargs) def _convert_state_dict( diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 3b857ba2..adacd380 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -163,11 +163,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: for imgs in images ] - labels = kwargs[LanguageModelKwargs.labels] - if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): - # If image break or end token is present, we need to replace image token ids to -100 in labels - # TODO: avoid double cloning labels in case of loss masking spans? - labels = labels.clone() + if LanguageModelKwargs.labels in kwargs: + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() patches = [] patch_position_ids = [] @@ -191,8 +192,9 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: image_break=self._config.image_break_token is not None, image_end=self._config.image_end_token is not None, ) - # set labels for image patches to -100 - labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 + if LanguageModelKwargs.labels in kwargs: + # set labels for image patches to -100 + labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 if seqlen > max_seqlen: max_seqlen = seqlen cu_seqlens.append(cu_seqlens[-1] + seqlen) @@ -261,4 +263,19 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen - kwargs[LanguageModelKwargs.labels] = labels + if LanguageModelKwargs.labels in kwargs: + kwargs[LanguageModelKwargs.labels] = labels + + # TODO: add proper preprocessing for attention-mask when not using flash attention + # Following is just a dummy code to run the tests. + kwargs[self._config.transformer._transformer_kwargs.attention_mask] = torch.ones( + (1, 1, kwargs[TransformerKwargs.sequence_length], 1, kwargs[TransformerKwargs.sequence_length]), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[self._config.transformer._transformer_kwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._distributed_config.training_dtype.torch).min, + dtype=self._distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 79d7099c..2f4d9b61 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -884,13 +884,15 @@ def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + vision_handler_cls = cls.get_vision_handler_class() + text_handler_cls = cls.get_text_handler_class() cfg_dict = cls._load_config(config.path) kwargs = {} if "text_config" in cfg_dict: - text_kwargs = cls._import_config(cfg_dict["text_config"]) + text_kwargs = text_handler_cls._import_config_dict(cfg_dict["text_config"]) kwargs.update(text_kwargs) if "vision_config" in cfg_dict: - vision_kwargs = cls._import_config(cfg_dict["vision_config"]) + vision_kwargs = vision_handler_cls._import_config_dict(cfg_dict["vision_config"]) vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} kwargs.update(vision_kwargs) kwargs.update( @@ -927,9 +929,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: - handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + # handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) kwargs = {} - for converter in handler_cls._create_config_converters(): + for converter in cls._create_config_converters(): try: values = () for export_name in converter.export_names: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8d70a894..1e439e72 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -150,7 +150,11 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder.enabled: - max_image_size = batch_meta.max_image_size + try: + max_image_size = batch_meta.max_image_size + except AttributeError: + max_image_size = 256 + logger.warning("Inference mode: max_image_size not provided, defaulting to 256") image_mean = [ self._config.vision_encoder.image_normalization.mean_r, self._config.vision_encoder.image_normalization.mean_g, diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 55a7ef54..41f4eadb 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -178,6 +178,7 @@ class LlavaHybridHuggingfaceCheckpointFormat(CheckpointFormat): name: typing.ClassVar[str] = "llava_hybrid" vision_name: typing.ClassVar[str] = "pixtral" text_name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + trust_remote_code: typing.ClassVar[bool] = True @classmethod def get_handler_class(cls) -> type[CheckpointHandler]: @@ -206,7 +207,7 @@ def get_model_class(cls) -> type["HybridSSMModel"]: return HybridSSMModel @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM return HuggingfaceHybridSSMModelForCausalLM diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 97794b4f..db960f0f 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -822,3 +822,17 @@ def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + cls.architecture = "LlavaHybridForConditionalGeneration" + return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", + "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", + "AutoModelForConditionalGeneration": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + }, + ), + ] diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index d58b3535..78e390a1 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -1,4 +1,5 @@ -from transformers import LlavaModel +from torch import nn +from transformers import LlavaForConditionalGeneration, LlavaModel from .configuration_llava_hybrid import LlavaHybridConfig @@ -10,3 +11,11 @@ class LlavaHybridModel(LlavaModel): def __init__(self, config: LlavaHybridConfig): super().__init__(config) + + +class LlavaHybridForConditionalGeneration(LlavaForConditionalGeneration): + def __init__(self, config: LlavaHybridConfig): + super(LlavaForConditionalGeneration, self).__init__(config) + self.model = LlavaHybridModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/ssm/huggingface.py index 77cd346f..02f47207 100644 --- a/fast_llm/models/ssm/huggingface.py +++ b/fast_llm/models/ssm/huggingface.py @@ -1,9 +1,10 @@ import logging +import typing -from fast_llm.engine.huggingface.config import HuggingfaceModelConfig +from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from fast_llm.models.ssm.config import HybridSSMModelConfig -from fast_llm.models.ssm.model import HybridSSMModel +from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel logger = logging.getLogger(__name__) @@ -17,5 +18,6 @@ class HuggingfaceSSMModelConfig(HuggingfaceModelConfig): class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): config_class = HuggingfaceSSMModelConfig config: HuggingfaceSSMModelConfig + runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner model_class = HybridSSMModel _fast_llm_model: HybridSSMModel diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 80f9ca8b..df15907d 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -3,12 +3,14 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba2 import Mamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -164,3 +166,8 @@ class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): config_class: typing.ClassVar[type[HybridSSMModelConfig]] = HybridSSMModelConfig base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel + + +class HybridSSMInferenceRunner(InferenceRunner): + model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel + batch_config_class: typing.ClassVar[type[GPTBatchConfig]] = GPTBatchConfig diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6af35eeb..c2bb6800 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -529,12 +529,16 @@ def _update_and_add_testing_config( extra_args=[ "batch.max_image_size=128", "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.patch_norm.type=rms_norm", + "model.base_model.vision_encoder.transformer.add_linear_biases=False", + "model.base_model.vision_encoder.transformer.causal=False", + "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", "model.base_model.vision_encoder.transformer.type=image_encoder", "model.base_model.vision_encoder.transformer.gated=True", "model.base_model.vision_encoder.transformer.num_layers=2", "model.base_model.vision_encoder.transformer.hidden_size=256", "model.base_model.vision_encoder.transformer.num_attention_heads=8", - "model.base_model.vision_encoder.transformer.head_groups=4", + "model.base_model.vision_encoder.transformer.head_groups=8", "model.base_model.vision_encoder.transformer.init_method_std=0.022", "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", "model.base_model.vision_encoder.adapter_size=512", @@ -545,7 +549,7 @@ def _update_and_add_testing_config( groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, From 752274be515a8ee417d454f9e2f6032427327489 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 24 Jul 2025 15:46:41 +0000 Subject: [PATCH 19/23] fix conversion --- fast_llm/models/gpt/conversion.py | 5 +- fast_llm/models/ssm/conversion.py | 4 +- .../configuration_llava_hybrid.py | 7 ++ .../llava_hybrid/modeling_llava_hybrid.py | 105 +++++++++++++++++- tests/models/test_checkpoint.py | 2 +- 5 files changed, 115 insertions(+), 8 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 2f4d9b61..f376db3f 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -872,6 +872,7 @@ def num_layers(self) -> int: class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaForConditionalGeneration" _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig @classmethod @@ -912,9 +913,7 @@ def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetad @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] - ), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), MappedConfigParamConverter( fast_llm_names=(("vision_encoder", "adapter_activation_type"),), export_names=(("projector_hidden_act",),), diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index db960f0f..c765c956 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -808,6 +808,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class LlavaHybridHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlavaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaHybridForConditionalGeneration" modeling_file = modeling_llava_hybrid.__file__ configuration_file = configuration_llava_hybrid.__file__ configuration_cls: typing.ClassVar[type[PretrainedConfig]] = configuration_llava_hybrid.LlavaHybridConfig @@ -825,14 +826,13 @@ def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "LlavaHybridForConditionalGeneration" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), export_value={ "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", - "AutoModelForConditionalGeneration": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", }, ), ] diff --git a/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py index 09e17a92..b8e822d9 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py @@ -57,6 +57,7 @@ def __init__( text_config=None, image_token_index=32000, projector_hidden_act="gelu", + projector_intermediate_size=4096, vision_feature_select_strategy="default", vision_feature_layer=-2, image_seq_length=576, @@ -65,6 +66,8 @@ def __init__( ): self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act + # projector_intermediate_size is an addition to the original Llava config + self.projector_intermediate_size = projector_intermediate_size self.image_seq_length = image_seq_length if vision_feature_select_strategy not in ["default", "full"]: @@ -96,6 +99,7 @@ def __init__( self.vision_config = vision_config if isinstance(text_config, dict): + # Load the custom SSM hybrid config if specified if text_config.get("model_type") == "apriel_ssm_thinker_hybrid": text_config = AprielSSMHybridConfig(**text_config) else: @@ -108,3 +112,6 @@ def __init__( self.multimodal_projector_bias = multimodal_projector_bias super().__init__(**kwargs) + + +__all__ = ["LlavaHybridConfig"] diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index 78e390a1..6917fea9 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -1,7 +1,31 @@ from torch import nn -from transformers import LlavaForConditionalGeneration, LlavaModel +from transformers import AutoModel, LlavaForConditionalGeneration, LlavaModel +from transformers.activations import ACT2FN from .configuration_llava_hybrid import LlavaHybridConfig +from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlavaHybridConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.projector_intermediate_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.projector_intermediate_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states class LlavaHybridModel(LlavaModel): @@ -9,13 +33,90 @@ class LlavaHybridModel(LlavaModel): Llava SSM-Hybrid-decoder model. """ + config_class = LlavaHybridConfig + def __init__(self, config: LlavaHybridConfig): - super().__init__(config) + super(LlavaModel, self).__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + assert ( + config.text_config.model_type == "apriel_ssm_thinker_hybrid" + ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" + # from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel + self.language_model = AprielThinkerSSMHybridModel(config.text_config) + self.post_init() class LlavaHybridForConditionalGeneration(LlavaForConditionalGeneration): + config_class = LlavaHybridConfig + def __init__(self, config: LlavaHybridConfig): super(LlavaForConditionalGeneration, self).__init__(config) self.model = LlavaHybridModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Copy of the method from `AprielThinkerSSMHybridForCausalLM` + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + # Copy from `LlavaForConditionalGeneration.prepare_inputs_for_generation` + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 05acf23d..b1e9e74f 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -309,7 +309,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): errors = [] auto_model = ( transformers.AutoModel - if model_testing_config.name in ("diffusion_llama", "dream") + if model_testing_config.name in ("diffusion_llama", "dream", "vision_hybrid_mamba2") else transformers.AutoModelForCausalLM ) model_as_hf = auto_model.from_pretrained( From 8c03b54a05ba05324b16d638d10db91db8da0a06 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 24 Jul 2025 20:13:11 +0000 Subject: [PATCH 20/23] use hybrid cache, update test --- .../modeling_ssm_hybrid_apriel15b.py | 33 ++++++++++++++++++- .../llava_hybrid/modeling_llava_hybrid.py | 6 ++-- tests/models/test_checkpoint.py | 11 ++++--- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a052..bd12243e 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import LossKwargs, can_return_tuple, logging from transformers.utils.generic import ModelOutput from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig @@ -1209,6 +1209,37 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): # Initialize weights and apply final processing self.post_init() + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache and past_key_values is None: + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **flash_attn_kwargs, + ) + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py index 6917fea9..9896d91d 100644 --- a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -3,7 +3,6 @@ from transformers.activations import ACT2FN from .configuration_llava_hybrid import LlavaHybridConfig -from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel, HybridMambaAttentionDynamicCache class LlavaMultiModalProjector(nn.Module): @@ -43,7 +42,8 @@ def __init__(self, config: LlavaHybridConfig): assert ( config.text_config.model_type == "apriel_ssm_thinker_hybrid" ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" - # from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel + from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel + self.language_model = AprielThinkerSSMHybridModel(config.text_config) self.post_init() @@ -69,6 +69,8 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): + from .modeling_ssm_hybrid_apriel15b import HybridMambaAttentionDynamicCache + # Copy of the method from `AprielThinkerSSMHybridForCausalLM` # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index b1e9e74f..73bb24c8 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -307,11 +307,12 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) ) errors = [] - auto_model = ( - transformers.AutoModel - if model_testing_config.name in ("diffusion_llama", "dream", "vision_hybrid_mamba2") - else transformers.AutoModelForCausalLM - ) + if model_testing_config.name in ("diffusion_llama", "dream"): + auto_model = transformers.AutoModel + elif model_testing_config.name == "vision_hybrid_mamba2": + auto_model = transformers.AutoModelForVision2Seq + else: + auto_model = transformers.AutoModelForCausalLM model_as_hf = auto_model.from_pretrained( hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code ).cuda() From d51f817c34097792d50ff7a6696092fa620fad9c Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 24 Jul 2025 21:58:25 +0000 Subject: [PATCH 21/23] finish ssm-hybrid conversion --- fast_llm/models/ssm/config.py | 1 + fast_llm/models/ssm/conversion.py | 40 +++++++++++++------ .../modeling_ssm_hybrid_apriel15b.py | 8 +++- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 41f4eadb..70362a40 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -158,6 +158,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): support_optimizer: typing.ClassVar[bool] = False name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + trust_remote_code: typing.ClassVar[bool] = True @classmethod def get_handler_class(cls) -> type[CheckpointHandler]: diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index c765c956..640615e0 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -727,6 +727,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( + CustomModelingExportMixin, HybridModelCheckpointHandler, # handles the block structure parameter CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers @@ -741,6 +742,11 @@ class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( _default_block_type: str = SSMBlockType.mamba2_discrete.value _hf_prefix: str = "model" architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" + modeling_file = modeling_ssm_hybrid_apriel15b.__file__ + configuration_file = configuration_ssm_hybrid_apriel15b.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = ( + configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig + ) def _create_weight_converters( self, @@ -767,6 +773,14 @@ def _create_weight_converters( @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + }, + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), @@ -791,19 +805,19 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ), ] - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) + # @classmethod + # def _load_config(cls, directory: pathlib.Path | str) -> dict: + # if not os.path.exists(directory / "config.json"): + # raise FileNotFoundError(f"config.json not found in {directory}") + # with open(directory / "config.json") as f: + # config = json.load(f) + # Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + # return config + + # @classmethod + # def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + # with open(directory / "config.json", "w") as f: + # json.dump(config, f) class LlavaHybridHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlavaHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index bd12243e..da7984c7 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -357,7 +357,13 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx if len(self.key_cache) <= layer_idx: return 0 - return self.key_cache[layer_idx].shape[-2] + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + return self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + # return self.key_cache[layer_idx].shape[-2] def reset(self): self.conv_states.zero_() From f898ff25648a66f8cc8bf5f44c1d43c3ffe3fa34 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 24 Jul 2025 22:00:28 +0000 Subject: [PATCH 22/23] fix architecture classvar --- fast_llm/models/gpt/conversion.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index f376db3f..fb180106 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -368,10 +368,10 @@ def _get_weight_and_bias_converters( class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "Starcoder2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "Starcoder2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "rotary", "type"),), @@ -495,10 +495,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlamaForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "LlamaForCausalLM" return super()._create_config_converters() + [ # TODO: Llama supports biases ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), @@ -547,10 +547,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "Qwen2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "Qwen2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), @@ -593,10 +593,10 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MistralGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MistralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MistralForCausalLM" return super()._create_config_converters() + [ IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] @@ -1014,10 +1014,10 @@ def _create_weight_converters(self): class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MixtralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MixtralForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk @@ -1055,13 +1055,13 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam from fast_llm.models.gpt.external.mtp_llama import configuration_mtp_llama, modeling_mtp_llama format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MTPLlamaForCausalLM" modeling_file = modeling_mtp_llama.__file__ configuration_file = configuration_mtp_llama.__file__ configuration_cls: typing.ClassVar[type[PretrainedConfig]] = MTPLlamaConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MTPLlamaForCausalLM" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -1143,6 +1143,7 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "DreamModel" modeling_file = modeling_dream.__file__ configuration_file = configuration_dream.__file__ generation_utils_file = generation_utils.__file__ @@ -1150,7 +1151,6 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "DreamModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -1171,6 +1171,7 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam ) format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "DiffusionLlamaModel" modeling_file = modeling_diffusion_llama.__file__ configuration_file = configuration_diffusion_llama.__file__ generation_utils_file = generation_utils.__file__ @@ -1178,7 +1179,6 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "DiffusionLlamaModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), From c447fc3857800484929fc6cfb80ce879f201210c Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 24 Jul 2025 22:05:33 +0000 Subject: [PATCH 23/23] add llava test and m2 conversion test --- tests/models/test_checkpoint.py | 2 +- tests/utils/model_configs.py | 46 ++++++++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 73bb24c8..da719a42 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -309,7 +309,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): errors = [] if model_testing_config.name in ("diffusion_llama", "dream"): auto_model = transformers.AutoModel - elif model_testing_config.name == "vision_hybrid_mamba2": + elif model_testing_config.name in ("llava", "vision_hybrid_mamba2"): auto_model = transformers.AutoModelForVision2Seq else: auto_model = transformers.AutoModelForCausalLM diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index c2bb6800..643ca6c2 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -13,13 +13,18 @@ DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, LlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat, LlavaHybridHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import ( + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, +) from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig @@ -450,6 +455,41 @@ def _update_and_add_testing_config( compare_factor=2.0, ) +_update_and_add_testing_config( + # Tests hybrid ssm, llamba converter. + "llama", + "llava", + extra_args=[ + "batch.max_image_size=128", + "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.patch_norm.type=rms_norm", + "model.base_model.vision_encoder.transformer.add_linear_biases=False", + "model.base_model.vision_encoder.transformer.causal=False", + "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", + "model.base_model.vision_encoder.transformer.type=image_encoder", + "model.base_model.vision_encoder.transformer.gated=True", + "model.base_model.vision_encoder.transformer.num_layers=2", + "model.base_model.vision_encoder.transformer.hidden_size=256", + "model.base_model.vision_encoder.transformer.num_attention_heads=8", + "model.base_model.vision_encoder.transformer.head_groups=8", + "model.base_model.vision_encoder.transformer.init_method_std=0.022", + "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", + "model.base_model.vision_encoder.adapter_size=256", + "model.distributed.training_dtype=torch.bfloat16", + ], + megatron_args=None, + checkpoint_format=LlavaGPTHuggingfaceCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=8.0, +) + _update_and_add_testing_config( # Tests hybrid ssm, llamba converter. "llama", @@ -510,11 +550,11 @@ def _update_and_add_testing_config( "model.base_model.hybrid_block_layout=['t','m2']", ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant,