diff --git a/src/vanna/hf/hf.py b/src/vanna/hf/hf.py index feb7ea5a..bed19630 100644 --- a/src/vanna/hf/hf.py +++ b/src/vanna/hf/hf.py @@ -6,13 +6,15 @@ class Hf(VannaBase): def __init__(self, config=None): - model_name = self.config.get( - "model_name", None - ) # e.g. meta-llama/Meta-Llama-3-8B-Instruct - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + model_name_or_path = self.config.get( + "model_name_or_path", None + ) # e.g. meta-llama/Meta-Llama-3-8B-Instruct or local path to the model checkpoint files + # list of quantization methods supported by transformers package: https://huggingface.co/docs/transformers/main/en/quantization/overview + quantization_config = self.config.get("quantization_config", None) + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype="auto", + model_name_or_path, + quantization_config=quantization_config, device_map="auto", )