From 7612050bd2f36cf0b7debf45c351a84a43e84ec2 Mon Sep 17 00:00:00 2001 From: parkervg Date: Fri, 21 Jun 2024 11:42:21 -0400 Subject: [PATCH] Fixing model_kwarg flow --- blendsql/models/local/_transformers.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/blendsql/models/local/_transformers.py b/blendsql/models/local/_transformers.py index 49bcb33a..e90abccc 100644 --- a/blendsql/models/local/_transformers.py +++ b/blendsql/models/local/_transformers.py @@ -2,6 +2,8 @@ from outlines.models import transformers, LogitsGenerator from .._model import LocalModel +DEFAULT_KWARGS = {"do_sample": True, "temperature": 0.0, "top_p": 1.0} + _has_transformers = importlib.util.find_spec("transformers") is not None _has_torch = importlib.util.find_spec("torch") is not None @@ -41,6 +43,7 @@ def __init__(self, model_name_or_path: str, caching: bool = True, **kwargs): model_name_or_path=model_name_or_path, requires_config=False, tokenizer=transformers.AutoTokenizer.from_pretrained(model_name_or_path), + load_model_kwargs=DEFAULT_KWARGS | kwargs, caching=caching, **kwargs, ) @@ -49,12 +52,7 @@ def _load_model(self) -> LogitsGenerator: # https://huggingface.co/blog/how-to-generate return transformers( self.model_name_or_path, - model_kwargs={ - "do_sample": True, - "temperature": 0.0, - "top_p": 1.0, - "trust_remote_code": True, - }, + model_kwargs=self.load_model_kwargs, )