diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 96fb5332..27017175 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -120,14 +120,14 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: cls.base_model_converter_class.export_config(config.base_model), { "model_type": cls.get_huggingface_model_type(), - "architecture": cls.architecture, + "architectures": [cls.architecture], }, ) @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - Assert.eq(config["architecture"], cls.architecture) + Assert.eq(config["architectures"], [cls.architecture]) return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) def _create_weight_converters(self) -> list[WeightConverter]: diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 4b984963..7550df04 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -8,10 +8,15 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat -from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.llama import ( + LlamaMLPConverter, + get_parameter_converter, + get_weight_and_bias_converters, +) from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, @@ -224,12 +229,31 @@ def get_converters( ] -class AprielDiscreteMamba2BlockConverter(MistralBlockConverter): +class AprielMLPConverter(LlamaMLPConverter): + @classmethod + def import_config(cls, config: dict) -> dict: + config["mlp_bias"] = False + return super().import_config(config) + + @classmethod + def export_config(cls, config: MLPConfig) -> dict: + out = super().export_config(config) + del out["mlp_bias"] + return out + + +class AprielBlockConverterBase(MistralBlockConverter): + mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter + + +class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase): mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter + hf_mixer_name: typing.ClassVar[str] = "mixer" -class AprielMamba2BlockConverter(MistralBlockConverter): +class AprielMamba2BlockConverter(AprielBlockConverterBase): mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter + hf_mixer_name: typing.ClassVar[str] = "mixer" class AprielBlockConverter: @@ -239,7 +263,7 @@ class AprielBlockConverter: DiscreteMamba2Config: "m2d", } _converter_classes = { - AttentionConfig: MistralBlockConverter, + AttentionConfig: AprielBlockConverterBase, Mamba2Config: AprielMamba2BlockConverter, DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter, } diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index bfc7d556..b5db3fa0 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -17,14 +17,20 @@ class MistralAttentionConverter(LlamaAttentionConverter): @classmethod def import_config(cls, config: dict) -> dict: - return safe_merge_dicts(super().import_config(config), {"window_size": config["sliding_window"]}) + config["attention_bias"] = False + return safe_merge_dicts( + super().import_config(config), + {"window_size": config["sliding_window"]}, + ) @classmethod def export_config(cls, config: AttentionConfig) -> dict: - return safe_merge_dicts( + out = safe_merge_dicts( super().export_config(config), {"sliding_window": config.window_size}, ) + del out["attention_bias"] + return out @classmethod def _check_config(cls, config: AttentionConfig) -> None: