Skip to content

Commit 4db6271

Browse files
authored
Fix Mixtral conversion (#365)
1 parent ca3b000 commit 4db6271

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

fast_llm/models/gpt/conversion/llama.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,10 @@ class LlamaBlockConverter:
354354
mixer_converter_class: typing.ClassVar[type[LlamaAttentionConverter]] = LlamaAttentionConverter
355355
mlp_converter_class: typing.ClassVar[type[LlamaMLPConverter]] = LlamaMLPConverter
356356
normalization_converter_class: typing.ClassVar[type[LlamaNormalizationConverter]] = LlamaNormalizationConverter
357+
hf_mixer_name: typing.ClassVar[str] = "self_attn"
358+
hf_mlp_name: typing.ClassVar[str] = "mlp"
359+
hf_norm_1_name: typing.ClassVar[str] = "input_layernorm"
360+
hf_norm_2_name: typing.ClassVar[str] = "post_attention_layernorm"
357361

358362
@classmethod
359363
def import_config(cls, config: dict, hidden_size: int) -> dict:
@@ -380,25 +384,25 @@ def get_converters(
380384
*cls.mixer_converter_class.get_converters(
381385
config.mixer,
382386
f"{fast_llm_prefix}.mixer",
383-
f"{hf_prefix}.self_attn",
387+
f"{hf_prefix}.{cls.hf_mixer_name}",
384388
drop_on_export,
385389
),
386390
*cls.mlp_converter_class.get_converters(
387391
config.mlp,
388392
f"{fast_llm_prefix}.mlp",
389-
f"{hf_prefix}.mlp",
393+
f"{hf_prefix}.{cls.hf_mlp_name}",
390394
drop_on_export,
391395
),
392396
*cls.normalization_converter_class.get_converters(
393397
config.normalization,
394398
f"{fast_llm_prefix}.norm_1",
395-
f"{hf_prefix}.input_layernorm",
399+
f"{hf_prefix}.{cls.hf_norm_1_name}",
396400
drop_on_export,
397401
),
398402
*cls.normalization_converter_class.get_converters(
399403
config.normalization,
400404
f"{fast_llm_prefix}.norm_2",
401-
f"{hf_prefix}.post_attention_layernorm",
405+
f"{hf_prefix}.{cls.hf_norm_2_name}",
402406
drop_on_export,
403407
),
404408
]

fast_llm/models/gpt/conversion/mixtral.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter
55
from fast_llm.layers.decoder.mlp.config import MoEMLPConfig
66
from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat
7-
from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, get_weight_and_bias_converters
7+
from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, MLPLayer2Converter, get_weight_and_bias_converters
88
from fast_llm.models.gpt.conversion.mistral import (
99
MistralBaseModelConverter,
1010
MistralBlockConverter,
@@ -50,16 +50,29 @@ def get_converters(
5050
return [
5151
*get_weight_and_bias_converters(
5252
f"{fast_llm_prefix}.router",
53-
() if drop_on_export else (f"{hf_prefix}.router",),
54-
config.add_linear_biases,
53+
f"{hf_prefix}.gate",
54+
False,
55+
drop_on_export=drop_on_export,
56+
),
57+
*get_weight_and_bias_converters(
58+
f"{fast_llm_prefix}.layer_1",
59+
tuple(f"{hf_prefix}.experts.{i}.{w}" for i in range(config.experts) for w in ("w1", "w3")),
60+
False,
5561
SplitWeightConverter,
5662
drop_on_export=drop_on_export,
5763
),
58-
*super().get_converters(config, fast_llm_prefix, hf_prefix, drop_on_export=drop_on_export),
64+
*get_weight_and_bias_converters(
65+
f"{fast_llm_prefix}.layer_2",
66+
tuple(f"{hf_prefix}.experts.{i}.w2" for i in range(config.experts)),
67+
False,
68+
MLPLayer2Converter,
69+
drop_on_export=drop_on_export,
70+
),
5971
]
6072

6173

6274
class MixtralBlockConverter(MistralBlockConverter):
75+
hf_mlp_name: typing.ClassVar[str] = "block_sparse_moe"
6376
mlp_converter_class: typing.ClassVar[type[MixtralMLPConverter]] = MixtralMLPConverter
6477

6578

0 commit comments

Comments
 (0)