Skip to content

Commit

Permalink
Merge pull request #509 from SMAntony/main
Browse files Browse the repository at this point in the history
Added quantization support for huggingface models
  • Loading branch information
zainhoda authored Jun 21, 2024
2 parents dce3186 + 3cd833f commit 8cc20fb
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/vanna/hf/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

class Hf(VannaBase):
def __init__(self, config=None):
model_name = self.config.get(
"model_name", None
) # e.g. meta-llama/Meta-Llama-3-8B-Instruct
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
model_name_or_path = self.config.get(
"model_name_or_path", None
) # e.g. meta-llama/Meta-Llama-3-8B-Instruct or local path to the model checkpoint files
# list of quantization methods supported by transformers package: https://huggingface.co/docs/transformers/main/en/quantization/overview
quantization_config = self.config.get("quantization_config", None)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
model_name_or_path,
quantization_config=quantization_config,
device_map="auto",
)

Expand Down

0 comments on commit 8cc20fb

Please sign in to comment.