Skip to content

Commit

Permalink
Chat model logging
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed Sep 28, 2024
1 parent 3fa7bd8 commit 85f55bb
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion blendsql/models/local/_transformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib.util
from typing import Optional
from colorama import Fore

from ..._logger import logger
from .._model import LocalModel, ModelObj

DEFAULT_KWARGS = {"do_sample": True, "temperature": 0.0, "top_p": 1.0}
Expand Down Expand Up @@ -47,6 +49,7 @@ def __init__(
transformers.logging.set_verbosity_error()
if config is None:
config = {}

super().__init__(
model_name_or_path=model_name_or_path,
requires_config=False,
Expand All @@ -61,12 +64,20 @@ def _load_model(self) -> ModelObj:
from guidance.models import Transformers
import torch

return Transformers(
lm = Transformers(
self.model_name_or_path,
echo=False,
device_map="cuda" if torch.cuda.is_available() else "cpu",
**self.load_model_kwargs,
)
# Try to infer if we're in chat mode
if lm.engine.tokenizer._orig_tokenizer.chat_template is None:
logger.debug(
Fore.YELLOW
+ "chat_template not found in tokenizer config.\nBlendSQL currently only works with chat models"
+ Fore.RESET
)
return lm


class TransformersVisionModel(TransformersLLM):
Expand Down

0 comments on commit 85f55bb

Please sign in to comment.