Skip to content

Commit

Permalink
Fixing model_kwarg flow
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed Jun 21, 2024
1 parent 72a6c16 commit 7612050
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions blendsql/models/local/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from outlines.models import transformers, LogitsGenerator
from .._model import LocalModel

DEFAULT_KWARGS = {"do_sample": True, "temperature": 0.0, "top_p": 1.0}

_has_transformers = importlib.util.find_spec("transformers") is not None
_has_torch = importlib.util.find_spec("torch") is not None

Expand Down Expand Up @@ -41,6 +43,7 @@ def __init__(self, model_name_or_path: str, caching: bool = True, **kwargs):
model_name_or_path=model_name_or_path,
requires_config=False,
tokenizer=transformers.AutoTokenizer.from_pretrained(model_name_or_path),
load_model_kwargs=DEFAULT_KWARGS | kwargs,
caching=caching,
**kwargs,
)
Expand All @@ -49,12 +52,7 @@ def _load_model(self) -> LogitsGenerator:
# https://huggingface.co/blog/how-to-generate
return transformers(
self.model_name_or_path,
model_kwargs={
"do_sample": True,
"temperature": 0.0,
"top_p": 1.0,
"trust_remote_code": True,
},
model_kwargs=self.load_model_kwargs,
)


Expand Down

0 comments on commit 7612050

Please sign in to comment.