|
4 | 4 | from fast_llm.engine.checkpoint.external import SplitWeightConverter, WeightConverter
|
5 | 5 | from fast_llm.layers.decoder.mlp.config import MoEMLPConfig
|
6 | 6 | from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat
|
7 |
| -from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, get_weight_and_bias_converters |
| 7 | +from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, MLPLayer2Converter, get_weight_and_bias_converters |
8 | 8 | from fast_llm.models.gpt.conversion.mistral import (
|
9 | 9 | MistralBaseModelConverter,
|
10 | 10 | MistralBlockConverter,
|
@@ -50,16 +50,29 @@ def get_converters(
|
50 | 50 | return [
|
51 | 51 | *get_weight_and_bias_converters(
|
52 | 52 | f"{fast_llm_prefix}.router",
|
53 |
| - () if drop_on_export else (f"{hf_prefix}.router",), |
54 |
| - config.add_linear_biases, |
| 53 | + f"{hf_prefix}.gate", |
| 54 | + False, |
| 55 | + drop_on_export=drop_on_export, |
| 56 | + ), |
| 57 | + *get_weight_and_bias_converters( |
| 58 | + f"{fast_llm_prefix}.layer_1", |
| 59 | + tuple(f"{hf_prefix}.experts.{i}.{w}" for i in range(config.experts) for w in ("w1", "w3")), |
| 60 | + False, |
55 | 61 | SplitWeightConverter,
|
56 | 62 | drop_on_export=drop_on_export,
|
57 | 63 | ),
|
58 |
| - *super().get_converters(config, fast_llm_prefix, hf_prefix, drop_on_export=drop_on_export), |
| 64 | + *get_weight_and_bias_converters( |
| 65 | + f"{fast_llm_prefix}.layer_2", |
| 66 | + tuple(f"{hf_prefix}.experts.{i}.w2" for i in range(config.experts)), |
| 67 | + False, |
| 68 | + MLPLayer2Converter, |
| 69 | + drop_on_export=drop_on_export, |
| 70 | + ), |
59 | 71 | ]
|
60 | 72 |
|
61 | 73 |
|
62 | 74 | class MixtralBlockConverter(MistralBlockConverter):
|
| 75 | + hf_mlp_name: typing.ClassVar[str] = "block_sparse_moe" |
63 | 76 | mlp_converter_class: typing.ClassVar[type[MixtralMLPConverter]] = MixtralMLPConverter
|
64 | 77 |
|
65 | 78 |
|
|
0 commit comments