From 85f55bb43c3f503920a006ec009d4772c4ca3590 Mon Sep 17 00:00:00 2001 From: parkervg Date: Sat, 28 Sep 2024 13:28:42 -0400 Subject: [PATCH] Chat model logging --- blendsql/models/local/_transformers.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/blendsql/models/local/_transformers.py b/blendsql/models/local/_transformers.py index 97e4b94..aeb5528 100644 --- a/blendsql/models/local/_transformers.py +++ b/blendsql/models/local/_transformers.py @@ -1,6 +1,8 @@ import importlib.util from typing import Optional +from colorama import Fore +from ..._logger import logger from .._model import LocalModel, ModelObj DEFAULT_KWARGS = {"do_sample": True, "temperature": 0.0, "top_p": 1.0} @@ -47,6 +49,7 @@ def __init__( transformers.logging.set_verbosity_error() if config is None: config = {} + super().__init__( model_name_or_path=model_name_or_path, requires_config=False, @@ -61,12 +64,20 @@ def _load_model(self) -> ModelObj: from guidance.models import Transformers import torch - return Transformers( + lm = Transformers( self.model_name_or_path, echo=False, device_map="cuda" if torch.cuda.is_available() else "cpu", **self.load_model_kwargs, ) + # Try to infer if we're in chat mode + if lm.engine.tokenizer._orig_tokenizer.chat_template is None: + logger.debug( + Fore.YELLOW + + "chat_template not found in tokenizer config.\nBlendSQL currently only works with chat models" + + Fore.RESET + ) + return lm class TransformersVisionModel(TransformersLLM):