diff --git a/langport/model/executor/huggingface.py b/langport/model/executor/huggingface.py index 229c4cc..51f442c 100644 --- a/langport/model/executor/huggingface.py +++ b/langport/model/executor/huggingface.py @@ -259,14 +259,17 @@ def load_model( # raises an error on incompatible platforms from transformers import BitsAndBytesConfig - if "max_memory" in kwargs: - kwargs["max_memory"]["cpu"] = ( - str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" - ) + if "max_memory" not in kwargs: + kwargs["max_memory"] = {} + kwargs["max_memory"]["cpu"] = ( + str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" + ) kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_8bit_fp32_cpu_offload=cpu_offloading + llm_int8_enable_fp32_cpu_offload = cpu_offloading, + load_in_8bit = quantization.lower() in ["8", "8bit", "8-bit", "8bits", "8-bits"], + load_in_4bit = quantization.lower() in ["4", "4bit", "4-bit", "4bits", "4-bits"], ) - kwargs["load_in_8bit"] = quantization!=None + kwargs["device_map"] = "auto" # Load model model, tokenizer = self._load_hf_model(adapter, model_path, kwargs) elif quantization is not None or gptq: