Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]:
cls.base_model_converter_class.export_config(config.base_model),
{
"model_type": cls.get_huggingface_model_type(),
"architecture": cls.architecture,
"architectures": [cls.architecture],
},
)

@classmethod
def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig:
Assert.eq(config["model_type"], cls.get_huggingface_model_type())
Assert.eq(config["architecture"], cls.architecture)
Assert.eq(config["architectures"], [cls.architecture])
return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)})

def _create_weight_converters(self) -> list[WeightConverter]:
Expand Down
32 changes: 28 additions & 4 deletions fast_llm/models/gpt/conversion/apriel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.layers.block.config import BlockSequenceConfig, FixedBlockSequenceConfig, PatternBlockSequenceConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig
from fast_llm.layers.decoder.mlp.config import MLPConfig
from fast_llm.layers.ssm.config import DiscreteMamba2Config, Mamba2Config
from fast_llm.models.gpt.config import GPTModelConfig
from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat
from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters
from fast_llm.models.gpt.conversion.llama import (
LlamaMLPConverter,
get_parameter_converter,
get_weight_and_bias_converters,
)
from fast_llm.models.gpt.conversion.mistral import (
MistralBaseModelConverter,
MistralBlockConverter,
Expand Down Expand Up @@ -224,12 +229,31 @@ def get_converters(
]


class AprielDiscreteMamba2BlockConverter(MistralBlockConverter):
class AprielMLPConverter(LlamaMLPConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
config["mlp_bias"] = False
return super().import_config(config)

@classmethod
def export_config(cls, config: MLPConfig) -> dict:
out = super().export_config(config)
del out["mlp_bias"]
return out


class AprielBlockConverterBase(MistralBlockConverter):
mlp_converter_class: typing.ClassVar[type[AprielMLPConverter]] = AprielMLPConverter


class AprielDiscreteMamba2BlockConverter(AprielBlockConverterBase):
mixer_converter_class: typing.ClassVar[type[AprielDiscreteMamba2Converter]] = AprielDiscreteMamba2Converter
hf_mixer_name: typing.ClassVar[str] = "mixer"


class AprielMamba2BlockConverter(MistralBlockConverter):
class AprielMamba2BlockConverter(AprielBlockConverterBase):
mixer_converter_class: typing.ClassVar[type[AprielMamba2Converter]] = AprielMamba2Converter
hf_mixer_name: typing.ClassVar[str] = "mixer"


class AprielBlockConverter:
Expand All @@ -239,7 +263,7 @@ class AprielBlockConverter:
DiscreteMamba2Config: "m2d",
}
_converter_classes = {
AttentionConfig: MistralBlockConverter,
AttentionConfig: AprielBlockConverterBase,
Mamba2Config: AprielMamba2BlockConverter,
DiscreteMamba2Config: AprielDiscreteMamba2BlockConverter,
}
Expand Down
10 changes: 8 additions & 2 deletions fast_llm/models/gpt/conversion/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
class MistralAttentionConverter(LlamaAttentionConverter):
@classmethod
def import_config(cls, config: dict) -> dict:
return safe_merge_dicts(super().import_config(config), {"window_size": config["sliding_window"]})
config["attention_bias"] = False
return safe_merge_dicts(
super().import_config(config),
{"window_size": config["sliding_window"]},
)

@classmethod
def export_config(cls, config: AttentionConfig) -> dict:
return safe_merge_dicts(
out = safe_merge_dicts(
super().export_config(config),
{"sliding_window": config.window_size},
)
del out["attention_bias"]
return out

@classmethod
def _check_config(cls, config: AttentionConfig) -> None:
Expand Down