diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8537b7e754..e94a0f6b88 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,4 +1,5 @@ """Module for models and model loading""" + # pylint: disable=too-many-lines import logging @@ -504,6 +505,9 @@ def load_model( bnb_config = { "load_in_8bit": True, } + # Exclude mamba blocks from int8 quantization for jamba + if cfg.model_config_type == "jamba": + bnb_config["llm_int8_skip_modules"] = ["mamba"] model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, )