diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 72db80f6..f8a42b31 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -283,7 +283,7 @@ def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: return exported_config # Noqa @classmethod - def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa + def _import_config_dict(cls, config: dict[str, typing.Any]) -> dict[str | tuple[str, ...], typing.Any]: kwargs = {} for converter in cls._get_config_converters(): try: @@ -306,7 +306,11 @@ def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # no kwargs[fast_llm_name] = value except Exception as e: raise RuntimeError(f"Config conversion failed for converter {converter}", *e.args) + return kwargs + @classmethod + def _import_config(cls, config: dict[str, typing.Any]) -> BaseModelConfig: # noqa + kwargs = cls._import_config_dict(config) return cls._model_class.get_base_model_config_class().from_dict({}, kwargs) def _convert_state_dict( diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 16b3e005..4cfff4af 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -134,6 +134,7 @@ class CustomModelingExportMixin: configuration_file: typing.ClassVar[str] configuration_cls: typing.ClassVar[type[PretrainedConfig]] generation_utils_file: str | None = None + additional_files: typing.ClassVar[list[str]] = [] # Use custom config instead of relying on the transformers library @classmethod @@ -159,3 +160,5 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None: gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json" if gen_config.exists(): shutil.copy(gen_config, config.path) + for file in self.additional_files: + shutil.copy(file, config.path) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c69ada38..46d629aa 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -22,7 +22,7 @@ class SSMDimNames: v_heads = "v_heads" # Number of V heads # Mamba 2 - x_proj_dim_2 = "x_proj_dim" # d_xb + x_proj_dim_2 = "x_proj_dim_2" # d_xb class SSMBlockType(enum.StrEnum): diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509ab..8a61a896 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -10,6 +11,8 @@ from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ from fast_llm.utils import get_lr_scale +logger = logging.getLogger(__name__) + try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -144,7 +147,15 @@ def init_from_tensor_( value: torch.Tensor, ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) + logger.info( + f"Initializing {meta.tensor_name} with shape {meta.shape}, tensor shape {tensor.shape} from value shape {value.shape}" + ) + # TODO: fix and remove try-except + try: + return tensor.copy_(value) + except RuntimeError as e: + logger.error(f"Failed to copy value to tensor: {e}") + return tensor.fill_(0.0) return init_ @@ -156,6 +167,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) lr_scale=mamba_layer_lr_scale, ) # define bias outside the linear layer since its also used in the selective_scan_fn + logger.info(f"td_inner: {td_inner}, inv_dt: {inv_dt.shape}") self.dt_proj_bias = ParameterMeta.from_dims( (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale ) @@ -166,6 +178,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) d=self.d_inner, ).contiguous() A_log = torch.log(A).flatten() # Keep A_log in fp32 + logger.info(f"A_log: {A_log.shape}, td_inner: {td_inner}, td_state: {td_state}") self.A_log = ParameterMeta.from_dims( (td_inner, td_state), init_method=init_from_tensor_(A_log), diff --git a/fast_llm/layers/vision_encoder/adapter.py b/fast_llm/layers/vision_encoder/adapter.py index 41ea065d..d324d522 100644 --- a/fast_llm/layers/vision_encoder/adapter.py +++ b/fast_llm/layers/vision_encoder/adapter.py @@ -24,15 +24,15 @@ def __init__(self, config: VisionEncoderConfig, tensor_space: TensorSpace): input_dim, tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), bias=True, - weight_init_method=init_normal_(), - bias_init_method=init_normal_(), + weight_init_method=init_normal_(std=config.adapter_init_method_std), + bias_init_method=init_normal_(std=config.adapter_init_method_std), ) self.layer_2 = Linear( tensor_space.get_tensor_dim(VisionEncoderDimNames.adapter_size), tensor_space.get_tensor_dim(TransformerDimNames.hidden), bias=True, - weight_init_method=init_normal_(), - bias_init_method=init_normal_(), + weight_init_method=init_normal_(std=config.adapter_init_method_std), + bias_init_method=init_normal_(std=config.adapter_init_method_std), ) def forward( diff --git a/fast_llm/layers/vision_encoder/config.py b/fast_llm/layers/vision_encoder/config.py index 2ea7f611..a705d948 100644 --- a/fast_llm/layers/vision_encoder/config.py +++ b/fast_llm/layers/vision_encoder/config.py @@ -150,6 +150,18 @@ class VisionEncoderConfig(BaseModelConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + adapter_init_method_std: float = Field( + default=None, + desc="Standard deviation for the normal initialization of the adapter weights. Default: adapter_size ** -0.5.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.adapter_init_method_std is None: + self.adapter_init_method_std = self.adapter_size**-0.5 + super()._validate() def setup_tensor_space(self, tensor_space: TensorSpace): tensor_space.add_tensor_dim(TensorDim(VisionEncoderDimNames.out_channels, self.transformer.hidden_size)) diff --git a/fast_llm/layers/vision_encoder/preprocessing.py b/fast_llm/layers/vision_encoder/preprocessing.py index 3b857ba2..adacd380 100644 --- a/fast_llm/layers/vision_encoder/preprocessing.py +++ b/fast_llm/layers/vision_encoder/preprocessing.py @@ -163,11 +163,12 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: for imgs in images ] - labels = kwargs[LanguageModelKwargs.labels] - if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): - # If image break or end token is present, we need to replace image token ids to -100 in labels - # TODO: avoid double cloning labels in case of loss masking spans? - labels = labels.clone() + if LanguageModelKwargs.labels in kwargs: + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() patches = [] patch_position_ids = [] @@ -191,8 +192,9 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: image_break=self._config.image_break_token is not None, image_end=self._config.image_end_token is not None, ) - # set labels for image patches to -100 - labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 + if LanguageModelKwargs.labels in kwargs: + # set labels for image patches to -100 + labels[idx, max(position - 1, 0) : position + num_tokens - 1] = -100 if seqlen > max_seqlen: max_seqlen = seqlen cu_seqlens.append(cu_seqlens[-1] + seqlen) @@ -261,4 +263,19 @@ def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: ) kwargs[VisionTransformerKwargs.max_seqlen_q] = max_seqlen kwargs[VisionTransformerKwargs.max_seqlen_k] = max_seqlen - kwargs[LanguageModelKwargs.labels] = labels + if LanguageModelKwargs.labels in kwargs: + kwargs[LanguageModelKwargs.labels] = labels + + # TODO: add proper preprocessing for attention-mask when not using flash attention + # Following is just a dummy code to run the tests. + kwargs[self._config.transformer._transformer_kwargs.attention_mask] = torch.ones( + (1, 1, kwargs[TransformerKwargs.sequence_length], 1, kwargs[TransformerKwargs.sequence_length]), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[self._config.transformer._transformer_kwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._distributed_config.training_dtype.torch).min, + dtype=self._distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 039b97f8..bc64821f 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -257,6 +257,9 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + if self.model.base_model.vision_encoder.enabled: + assert self.batch.max_image_size is not None, "max_image_size must be set when using vision encoder" + Assert.gt(self.batch.max_image_size, 0) @classmethod def _from_dict( diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index e01aaf70..fb180106 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -14,6 +14,7 @@ AutoStateDictCheckpointHandler, ConstantExportParamConverter, ConstantImportParamConverter, + ExternalStateDictCheckpointHandler, IgnoreExportWeightConverter, IgnoreImportParamConverter, IgnoreImportWeightConverter, @@ -367,10 +368,10 @@ def _get_weight_and_bias_converters( class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Starcoder2GPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "Starcoder2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "Starcoder2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "rotary", "type"),), @@ -494,10 +495,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlamaForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "LlamaForCausalLM" return super()._create_config_converters() + [ # TODO: Llama supports biases ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), @@ -546,10 +547,10 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "Qwen2ForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "Qwen2ForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "normalization", "type"),), @@ -592,10 +593,10 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MistralGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MistralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MistralForCausalLM" return super()._create_config_converters() + [ IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] @@ -664,7 +665,7 @@ def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.A class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = PixtralGPTHuggingfaceCheckpointFormat - _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + _model_class: typing.ClassVar[FastLLMModelConfig] = FastLLMModelConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -871,17 +872,28 @@ def num_layers(self) -> int: class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = LlavaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaForConditionalGeneration" _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig + @classmethod + def get_vision_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) + + @classmethod + def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + return AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + vision_handler_cls = cls.get_vision_handler_class() + text_handler_cls = cls.get_text_handler_class() cfg_dict = cls._load_config(config.path) kwargs = {} if "text_config" in cfg_dict: - text_kwargs = cls._import_config(cfg_dict["text_config"]) + text_kwargs = text_handler_cls._import_config_dict(cfg_dict["text_config"]) kwargs.update(text_kwargs) if "vision_config" in cfg_dict: - vision_kwargs = cls._import_config(cfg_dict["vision_config"]) + vision_kwargs = vision_handler_cls._import_config_dict(cfg_dict["vision_config"]) vision_kwargs = {tuple(["vision_encoder"] + list(key)): value for key, value in vision_kwargs.items()} kwargs.update(vision_kwargs) kwargs.update( @@ -901,9 +913,7 @@ def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetad @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - ConstantExportParamConverter( - export_names=(("architectures",),), export_value=["LlavaForConditionalGeneration"] - ), + ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]), MappedConfigParamConverter( fast_llm_names=(("vision_encoder", "adapter_activation_type"),), export_names=(("projector_hidden_act",),), @@ -918,9 +928,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: - handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) + # handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(config["model_type"]) kwargs = {} - for converter in handler_cls._create_config_converters(): + for converter in cls._create_config_converters(): try: values = () for export_name in converter.export_names: @@ -944,8 +954,8 @@ def _import_config(cls, config: dict[str, typing.Any]) -> GPTBaseModelConfig: @classmethod def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: exported_config = {} - vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.vision_name) - text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(cls.format.text_name) + vision_handler_cls = cls.get_vision_handler_class() + text_handler_cls = cls.get_text_handler_class() for converter in vision_handler_cls._create_config_converters(): try: values = converter.export_params( @@ -991,10 +1001,10 @@ def _export_config(cls, config: BaseModelConfig) -> dict[str, typing.Any]: return exported_config def _create_weight_converters(self): - vision_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.vision_name) + vision_handler_cls = self.get_vision_handler_class() vision_handler = vision_handler_cls(self._model) converters = vision_handler._create_weight_converters(hf_base_prefix="vision_tower.", offset=0) - text_handler_cls = AutoGPTHuggingfaceCheckpointHandler.get_handler_class(self.format.text_name) + text_handler_cls = self.get_text_handler_class() text_handler = text_handler_cls(self._model) converters.extend( text_handler._create_weight_converters(hf_base_prefix="language_model.", offset=vision_handler.num_layers) @@ -1004,10 +1014,10 @@ def _create_weight_converters(self): class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MixtralGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MixtralForCausalLM" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MixtralForCausalLM" return super()._create_config_converters() + [ ConstantImportParamConverter( fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk @@ -1045,13 +1055,13 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam from fast_llm.models.gpt.external.mtp_llama import configuration_mtp_llama, modeling_mtp_llama format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "MTPLlamaForCausalLM" modeling_file = modeling_mtp_llama.__file__ configuration_file = configuration_mtp_llama.__file__ configuration_cls: typing.ClassVar[type[PretrainedConfig]] = MTPLlamaConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "MTPLlamaForCausalLM" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -1133,6 +1143,7 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream format: typing.ClassVar[type[CheckpointFormat]] = DiffusionDreamGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "DreamModel" modeling_file = modeling_dream.__file__ configuration_file = configuration_dream.__file__ generation_utils_file = generation_utils.__file__ @@ -1140,7 +1151,6 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "DreamModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), @@ -1161,6 +1171,7 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam ) format: typing.ClassVar[type[CheckpointFormat]] = DiffusionLlamaGPTHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "DiffusionLlamaModel" modeling_file = modeling_diffusion_llama.__file__ configuration_file = configuration_diffusion_llama.__file__ generation_utils_file = generation_utils.__file__ @@ -1168,7 +1179,6 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Llam @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - cls.architecture = "DiffusionLlamaModel" return super()._create_config_converters() + [ ConstantExportParamConverter( export_names=(("auto_map",),), diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6356cf23..1e439e72 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -110,13 +110,15 @@ def get_vision_layers(self) -> list[Layer]: MultiModalEmbedding(self._config, self._tensor_space), ] + def get_embedding_layers(self) -> list[Layer]: + if self._config.vision_encoder.enabled: + return self.get_vision_layers() + else: + return [LanguageModelEmbedding(self._config, self._tensor_space)] + def get_layers(self) -> list[Layer]: return [ - *( - [LanguageModelEmbedding(self._config, self._tensor_space)] - if not self._config.vision_encoder.enabled - else self.get_vision_layers() - ), + *(self.get_embedding_layers()), *[ TransformerLayer( self._config.transformer, @@ -148,7 +150,11 @@ def preprocess_meta( micro_sequence_length = sequence_length if self._config.vision_encoder.enabled: - max_image_size = batch_meta.max_image_size + try: + max_image_size = batch_meta.max_image_size + except AttributeError: + max_image_size = 256 + logger.warning("Inference mode: max_image_size not provided, defaulting to 256") image_mean = [ self._config.vision_encoder.image_normalization.mean_r, self._config.vision_encoder.image_normalization.mean_g, diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11b..70362a40 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -158,6 +158,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): support_optimizer: typing.ClassVar[bool] = False name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + trust_remote_code: typing.ClassVar[bool] = True @classmethod def get_handler_class(cls) -> type[CheckpointHandler]: @@ -166,6 +167,27 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler +# class LlavaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): +# name: typing.ClassVar[str] = "llava" +# # Using default values for vision and text models. Can be overridden in the config +# vision_name: typing.ClassVar[str] = "pixtral" +# text_name: typing.ClassVar[str] = "mistral" + + +class LlavaHybridHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "llava_hybrid" + vision_name: typing.ClassVar[str] = "pixtral" + text_name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" + trust_remote_code: typing.ClassVar[bool] = True + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import LlavaHybridHuggingfaceCheckpointHandler + + return LlavaHybridHuggingfaceCheckpointHandler + + @config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False @@ -176,6 +198,7 @@ class HybridSSMModelConfig(FastLLMModelConfig): AprielSSMHuggingfaceCheckpointFormat, AprielSSMHHybridHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, ) @classmethod @@ -185,7 +208,7 @@ def get_model_class(cls) -> type["HybridSSMModel"]: return HybridSSMModel @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM return HuggingfaceHybridSSMModelForCausalLM diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d5730025..640615e0 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,10 +3,13 @@ import pathlib import typing +from transformers.configuration_utils import PretrainedConfig + from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, ConstantImportParamConverter, + ExternalStateDictCheckpointHandler, IgnoreImportParamConverter, IgnoreImportWeightConverter, MappedConfigParamConverter, @@ -15,19 +18,29 @@ SplitWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig from fast_llm.layers.ssm.config import SSMBlockType -from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter +from fast_llm.models.gpt.conversion import ( + CommonLlamaHuggingfaceCheckpointHandler, + LlavaHuggingfaceCheckpointHandler, + MLPLayer2Converter, +) from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, +) +from fast_llm.models.ssm.external.apriel_15b_hybrid import ( + configuration_ssm_hybrid_apriel15b, + modeling_ssm_hybrid_apriel15b, ) +from fast_llm.models.ssm.external.llava_hybrid import configuration_llava_hybrid, modeling_llava_hybrid from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert @@ -215,8 +228,18 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() or [] + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = ( + super()._create_weight_converters( + hf_base_prefix=hf_base_prefix, + offset=offset, + ) + or [] + ) num_layers = self._model.config.base_model.transformer.num_layers ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear @@ -224,55 +247,65 @@ def _create_weight_converters(self) -> list[WeightConverter]: for i in range(num_layers): # SSM converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias + f"layers.{offset+i+1}.mixer.in_proj", f"{hf_base_prefix}model.layers.{i}.mixer.in_proj", ssm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias + f"layers.{offset+i+1}.mixer.out_proj", f"{hf_base_prefix}model.layers.{i}.mixer.out_proj", ssm_bias ) converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) + WeightConverter( + f"layers.{offset+i+1}.mixer.D", + f"{hf_base_prefix}model.layers.{i}.mixer.D", + self._model.config.base_model, + ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{offset+i+1}.mixer.z_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{offset+i+1}.mixer.z_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"model.layers.{i}.mixer.conv1d.weight", + f"layers.{offset+i+1}.mixer.conv1d_weight", + f"{hf_base_prefix}model.layers.{i}.mixer.conv1d.weight", self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"model.layers.{i}.mixer.conv1d.bias", + f"layers.{offset+i+1}.mixer.conv1d_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.conv1d.bias", self._model.config.base_model, ) ) # ================================================ # Mamba2 specific parameters converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False + f"layers.{offset+i+1}.mixer.dt_proj", f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj", False ) # bias is treated separately in Mamba2 and must always exist (https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py) converters.append( WeightConverter( - f"layers.{i+1}.mixer.dt_proj_bias", - f"model.layers.{i}.mixer.dt_proj.bias", + f"layers.{offset+i+1}.mixer.dt_proj_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj.bias", self._model.config.base_model, ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.A_log", f"model.layers.{i}.mixer.A_log", self._model.config.base_model + f"layers.{offset+i+1}.mixer.A_log", + f"{hf_base_prefix}model.layers.{i}.mixer.A_log", + self._model.config.base_model, ) ) @@ -566,11 +599,16 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = super()._create_weight_converters(hf_base_prefix, offset) num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False + # TODO: use hf_base_prefix and offset # Embedding and output if self._model.config.base_model.tie_word_embeddings: converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) @@ -689,6 +727,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( + CustomModelingExportMixin, HybridModelCheckpointHandler, # handles the block structure parameter CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers @@ -703,9 +742,18 @@ class AprielThinkerSSMHHybridHuggingfaceCheckpointHandler( _default_block_type: str = SSMBlockType.mamba2_discrete.value _hf_prefix: str = "model" architecture: typing.ClassVar[str] = "AprielThinkerSSMHybridForCausalLM" + modeling_file = modeling_ssm_hybrid_apriel15b.__file__ + configuration_file = configuration_ssm_hybrid_apriel15b.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = ( + configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig + ) - def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() + def _create_weight_converters( + self, + hf_base_prefix: str = "", + offset: int = 0, + ) -> list[WeightConverter]: + converters = super()._create_weight_converters(hf_base_prefix, offset) # num_layers = self._model.config.base_model.transformer.num_layers # # Embedding and output # if self._model.config.base_model.tie_word_embeddings: @@ -725,6 +773,14 @@ def _create_weight_converters(self) -> list[WeightConverter]: @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig", + "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel", + "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM", + }, + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), @@ -749,16 +805,48 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ), ] + # @classmethod + # def _load_config(cls, directory: pathlib.Path | str) -> dict: + # if not os.path.exists(directory / "config.json"): + # raise FileNotFoundError(f"config.json not found in {directory}") + # with open(directory / "config.json") as f: + # config = json.load(f) + # Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + # return config + + # @classmethod + # def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + # with open(directory / "config.json", "w") as f: + # json.dump(config, f) + + +class LlavaHybridHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlavaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = LlavaHybridHuggingfaceCheckpointFormat + architecture: typing.ClassVar[str] = "LlavaHybridForConditionalGeneration" + modeling_file = modeling_llava_hybrid.__file__ + configuration_file = configuration_llava_hybrid.__file__ + configuration_cls: typing.ClassVar[type[PretrainedConfig]] = configuration_llava_hybrid.LlavaHybridConfig + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + additional_files = [ + modeling_ssm_hybrid_apriel15b.__file__, + configuration_ssm_hybrid_apriel15b.__file__, + ] + @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config + def get_text_handler_class(cls) -> type[ExternalStateDictCheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielThinkerSSMHHybridHuggingfaceCheckpointHandler + + return AprielThinkerSSMHHybridHuggingfaceCheckpointHandler @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter( + export_names=(("auto_map",),), + export_value={ + "AutoConfig": "configuration_llava_hybrid.LlavaHybridConfig", + "AutoModel": "modeling_llava_hybrid.LlavaHybridModel", + "AutoModelForVision2Seq": "modeling_llava_hybrid.LlavaHybridForConditionalGeneration", + }, + ), + ] diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a052..da7984c7 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import LossKwargs, can_return_tuple, logging from transformers.utils.generic import ModelOutput from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig @@ -357,7 +357,13 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx if len(self.key_cache) <= layer_idx: return 0 - return self.key_cache[layer_idx].shape[-2] + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or not self.key_cache[layer_idx].numel() # the layer has no cache + ) + return self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + # return self.key_cache[layer_idx].shape[-2] def reset(self): self.conv_states.zero_() @@ -1209,6 +1215,37 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): # Initialize weights and apply final processing self.post_init() + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache and past_key_values is None: + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **flash_attn_kwargs, + ) + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... diff --git a/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py new file mode 100644 index 00000000..b8e822d9 --- /dev/null +++ b/fast_llm/models/ssm/external/llava_hybrid/configuration_llava_hybrid.py @@ -0,0 +1,117 @@ +from transformers import MistralConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +# Copied from configuration_ssm_hybrid_apriel15b.py +# TODO: split into mamba 2 and discrete mamba 2 configs with a base dict +ssm_config_default = { + # discrete mamba2 + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 32 * 128, + # mamba2 + "d_xb": None, # will be set to model dim + "dt_rank": "auto", + "dt_min": 0.001, + "dt_max": 0.1, + "dt_init": "random", + "dt_scale": 1.0, + "dt_init_floor": 1e-4, + "conv_bias": True, +} + + +class AprielSSMHybridConfig(MistralConfig): + model_type = "apriel_ssm_thinker_hybrid" + + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + super().__init__(**kwargs) + self.hybrid_block_layout = hybrid_block_layout + self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 + self.ssm_cfg = ssm_cfg or ssm_config_default + + for k, v in ssm_config_default.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v # to make sure all elements are present in the config + + +class LlavaHybridConfig(PretrainedConfig): + """ + Configuration class for Llava SSM-Hybrid-decoder model. + """ + + model_type = "llava_hybrid" + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + projector_hidden_act="gelu", + projector_intermediate_size=4096, + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=576, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + # projector_intermediate_size is an addition to the original Llava config + self.projector_intermediate_size = projector_intermediate_size + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + # Load the custom SSM hybrid config if specified + if text_config.get("model_type") == "apriel_ssm_thinker_hybrid": + text_config = AprielSSMHybridConfig(**text_config) + else: + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + + super().__init__(**kwargs) + + +__all__ = ["LlavaHybridConfig"] diff --git a/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py new file mode 100644 index 00000000..9896d91d --- /dev/null +++ b/fast_llm/models/ssm/external/llava_hybrid/modeling_llava_hybrid.py @@ -0,0 +1,124 @@ +from torch import nn +from transformers import AutoModel, LlavaForConditionalGeneration, LlavaModel +from transformers.activations import ACT2FN + +from .configuration_llava_hybrid import LlavaHybridConfig + + +class LlavaMultiModalProjector(nn.Module): + def __init__(self, config: LlavaHybridConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.projector_intermediate_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.projector_intermediate_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class LlavaHybridModel(LlavaModel): + """ + Llava SSM-Hybrid-decoder model. + """ + + config_class = LlavaHybridConfig + + def __init__(self, config: LlavaHybridConfig): + super(LlavaModel, self).__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = LlavaMultiModalProjector(config) + assert ( + config.text_config.model_type == "apriel_ssm_thinker_hybrid" + ), "Only Apriel SSM Hybrid model is supported in LlavaHybridModel" + from .modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridModel + + self.language_model = AprielThinkerSSMHybridModel(config.text_config) + self.post_init() + + +class LlavaHybridForConditionalGeneration(LlavaForConditionalGeneration): + config_class = LlavaHybridConfig + + def __init__(self, config: LlavaHybridConfig): + super(LlavaForConditionalGeneration, self).__init__(config) + self.model = LlavaHybridModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + from .modeling_ssm_hybrid_apriel15b import HybridMambaAttentionDynamicCache + + # Copy of the method from `AprielThinkerSSMHybridForCausalLM` + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + # Copy from `LlavaForConditionalGeneration.prepare_inputs_for_generation` + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/ssm/huggingface.py index 77cd346f..02f47207 100644 --- a/fast_llm/models/ssm/huggingface.py +++ b/fast_llm/models/ssm/huggingface.py @@ -1,9 +1,10 @@ import logging +import typing -from fast_llm.engine.huggingface.config import HuggingfaceModelConfig +from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from fast_llm.models.ssm.config import HybridSSMModelConfig -from fast_llm.models.ssm.model import HybridSSMModel +from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel logger = logging.getLogger(__name__) @@ -17,5 +18,6 @@ class HuggingfaceSSMModelConfig(HuggingfaceModelConfig): class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): config_class = HuggingfaceSSMModelConfig config: HuggingfaceSSMModelConfig + runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner model_class = HybridSSMModel _fast_llm_model: HybridSSMModel diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac23..df15907d 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -3,13 +3,14 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba2 import Mamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -94,7 +95,7 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers = self.get_embedding_layers() # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): @@ -165,3 +166,8 @@ class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): config_class: typing.ClassVar[type[HybridSSMModelConfig]] = HybridSSMModelConfig base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel + + +class HybridSSMInferenceRunner(InferenceRunner): + model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel + batch_config_class: typing.ClassVar[type[GPTBatchConfig]] = GPTBatchConfig diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 05acf23d..da719a42 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -307,11 +307,12 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) ) errors = [] - auto_model = ( - transformers.AutoModel - if model_testing_config.name in ("diffusion_llama", "dream") - else transformers.AutoModelForCausalLM - ) + if model_testing_config.name in ("diffusion_llama", "dream"): + auto_model = transformers.AutoModel + elif model_testing_config.name in ("llava", "vision_hybrid_mamba2"): + auto_model = transformers.AutoModelForVision2Seq + else: + auto_model = transformers.AutoModelForCausalLM model_as_hf = auto_model.from_pretrained( hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code ).cuda() diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b96a8963..643ca6c2 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -13,13 +13,18 @@ DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, LlamaGPTHuggingfaceCheckpointFormat, + LlavaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, MTPLlamaGPTHuggingfaceCheckpointFormat, Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import ( + AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, + LlavaHybridHuggingfaceCheckpointFormat, +) from tests.utils.dataset import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE from tests.utils.distributed_configs import DistributedTestingConfig @@ -450,6 +455,41 @@ def _update_and_add_testing_config( compare_factor=2.0, ) +_update_and_add_testing_config( + # Tests hybrid ssm, llamba converter. + "llama", + "llava", + extra_args=[ + "batch.max_image_size=128", + "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.patch_norm.type=rms_norm", + "model.base_model.vision_encoder.transformer.add_linear_biases=False", + "model.base_model.vision_encoder.transformer.causal=False", + "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", + "model.base_model.vision_encoder.transformer.type=image_encoder", + "model.base_model.vision_encoder.transformer.gated=True", + "model.base_model.vision_encoder.transformer.num_layers=2", + "model.base_model.vision_encoder.transformer.hidden_size=256", + "model.base_model.vision_encoder.transformer.num_attention_heads=8", + "model.base_model.vision_encoder.transformer.head_groups=8", + "model.base_model.vision_encoder.transformer.init_method_std=0.022", + "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", + "model.base_model.vision_encoder.adapter_size=256", + "model.distributed.training_dtype=torch.bfloat16", + ], + megatron_args=None, + checkpoint_format=LlavaGPTHuggingfaceCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=8.0, +) + _update_and_add_testing_config( # Tests hybrid ssm, llamba converter. "llama", @@ -510,15 +550,51 @@ def _update_and_add_testing_config( "model.base_model.hybrid_block_layout=['t','m2']", ], megatron_args=None, - checkpoint_format=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, - ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, +) + +_update_and_add_testing_config( + # Tests hybrid ssm, llamba converter. + "hybrid_mamba2", + "vision_hybrid_mamba2", + model_type="hybrid_ssm", + extra_args=[ + "batch.max_image_size=128", + "model.base_model.vision_encoder.type=pixtral", + "model.base_model.vision_encoder.patch_norm.type=rms_norm", + "model.base_model.vision_encoder.transformer.add_linear_biases=False", + "model.base_model.vision_encoder.transformer.causal=False", + "model.base_model.vision_encoder.transformer.normalization.type=rms_norm", + "model.base_model.vision_encoder.transformer.type=image_encoder", + "model.base_model.vision_encoder.transformer.gated=True", + "model.base_model.vision_encoder.transformer.num_layers=2", + "model.base_model.vision_encoder.transformer.hidden_size=256", + "model.base_model.vision_encoder.transformer.num_attention_heads=8", + "model.base_model.vision_encoder.transformer.head_groups=8", + "model.base_model.vision_encoder.transformer.init_method_std=0.022", + "model.base_model.vision_encoder.transformer.rotary.type=rope_2d", + "model.base_model.vision_encoder.adapter_size=512", + "model.distributed.training_dtype=torch.bfloat16", + ], + megatron_args=None, + checkpoint_format=LlavaHybridHuggingfaceCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, + compare_factor=16.0, )