diff --git a/blendsql/models/local/_transformers.py b/blendsql/models/local/_transformers.py index 49bcb33..e90abcc 100644 --- a/blendsql/models/local/_transformers.py +++ b/blendsql/models/local/_transformers.py @@ -2,6 +2,8 @@ from outlines.models import transformers, LogitsGenerator from .._model import LocalModel +DEFAULT_KWARGS = {"do_sample": True, "temperature": 0.0, "top_p": 1.0} + _has_transformers = importlib.util.find_spec("transformers") is not None _has_torch = importlib.util.find_spec("torch") is not None @@ -41,6 +43,7 @@ def __init__(self, model_name_or_path: str, caching: bool = True, **kwargs): model_name_or_path=model_name_or_path, requires_config=False, tokenizer=transformers.AutoTokenizer.from_pretrained(model_name_or_path), + load_model_kwargs=DEFAULT_KWARGS | kwargs, caching=caching, **kwargs, ) @@ -49,12 +52,7 @@ def _load_model(self) -> LogitsGenerator: # https://huggingface.co/blog/how-to-generate return transformers( self.model_name_or_path, - model_kwargs={ - "do_sample": True, - "temperature": 0.0, - "top_p": 1.0, - "trust_remote_code": True, - }, + model_kwargs=self.load_model_kwargs, )